Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
39c45bc
docs, type hints,
Pwhsky Jul 24, 2025
e928b87
types
Pwhsky Jul 24, 2025
44f3f72
docs
Pwhsky Jul 28, 2025
9385fb3
torch compatible
Pwhsky Jul 28, 2025
5123938
added averagepooling for torch
Pwhsky Jul 28, 2025
8d1763f
u
Pwhsky Jul 28, 2025
30e4c3c
unit test for avgpooling
Pwhsky Jul 28, 2025
e6b94e2
parenthesis
Pwhsky Jul 28, 2025
a88966d
input image
Pwhsky Jul 28, 2025
ecf5efc
capital
Pwhsky Jul 28, 2025
937dc39
u
Pwhsky Jul 28, 2025
eac7e7a
u
Pwhsky Jul 28, 2025
d71e476
u
Pwhsky Jul 28, 2025
4771183
u
Pwhsky Jul 28, 2025
fbd36c9
u
Pwhsky Jul 28, 2025
bb107e0
u
Pwhsky Jul 28, 2025
958d8b4
u
Pwhsky Jul 28, 2025
10097d9
u
Pwhsky Jul 28, 2025
2f030c3
u
Pwhsky Jul 28, 2025
1d872c0
u
Pwhsky Jul 28, 2025
3e0cee9
u
Pwhsky Jul 28, 2025
c8f6fc0
u
Pwhsky Jul 28, 2025
981a7cd
u
Pwhsky Jul 28, 2025
46fb873
u
Pwhsky Jul 28, 2025
3c11ccd
Update math.py
Pwhsky Jul 29, 2025
a411bbf
avg pooling with torch and numpy independent
Pwhsky Jul 29, 2025
6f27463
docs for averagepooling, unit tests todo
Pwhsky Jul 29, 2025
00fa795
self.get_backend
Pwhsky Jul 30, 2025
a1e839c
separated numpy and torch unit tests
Pwhsky Jul 30, 2025
fe6b1bc
u
Pwhsky Jul 30, 2025
5b94e5a
attempt at fixing broken optics
Pwhsky Jul 30, 2025
2332197
u
Pwhsky Jul 30, 2025
3c61159
undo fix
Pwhsky Jul 30, 2025
a249737
no kwargs in torch pool
Pwhsky Jul 30, 2025
c8d2fb8
docs typo
Pwhsky Jul 30, 2025
9207126
docs
Pwhsky Jul 30, 2025
288d944
Added shape handling for len(dim) == 2
Pwhsky Aug 6, 2025
43c4247
xp tests for avgpool
Pwhsky Sep 3, 2025
447aff0
Update math.py
Pwhsky Sep 5, 2025
ebbc05c
Merge branch 'develop' into AL/math/pool
Pwhsky Sep 5, 2025
070d63d
Merge branch 'develop' into AL/math/pool
Pwhsky Sep 5, 2025
69694e2
Formatting for AveragePooling
Pwhsky Sep 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 137 additions & 29 deletions deeptrack/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,24 +1069,23 @@ def __init__(
super().__init__(ndimage.median_filter, size=ksize, **kwargs)


#TODO ***AL*** revise Pool - torch, typing, docstring, unit test
class Pool(Feature):
class Pool(Feature): # Deprecated, children will be independent in the future.
"""Downsamples the image by applying a function to local regions of the
image.

This class reduces the resolution of an image by dividing it into
non-overlapping blocks of size `ksize` and applying the specified pooling
function to each block. The result is a downsampled image where each pixel
value represents the result of the pooling function applied to the
corresponding block.
corresponding block. This pooling only works with numpy functions.

Parameters
----------
pooling_function: function
pooling_function: Numpy function
A function that is applied to each local region of the image.
DOES NOT NEED TO BE WRAPPED IN ANOTHER FUNCTION.
The `pooling_function` must accept the input image as a keyword argument
named `input`, as it is called via `utils.safe_call`.
The `pooling_function` must accept the input image as a keyword
argument named `input`, as it is called via `utils.safe_call`.
Examples include `np.mean`, `np.max`, `np.min`, etc.
ksize: int
Size of the pooling kernel.
Expand All @@ -1095,7 +1094,8 @@ class Pool(Feature):

Methods
-------
`get(image: np.ndarray | Image, ksize: int, **kwargs: Any) --> np.ndarray`
`get(image: NDArray,
ksize: int, **kwargs: Any) --> NDArray`
Applies the pooling function to the input image.

Examples
Expand Down Expand Up @@ -1152,17 +1152,17 @@ def __init__(

def get(
self: Pool,
image: np.ndarray | Image,
image: NDArray,
ksize: int,
**kwargs: Any,
) -> np.ndarray:
) -> NDArray:
"""Applies the pooling function to the input image.

This method applies the pooling function to the input image.
This method applies `pooling_function` to the input image.

Parameters
----------
image: np.ndarray
image: NDArray | torch.Tensor
The input image to pool.
ksize: int
Size of the pooling kernel.
Expand All @@ -1171,7 +1171,7 @@ def get(

Returns
-------
np.ndarray
NDArray | torch.Tensor
The pooled image.

"""
Expand All @@ -1188,54 +1188,54 @@ def get(
)


#TODO ***AL*** revise AveragePooling - torch, typing, docstring, unit test
class AveragePooling(Pool):
"""Apply average pooling to an image.

This class reduces the resolution of an image by dividing it into
non-overlapping blocks of size `ksize` and applying the average function to
each block. The result is a downsampled image where each pixel value
`AveragePooling` reduces the resolution of an image by dividing it into
non-overlapping blocks of size `ksize` and applying the `average` function
to each block. The result is a downsampled image where each pixel value
represents the average value within the corresponding block of the
original image.
original image. This is useful for reducing the size of an image while
retaining the most significant features.

If the backend is NumPy, the downsampling is performed using
`skimage.measure.block_reduce`.
If the backend is PyTorch, the downsampling
is performed using `torch.nn.functional.avg_pool2d`.

Parameters
----------
ksize: int
Size of the pooling kernel.
**kwargs: dict
**kwargs: Any
Additional parameters sent to the pooling function.

Examples
--------
>>> import deeptrack as dt
>>> import numpy as np

Create an input image:
>>> import numpy as np
>>>
>>> input_image = np.random.rand(32, 32)

Define an average pooling feature:
Define and use a average-pooling feature:

>>> average_pooling = dt.AveragePooling(ksize=4)
>>> output_image = average_pooling(input_image)
>>> print(output_image.shape)
(8, 8)

Notes
-----
Calling this feature returns a `np.ndarray` by default. If
`store_properties` is set to `True`, the returned array will be
automatically wrapped in an `Image` object. This behavior is handled
internally and does not affect the return type of the `get()` method.

"""

def __init__(
self: Pool,
self: AveragePooling,
ksize: PropertyLike[int] = 3,
**kwargs: Any,
):
"""Initialize the parameters for average pooling.

This constructor initializes the parameters for average pooling.
This constructor initializes the parameters for average-pooling.

Parameters
----------
Expand All @@ -1248,6 +1248,114 @@ def __init__(

super().__init__(np.mean, ksize=ksize, **kwargs)

def get(
self: AveragePooling,
image: NDArray[Any] | torch.Tensor,
ksize: int=3,
**kwargs: Any,
) -> NDArray[Any] | torch.Tensor:
"""Average pooling of input.

Checks the current backend and chooses the appropriate function to pool
the input image, either `._get_torch()` or `._get_numpy()`.

Parameters
----------
image: array or tensor
Input array or tensor to be pooled.
ksize: int
Kernel size of the pooling operation.

Returns
-------
array or tensor
The pooled input as `NDArray` or `torch.Tensor` depending on
the backend.

"""

if self.get_backend() == "numpy":
return self._get_numpy(image, ksize, **kwargs)

if self.get_backend() == "torch":
return self._get_torch(image, ksize, **kwargs)

raise NotImplementedError(f"Backend {self.backend} not supported")

def _get_numpy(
self: AveragePooling,
image: NDArray[Any],
ksize: int = 3,
**kwargs: Any,
) -> NDArray[Any]:
"""Average pooling with the NumPy backend enabled.

Returns the result of the image passed to the scikit image
`block_reduce()` function with `np.mean()` as the pooling function.

Parameters
----------
image: NDArray
Input array to be pooled.
ksize: int
Kernel size of the pooling operation.

Returns
-------
array
The pooled image as a NumPy array.

"""

return utils.safe_call(
skimage.measure.block_reduce,
image=image,
func=np.average,
block_size=ksize,
**kwargs,
)

def _get_torch(
self: AveragePooling,
image: torch.Tensor,
ksize: int=3,
**kwargs: Any,
) -> torch.Tensor:
"""Average pooling with the PyTorch backend enabled.

Returns the result of the image passed to a Pytorch average pooling
layer.

Parameters
----------
image: torch.Tensor
Input tensor to be pooled.
ksize: int
Kernel size of the pooling operation.

Returns
-------
torch.Tensor
The pooled image as a `torch.Tensor`.

"""

# If input tensor is 2D
if len(image.shape) == 2:
# Add batch dimension for max pooling.
expanded_image = image.unsqueeze(0)

pooled_image = torch.nn.functional.avg_pool2d(
expanded_image, kernel_size=ksize,
)
# Remove the expanded dim.
return pooled_image.squeeze(0)

return torch.nn.functional.avg_pool2d(
image,
kernel_size=ksize,
)


class MaxPooling(Pool):
"""Apply max-pooling to images.
Expand Down
15 changes: 13 additions & 2 deletions deeptrack/tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@ def test_Blur(self):
#self.assertTrue(xp.all(blurred_image == expected_output))



def test_AveragePooling(self):
input_image = xp.asarray([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=float)
feature = math.AveragePooling(ksize=2)
pooled_image = feature.resolve(input_image)

expected = xp.asarray([[3.5, 5.5]])

self.assertTrue(xp.all(pooled_image == expected))
self.assertEqual(pooled_image.shape, (1, 2))

def test_MaxPooling(self):
input_image = xp.asarray([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=float)
feature = math.MaxPooling(ksize=2)
Expand All @@ -109,8 +120,7 @@ def test_MinPooling(self):
class TestMath_Torch(TestMath_Numpy):
BACKEND = "torch"
pass



class TestMath(unittest.TestCase):

def test_GaussianBlur(self):
Expand All @@ -130,6 +140,7 @@ def test_AveragePooling(self):
pooled_image = feature.resolve(input_image)
self.assertTrue(np.all(pooled_image == [[3.5, 5.5]]))


def test_MaxPooling(self):
input_image = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
feature = math.MaxPooling(ksize=2)
Expand Down
Loading