diff --git a/deeptrack/__init__.py b/deeptrack/__init__.py index 189550bb5..daf7ada36 100644 --- a/deeptrack/__init__.py +++ b/deeptrack/__init__.py @@ -24,7 +24,7 @@ # Create a unit registry with custom pixel-related units. units_registry = UnitRegistry(pint_definitions.split("\n")) - +units = units_registry # Alias for backward compatibility from deeptrack.backend import * diff --git a/deeptrack/features.py b/deeptrack/features.py index 43e809612..b2edb1024 100644 --- a/deeptrack/features.py +++ b/deeptrack/features.py @@ -8231,10 +8231,16 @@ def get( If the input `factor` is not a valid integer or tuple of integers. """ - + # TBE: this seems to create an issue with image normalization when + # only one number is give. IT automatically replicate this value in the + # 3D but pooling is actually only done in 2D. I suggest if only `factor` + # is given to transform it into (facto, factor, 1) by default. + # This should also ensure backcompatibility. + # Ensure factor is a tuple of three integers. if np.size(factor) == 1: - factor = (factor,) * 3 + # factor = (factor,) * 3 + factor = (factor, factor, 1) elif len(factor) != 3: raise ValueError( "Factor must be an integer or a tuple of three integers." @@ -8245,12 +8251,16 @@ def get( with units.context(ctx): image = self.feature(image) - # Downscale the result to the original resolution. - import skimage.measure - image = skimage.measure.block_reduce( - image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean - ) + # NOTE: The downscaling step is disabled and taken care in + # deeptrack.optics since it now depends on scatter.main_property + + # # Downscale the result to the original resolution. + # import skimage.measure + + # image = skimage.measure.block_reduce( + # image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean + # ) return image @@ -9707,4 +9717,4 @@ def get( if len(res) == 1: res = res[0] - return res + return res \ No newline at end of file diff --git a/deeptrack/optics.py b/deeptrack/optics.py index ab2ff4203..ef7c6b7b9 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -142,7 +142,9 @@ def _pad_volume( import numpy as np from numpy.typing import NDArray -from scipy.ndimage import convolve +from scipy.ndimage import convolve #check if still necessary +import torch +import torch.nn.functional as F from deeptrack.backend.units import ( ConversionTable, @@ -150,7 +152,7 @@ def _pad_volume( get_active_scale, get_active_voxel_size, ) -from deeptrack.math import AveragePooling +# from deeptrack.math import AveragePooling from deeptrack.features import propagate_data_to_dependencies from deeptrack.features import DummyFeature, Feature, StructuralFeature from deeptrack.image import Image, pad_image_to_fft @@ -159,9 +161,15 @@ def _pad_volume( from deeptrack import image from deeptrack import units_registry as u -if TYPE_CHECKING: +from deeptrack import TORCH_AVAILABLE, image +from deeptrack.backend import xp +from deeptrack.scatterers import ScatteredVolume, ScatteredField + +if TORCH_AVAILABLE: import torch +if TYPE_CHECKING: + import torch #TODO ***??*** revise Microscope - torch, typing, docstring, unit test class Microscope(StructuralFeature): @@ -242,11 +250,13 @@ def __init__( super().__init__(**kwargs) + print(">>> creating Microscope", type(sample), type(objective)) self._sample = self.add_feature(sample) self._objective = self.add_feature(objective) #TODO: erase following line when rid of Image - self._sample.store_properties() + # self._sample.store_properties() + print(">>> creating Microscope", type(self._sample)) def get( self: Microscope, @@ -302,7 +312,11 @@ def get( with u.context(create_context(*objective_properties["voxel_size"])): + # Following code does nothing is upscale is (1, 1, 1). + # It is needed if dt.Upscale is used + upscale = np.round(get_active_scale()) + print(">>> upscale", upscale) def _scale_region_2d( region: list[int], @@ -354,7 +368,8 @@ def _scale_region_2d( volume_scatterers = [ scatterer for scatterer in list_of_scatterers - if not scatterer.get_property("is_field", default=False) + if isinstance(scatterer, ScatteredVolume) + # if not scatterer.get_property("is_field", default=False) ] # All scatterers that are defined as fields. @@ -363,16 +378,10 @@ def _scale_region_2d( field_scatterers = [ scatterer for scatterer in list_of_scatterers - if scatterer.get_property("is_field", default=False) + if isinstance(scatterer, ScatteredField) + # if scatterer.get_property("is_field", default=False) ] - - - - - - - # Merge all volumes into a single volume. sample_volume, limits = _create_volume( volume_scatterers, @@ -393,6 +402,12 @@ def _scale_region_2d( imaged_sample = self._objective.resolve(sample_volume) + + # Handling upscale from dt.Upscale() here to eliminate Image + # wrapping issues. + if np.any(np.array(upscale) != 1): + imaged_sample = _downscale_scatterer(imaged_sample, upscale[:2], scatterer.main_property) + #TODO: TBE """ # Handling separately upscale given by optics. @@ -540,7 +555,7 @@ def __init__( output_region: PropertyLike[ArrayLike[int]] = (0, 0, 128, 128), pupil: Feature = None, illumination: Feature = None, - upscale: int = 1, + upscale: int = 1, # to be deprecated in favor of dt.Upscale() **kwargs: Any, ): """Initialize the `Optics` instance. @@ -964,7 +979,8 @@ def __call__( True """ - from deeptrack.scatterers import MieScatterer # Temporary place for this import. + from deeptrack.scatterers import \ + MieScatterer # Temporary place for this import. if isinstance(self, (Darkfield, ISCAT, Holography)) and not isinstance(sample, MieScatterer): warnings.warn( @@ -1845,7 +1861,7 @@ def get( #TODO ***??*** revise _get_position - torch, typing, docstring, unit test def _get_position( - image: Image, + scatterer: ScatteredVolume, mode: str = "corner", return_z: bool = False, ) -> np.ndarray: @@ -1868,38 +1884,38 @@ def _get_position( """ num_outputs = 2 + return_z - - if mode == "corner" and image.size > 0: + if mode == "corner" and scatterer.array.size > 0: import scipy.ndimage - image = image.to_numpy() - - shift = scipy.ndimage.center_of_mass(np.abs(image)) + shift = scipy.ndimage.center_of_mass(np.abs(scatterer.array)) if np.isnan(shift).any(): - shift = np.array(image.shape) / 2 + shift = np.array(scatterer.array.shape) / 2 else: shift = np.zeros((num_outputs)) - position = np.array(image.get_property("position", default=None)) + position = np.array(scatterer.get_property("position", default=None)) if position is None: return position scale = np.array(get_active_scale()) - + print("image size", scatterer.get_property("z", default=0)) + print(scale) + print(shift) + print(np.array([position[0], position[1], scatterer.get_property("z", default=0)])) if len(position) == 3: position = position * scale + 0.5 * (scale - 1) if return_z: return position * scale - shift else: return position[0:2] - shift[0:2] - + elif len(position) == 2: if return_z: outp = ( - np.array([position[0], position[1], image.get_property("z", default=0)]) + np.array([position[0], position[1], scatterer.get_property("z", default=0)]) * scale - shift + 0.5 * (scale - 1) @@ -1911,40 +1927,112 @@ def _get_position( return position +def _bilinear_interpolate_numpy( + scatterer: np.ndarray, x_off: float, y_off: float +) -> np.ndarray: + """Apply bilinear subpixel interpolation in the x–y plane (NumPy).""" + kernel = np.array( + [ + [0.0, 0.0, 0.0], + [0.0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off], + [0.0, x_off * (1 - y_off), x_off * y_off], + ] + ) + out = np.zeros_like(scatterer) + for z in range(scatterer.shape[2]): + if np.iscomplexobj(scatterer): + out[:, :, z] = ( + convolve(np.real(scatterer[:, :, z]), kernel, mode="constant") + + 1j + * convolve(np.imag(scatterer[:, :, z]), kernel, mode="constant") + ) + else: + out[:, :, z] = convolve(scatterer[:, :, z], kernel, mode="constant") + return out + + +def _bilinear_interpolate_torch( + scatterer: torch.Tensor, x_off: float, y_off: float +) -> torch.Tensor: + """Apply bilinear subpixel interpolation in the x–y plane (Torch). + + Uses grid_sample for autograd-friendly interpolation. + """ + H, W, D = scatterer.shape + + # Normalized shifts in [-1,1] + x_shift = 2 * x_off / (W - 1) + y_shift = 2 * y_off / (H - 1) + + yy, xx = torch.meshgrid( + torch.linspace(-1, 1, H, device=scatterer.device, dtype=scatterer.dtype), + torch.linspace(-1, 1, W, device=scatterer.device, dtype=scatterer.dtype), + indexing="ij", + ) + grid = torch.stack((xx + x_shift, yy + y_shift), dim=-1) # (H,W,2) + grid = grid.unsqueeze(0).repeat(D, 1, 1, 1) # (D,H,W,2) + + inp = scatterer.permute(2, 0, 1).unsqueeze(1) # (D,1,H,W) + + out = F.grid_sample(inp, grid, mode="bilinear", + padding_mode="zeros", align_corners=True) + return out.squeeze(1).permute(1, 2, 0) # (H,W,D) + + #TODO ***??*** revise _create_volume - torch, typing, docstring, unit test def _create_volume( - list_of_scatterers: list, - pad: tuple = (0, 0, 0, 0), - output_region: tuple = (None, None, None, None), + list_of_scatterers: ArrayLike | Sequence[ArrayLike], + pad: tuple[int, int, int, int] = (0, 0, 0, 0), + output_region: tuple[int | None, int | None, int | None, int | None] = (None, None, None, None), refractive_index_medium: float = 1.33, **kwargs: Any, -) -> tuple: - """Converts a list of scatterers into a volumetric representation. +) -> tuple[ArrayLike, np.ndarray]: + """Assemble a volumetric representation from a list of scatterers. + + Each scatterer is represented as an ND array (numpy or torch), with + associated properties such as ``position``, ``intensity``, + ``refractive_index``, or ``value``. Scatterers are inserted into a common + 3D volume with optional padding, output region cropping, and subpixel + interpolation. Parameters ---------- - list_of_scatterers: list or single scatterer - List of scatterers to include in the volume. - pad: tuple of int, optional - Padding for the volume in the format (left, right, top, bottom). - Default is (0, 0, 0, 0). - output_region: tuple of int, optional - Region to output, defined as (x_min, y_min, x_max, y_max). Default is - None. - refractive_index_medium: float, optional - Refractive index of the medium surrounding the scatterers. Default is - 1.33. - **kwargs: Any - Additional arguments for customization. + list_of_scatterers : ArrayLike or Sequence[ArrayLike] + Single scatterer or sequence of scatterers to include in the volume. + Each scatterer must be an ``ndarray`` or ``torch.Tensor`` with + shape ``(nx, ny, nz)`` (or compatible) and carry a ``.properties`` + dictionary including at least ``"position"``. + pad : tuple of int, optional + Padding for the volume in the format ``(left, right, top, bottom)``. + Default is ``(0, 0, 0, 0)``. + output_region : tuple of int or None, optional + Region to output, defined as ``(x_min, y_min, x_max, y_max)``. + Default is ``(None, None, None, None)``, meaning unbounded. + refractive_index_medium : float, optional + Refractive index of the surrounding medium. Default is ``1.33``. + **kwargs : Any + Additional keyword arguments for customization. Returns ------- - tuple - - volume: numpy.ndarray - The generated volume containing the scatterers. - - limits: numpy.ndarray - Spatial limits of the volume. - + volume : ArrayLike + The generated 3D volume containing all scatterers. + Type matches the input backend: ``numpy.ndarray`` or ``torch.Tensor``. + limits : numpy.ndarray of shape (3, 2) + Spatial limits of the volume along each axis, as integers. + ``limits[:, 0]`` are the minima, ``limits[:, 1]`` the maxima. + + Raises + ------ + UserWarning + If a scatterer does not define a valid ``position`` property. + + Notes + ----- + - Subpixel positioning is handled by bilinear interpolation + (NumPy: convolution, Torch: grid_sample). + - Overlapping scatterers are **added** together in the volume. + """ if not isinstance(list_of_scatterers, list): @@ -1953,7 +2041,7 @@ def _create_volume( volume = np.zeros((1, 1, 1), dtype=complex) limits = None OR = np.zeros((4,)) - OR[0] = np.inf if output_region[0] is None else int( + OR[0] =-np.inf if output_region[0] is None else int( output_region[0] - pad[0] ) OR[1] = -np.inf if output_region[1] is None else int( @@ -1962,7 +2050,7 @@ def _create_volume( OR[2] = np.inf if output_region[2] is None else int( output_region[2] + pad[2] ) - OR[3] = -np.inf if output_region[3] is None else int( + OR[3] = np.inf if output_region[3] is None else int( output_region[3] + pad[3] ) @@ -1970,24 +2058,24 @@ def _create_volume( # This accounts for upscale doing AveragePool instead of SumPool. This is # a bit of a hack, but it works for now. - fudge_factor = scale[0] * scale[1] / scale[2] + # fudge_factor = scale[0] * scale[1] / scale[2] for scatterer in list_of_scatterers: + print(">>> scatterer type:", type(scatterer)) + # print(">>> scatterer properties:", scatterer.get_property("radius", None)) + position = _get_position(scatterer, mode="corner", return_z=True) - if scatterer.get_property("intensity", None) is not None: - intensity = scatterer.get_property("intensity") - scatterer_value = intensity * fudge_factor - elif scatterer.get_property("refractive_index", None) is not None: - refractive_index = scatterer.get_property("refractive_index") - scatterer_value = ( - refractive_index - refractive_index_medium - ) - else: + if scatterer.main_property == "intensity": + scatterer_value = scatterer.get_property("intensity") #* fudge_factor + elif scatterer.main_property == "refractive_index": + scatterer_value = scatterer.get_property("refractive_index") - refractive_index_medium + else: # fallback to generic value scatterer_value = scatterer.get_property("value") - scatterer = scatterer * scatterer_value + # Scale the array accordingly + scatterer.array = scatterer.array * scatterer_value if limits is None: limits = np.zeros((3, 2), dtype=np.int32) @@ -1995,26 +2083,26 @@ def _create_volume( limits[:, 1] = np.floor(position).astype(np.int32) + 1 if ( - position[0] + scatterer.shape[0] < OR[0] + position[0] + scatterer.array.shape[0] < OR[0] or position[0] > OR[2] - or position[1] + scatterer.shape[1] < OR[1] + or position[1] + scatterer.array.shape[1] < OR[1] or position[1] > OR[3] ): continue - padded_scatterer = Image( - np.pad( - scatterer, - [(2, 2), (2, 2), (2, 2)], - "constant", - constant_values=0, - ) + # at this point properties do not seem any longer necessary + # just keep the array part + padded_scatterer = np.pad( + scatterer.array, + [(2, 2), (2, 2), (2, 2)], + "constant", + constant_values=0, ) - padded_scatterer.merge_properties_from(scatterer) + # padded_scatterer.merge_properties_from(scatterer) - scatterer = padded_scatterer + # scatterer = padded_scatterer position = _get_position(scatterer, mode="corner", return_z=True) - shape = np.array(scatterer.shape) + shape = np.array(padded_scatterer.shape) if position is None: RuntimeWarning( @@ -2023,37 +2111,22 @@ def _create_volume( ) continue - splined_scatterer = np.zeros_like(scatterer) + # splined_scatterer = np.zeros_like(padded_scatterer) x_off = position[0] - np.floor(position[0]) y_off = position[1] - np.floor(position[1]) + + if isinstance(padded_scatterer, np.ndarray): # get_backend is a method of Features and not exposed + splined_scatterer = _bilinear_interpolate_numpy(padded_scatterer, x_off, y_off) + elif isinstance(padded_scatterer, torch.Tensor): + splined_scatterer = _bilinear_interpolate_torch(padded_scatterer, x_off, y_off) + else: + raise TypeError( + f"Unsupported array type {type(padded_scatterer)}. " + "Expected np.ndarray or torch.Tensor." + ) - kernel = np.array( - [ - [0, 0, 0], - [0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off], - [0, x_off * (1 - y_off), x_off * y_off], - ] - ) - - for z in range(scatterer.shape[2]): - if splined_scatterer.dtype == complex: - splined_scatterer[:, :, z] = ( - convolve( - np.real(scatterer[:, :, z]), kernel, mode="constant" - ) - + convolve( - np.imag(scatterer[:, :, z]), kernel, mode="constant" - ) - * 1j - ) - else: - splined_scatterer[:, :, z] = convolve( - scatterer[:, :, z], kernel, mode="constant" - ) - - scatterer = splined_scatterer - position = np.floor(position) + position = np.floor(position) # check or change name, this is position on the grid new_limits = np.zeros(limits.shape, dtype=np.int32) for i in range(3): new_limits[i, :] = ( @@ -2081,7 +2154,8 @@ def _create_volume( within_volume_position = position - limits[:, 0] - # NOTE: Maybe shouldn't be additive. + # NOTE: Maybe shouldn't be additive + # give options: sum default, but also sum, mean, max, min volume[ int(within_volume_position[0]) : int(within_volume_position[0] + shape[0]), @@ -2091,5 +2165,69 @@ def _create_volume( int(within_volume_position[2]) : int(within_volume_position[2] + shape[2]), - ] += scatterer + ] += splined_scatterer return volume, limits + +# TODO: replace the inner part of this function with AveragePooling from math when +# implemented with torch +def _downscale_scatterer(array, factor, main_property="value"): + """Downscale scatterer array by sum or average pooling depending on property. + + Parameters + ---------- + array : np.ndarray or torch.Tensor + The scatterer array. + factor : tuple[int, int] + Downscale factor (ux, uy). + main_property : str + Determines pooling strategy: + - "intensity" → sum pooling + - else → average pooling + + Returns + ------- + np.ndarray or torch.Tensor + Downscaled array in the same backend as input. + """ + + # Decide pooling op + is_sum = main_property == "intensity" + + # Ensure factor is integer + factor = tuple(int(f) for f in factor) + + + # Case 1: NumPy backend + if isinstance(array, np.ndarray): + import skimage.measure + + pool_shape = (factor[0], factor[1]) + (1,) * (array.ndim - 2) + func = np.sum if is_sum else np.mean + return skimage.measure.block_reduce(array, pool_shape, func) + + # Case 2: Torch backend + elif isinstance(array, torch.Tensor): + if array.ndim < 2: + raise ValueError("Torch pooling requires at least 2D input.") + + # spatial pooling only + kernel = (factor[0], factor[1]) + stride = kernel + + # Flatten extra dims as channels for pooling + b, c = 1, int(np.prod(array.shape[2:])) if array.ndim > 2 else 1 + h, w = array.shape[0], array.shape[1] + x = array.reshape(1, c, h, w) + + if is_sum: + pooled = torch.nn.functional.avg_pool2d(x, kernel, stride) * (kernel[0] * kernel[1]) + else: + pooled = torch.nn.functional.avg_pool2d(x, kernel, stride) + + # Reshape back + new_shape = (pooled.shape[2], pooled.shape[3]) + tuple(array.shape[2:]) + return pooled.reshape(new_shape) + + else: + raise TypeError("Unsupported array type: expected np.ndarray or torch.Tensor.") + diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py index 04a7c5eae..b337aa2fc 100644 --- a/deeptrack/scatterers.py +++ b/deeptrack/scatterers.py @@ -166,6 +166,7 @@ import numpy as np from numpy.typing import NDArray from pint import Quantity +from dataclasses import dataclass, field from deeptrack.holography import get_propagation_matrix from deeptrack.backend.units import ( @@ -246,6 +247,9 @@ class Scatterer(Feature): voxel_size=(u.meter, u.meter), ) + #: Default property name (subclasses override this) + main_property: str = "value" + def __init__( self, position: ArrayLike[float] = (32, 32), @@ -310,7 +314,7 @@ def _process_and_get( voxel_size = get_active_voxel_size() # Calls parent _process_and_get. - new_image = super()._process_and_get( + new_image = super(Scatterer, self)._process_and_get( *args, voxel_size=voxel_size, upsample=upsample, @@ -333,32 +337,41 @@ def _process_and_get( new_image = new_image[:, ~np.all(new_image == 0, axis=(0, 2))] new_image = new_image[:, :, ~np.all(new_image == 0, axis=(0, 1))] - return [Image(new_image)] - - def _no_wrap_format_input( - self, - *args, - **kwargs - ) -> list: - return self._image_wrapped_format_input(*args, **kwargs) - - def _no_wrap_process_and_get( - self, - *args, - **feature_input - ) -> list: - return self._image_wrapped_process_and_get(*args, **feature_input) - - def _no_wrap_process_output( - self, - *args, - **feature_input - ) -> list: - return self._image_wrapped_process_output(*args, **feature_input) + # # Copy properties + # props = kwargs.copy() + return [self._wrap_output(new_image, kwargs)] + + def _wrap_output(self, array, props) -> ScatteredBase: + """Must be overridden in subclasses to wrap output correctly.""" + raise NotImplementedError + +class VolumeScatterer(Scatterer): + """Abstract scatterer producing ScatteredVolume outputs.""" + def _wrap_output(self, array, props) -> ScatteredVolume: + return [ScatteredVolume( + array=array, + position=props.get("position", (0, 0)), + z=props.get("z", 0.0), + value=props.get("value", 1.0), + intensity=props.get("intensity", None), + refractive_index=props.get("refractive_index", None), + properties=props.copy(), + main_property=self.main_property, + )] + +class FieldScatterer(Scatterer): + def _wrap_output(self, array, props) -> ScatteredField: + return ScatteredField( + array=array, + position=props.get("position", (0, 0)), + wavelength=props.get("wavelength", 0.0), + properties=props.copy(), + main_property=self.main_property, + ) #TODO ***??*** revise PointParticle - torch, typing, docstring, unit test -class PointParticle(Scatterer): +class PointParticle(VolumeScatterer): """Generate a diffraction-limited point particle. A point particle is approximated by the size of a single pixel or voxel. @@ -382,6 +395,8 @@ class PointParticle(Scatterer): """ + main_property = "intensity" + def __init__( self: PointParticle, **kwargs: Any, @@ -405,7 +420,7 @@ def get( #TODO ***??*** revise Ellipse - torch, typing, docstring, unit test -class Ellipse(Scatterer): +class Ellipse(VolumeScatterer): """Generates an elliptical disk scatterer Parameters @@ -446,6 +461,8 @@ class Ellipse(Scatterer): rotation=(u.radian, u.radian), ) + main_property = "refractive_index" + def __init__( self, radius: float = 1e-6, @@ -519,7 +536,7 @@ def get( #TODO ***??*** revise Sphere - torch, typing, docstring, unit test -class Sphere(Scatterer): +class Sphere(VolumeScatterer): """Generates a spherical scatterer Parameters @@ -550,6 +567,8 @@ class Sphere(Scatterer): radius=(u.meter, u.meter), ) + main_property = "refractive_index" + def __init__( self, radius: float = 1e-6, @@ -741,7 +760,7 @@ def get( #TODO ***??*** revise MieScatterer - torch, typing, docstring, unit test -class MieScatterer(Scatterer): +class MieScatterer(FieldScatterer): """Base implementation of a Mie particle. New Mie-theory scatterers can be implemented by extending this class, and @@ -864,11 +883,11 @@ def __init__( "Please use input_polarization instead" ) input_polarization = polarization_angle - kwargs.pop("is_field", None) + kwargs.pop("is_field", None) # remove kwargs.pop("crop_empty", None) super().__init__( - is_field=True, + is_field=True, # remove crop_empty=False, L=L, offset_z=offset_z, @@ -1412,3 +1431,53 @@ def inner( refractive_index=refractive_index, **kwargs, ) + + +@dataclass +class ScatteredBase: + """Base class for scatterers (volumes and fields).""" + + array: ArrayLike + position: np.ndarray + z: float = 0.0 + properties: dict[str, Any] = field(default_factory=dict) + main_property: str = None + + def __post_init__(self): + self.position = np.array(self.position, dtype=float).reshape(-1)[:2] + self.z = float(np.atleast_1d(self.z).squeeze()) + + @property + def pos3d(self) -> np.ndarray: + return np.array([*self.position, self.z], dtype=float) + + def as_array(self) -> ArrayLike: + """Return the underlying array. + + Notes + ----- + The raw array is also directly available as ``scatterer.array``. + This method exists mainly for API compatibility and clarity. + + """ + + return self.array + + def get_property(self, key: str, default: Any = None) -> Any: + return getattr(self, key, self.properties.get(key, default)) + + +@dataclass +class ScatteredVolume(ScatteredBase): + """Volumetric object: intensity sources or refractive index contrasts.""" + + refractive_index: float | None = None + intensity: float | None = None + value: float | None = None + + +@dataclass +class ScatteredField(ScatteredBase): + """Complex wavefield (already propagated or emitted).""" + + wavelength: float = 500e-9 \ No newline at end of file diff --git a/testing_code.ipynb b/testing_code.ipynb new file mode 100644 index 000000000..ea4c597f2 --- /dev/null +++ b/testing_code.ipynb @@ -0,0 +1,559 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "ea5fc72f", + "metadata": {}, + "outputs": [ + { + "ename": "ImportError", + "evalue": "cannot import name 'TORCH_AVAILABLE' from partially initialized module 'deeptrack' (most likely due to a circular import) (/Users/841602/Documents/GitHub/DeepTrack2/deeptrack/__init__.py)", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mImportError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mdeeptrack\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mdt\u001b[39;00m\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmatplotlib\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mpyplot\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mplt\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/GitHub/DeepTrack2/deeptrack/__init__.py:29\u001b[39m\n\u001b[32m 25\u001b[39m \u001b[38;5;66;03m# Create a unit registry with custom pixel-related units.\u001b[39;00m\n\u001b[32m 26\u001b[39m units_registry = UnitRegistry(pint_definitions.split(\u001b[33m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m\"\u001b[39m))\n\u001b[32m---> \u001b[39m\u001b[32m29\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mdeeptrack\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01maberrations\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m *\n\u001b[32m 30\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mdeeptrack\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01maugmentations\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m *\n\u001b[32m 31\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mdeeptrack\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mbackend\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m *\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/GitHub/DeepTrack2/deeptrack/aberrations.py:84\u001b[39m\n\u001b[32m 80\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtyping\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Any\n\u001b[32m 82\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m84\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mdeeptrack\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mfeatures\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Feature\n\u001b[32m 85\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mdeeptrack\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mtypes\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m PropertyLike\n\u001b[32m 86\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mdeeptrack\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mutils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m as_list\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/GitHub/DeepTrack2/deeptrack/features.py:173\u001b[39m\n\u001b[32m 171\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mdeeptrack\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mbackend\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m TORCH_AVAILABLE, config, xp\n\u001b[32m 172\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mdeeptrack\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mbackend\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mcore\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m DeepTrackNode\n\u001b[32m--> \u001b[39m\u001b[32m173\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mdeeptrack\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mbackend\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01munits\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m ConversionTable, create_context\n\u001b[32m 174\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mdeeptrack\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mimage\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Image\n\u001b[32m 175\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mdeeptrack\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mproperties\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m PropertyDict, SequentialProperty\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/GitHub/DeepTrack2/deeptrack/backend/units.py:109\u001b[39m\n\u001b[32m 106\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m ndarray\n\u001b[32m 107\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpint\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Context, Quantity, Unit\n\u001b[32m--> \u001b[39m\u001b[32m109\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mdeeptrack\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m TORCH_AVAILABLE\n\u001b[32m 110\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mdeeptrack\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m units_registry \u001b[38;5;28;01mas\u001b[39;00m u\n\u001b[32m 112\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m TORCH_AVAILABLE:\n", + "\u001b[31mImportError\u001b[39m: cannot import name 'TORCH_AVAILABLE' from partially initialized module 'deeptrack' (most likely due to a circular import) (/Users/841602/Documents/GitHub/DeepTrack2/deeptrack/__init__.py)" + ] + } + ], + "source": [ + "import deeptrack as dt\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "particle = dt.Sphere(\n", + " position=np.array([0.5, 0.5]) * 64, position_unit=\"pixel\",\n", + " radius=1 , refractive_index=1.45, z = 0,\n", + ")\n", + "\n", + "# particle = dt.PointParticle(\n", + "# position=np.array([0.5, 0.5]) * 64, position_unit=\"pixel\",\n", + "# radius=500 * dt.units.nm, refractive_index=1.45 + 0.02j,\n", + "# )\n", + "\n", + "\n", + "brightfield_microscope = dt.Brightfield(\n", + " wavelength=500 * 1E-9, NA=1.0, resolution=2E-7,\n", + " magnification=2, refractive_index_medium=1.33,\n", + " output_region=(0, 0, 64, 64),\n", + ")\n", + "\n", + "imaged_scatterer = brightfield_microscope(particle)\n", + "image = imaged_scatterer()\n", + "\n", + "fig = plt.figure(figsize=(4, 4))\n", + "plt.imshow(image, cmap=\"gray\")\n", + "plt.show()\n", + "\n", + "\n", + "imaged_scatterer.print_dependencies_tree()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "de965325", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Sphere(len=1, action=)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "particle" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b9cebb43", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'scatterer' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mNameError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[43mscatterer\u001b[49m.properties()\n", + "\u001b[31mNameError\u001b[39m: name 'scatterer' is not defined" + ] + } + ], + "source": [ + "scatterer.properties()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37a5778e", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any, TYPE_CHECKING\n", + "from deeptrack.backend.units import (\n", + " ConversionTable,\n", + " create_context,\n", + " get_active_scale,\n", + " get_active_voxel_size,\n", + ")\n", + "from deeptrack.image import Image, pad_image_to_fft\n", + "\n", + "\n", + "def _get_position(\n", + " image: Image,\n", + " mode: str = \"corner\",\n", + " return_z: bool = False,\n", + ") -> np.ndarray:\n", + " \"\"\"Extracts the position of the upper-left corner of a scatterer.\n", + "\n", + " Parameters\n", + " ----------\n", + " image: numpy.ndarray\n", + " Input image or volume containing the scatterer.\n", + " mode: str, optional\n", + " Mode for position extraction. Default is \"corner\".\n", + " return_z: bool, optional\n", + " Whether to include the z-coordinate in the output. Default is False.\n", + "\n", + " Returns\n", + " -------\n", + " numpy.ndarray\n", + " Array containing the position of the scatterer.\n", + " \n", + " \"\"\"\n", + "\n", + " num_outputs = 2 + return_z\n", + "\n", + " if mode == \"corner\" and image.size > 0:\n", + " import scipy.ndimage\n", + "\n", + " image = image.to_numpy()\n", + "\n", + " shift = scipy.ndimage.center_of_mass(np.abs(image))\n", + "\n", + " if np.isnan(shift).any():\n", + " shift = np.array(image.shape) / 2\n", + "\n", + " else:\n", + " shift = np.zeros((num_outputs))\n", + "\n", + " position = np.array(image.get_property(\"position\", default=None))\n", + "\n", + " if position is None:\n", + " return position\n", + "\n", + " scale = np.array(get_active_scale())\n", + "\n", + " if len(position) == 3:\n", + " position = position * scale + 0.5 * (scale - 1)\n", + " if return_z:\n", + " return position * scale - shift\n", + " else:\n", + " return position[0:2] - shift[0:2]\n", + "\n", + " elif len(position) == 2:\n", + " if return_z:\n", + " outp = (\n", + " np.array([position[0], position[1], image.get_property(\"z\", default=0)])\n", + " * scale\n", + " - shift\n", + " + 0.5 * (scale - 1)\n", + " )\n", + " return outp\n", + " else:\n", + " return position * scale[:2] - shift[0:2] + 0.5 * (scale[:2] - 1)\n", + "\n", + " return position\n", + "\n", + "\n", + "# TODO ***??*** revise _create_volume - torch, typing, docstring, unit test\n", + "def _create_volume(\n", + " list_of_scatterers: list,\n", + " pad: tuple = (0, 0, 0, 0),\n", + " output_region: tuple = (None, None, None, None),\n", + " refractive_index_medium: float = 1.33,\n", + " **kwargs: Any,\n", + ") -> tuple:\n", + " \"\"\"Converts a list of scatterers into a volumetric representation.\n", + "\n", + " Parameters\n", + " ----------\n", + " list_of_scatterers: list or single scatterer\n", + " List of scatterers to include in the volume.\n", + " pad: tuple of int, optional\n", + " Padding for the volume in the format (left, right, top, bottom).\n", + " Default is (0, 0, 0, 0).\n", + " output_region: tuple of int, optional\n", + " Region to output, defined as (x_min, y_min, x_max, y_max). Default is \n", + " None.\n", + " refractive_index_medium: float, optional\n", + " Refractive index of the medium surrounding the scatterers. Default is \n", + " 1.33.\n", + " **kwargs: Any\n", + " Additional arguments for customization.\n", + "\n", + " Returns\n", + " -------\n", + " tuple\n", + " - volume: numpy.ndarray\n", + " The generated volume containing the scatterers.\n", + " - limits: numpy.ndarray\n", + " Spatial limits of the volume.\n", + "\n", + " \"\"\"\n", + "\n", + " if not isinstance(list_of_scatterers, list):\n", + " list_of_scatterers = [list_of_scatterers]\n", + "\n", + " volume = np.zeros((1, 1, 1), dtype=complex)\n", + " limits = None\n", + " OR = np.zeros((4,))\n", + " OR[0] = -np.inf if output_region[0] is None else int(\n", + " output_region[0] - pad[0]\n", + " )\n", + " OR[1] = np.inf if output_region[1] is None else int(\n", + " output_region[1] - pad[1]\n", + " )\n", + " OR[2] = -np.inf if output_region[2] is None else int(\n", + " output_region[2] + pad[2]\n", + " )\n", + " OR[3] = np.inf if output_region[3] is None else int(\n", + " output_region[3] + pad[3]\n", + " )\n", + "\n", + " scale = np.array(get_active_scale())\n", + "\n", + " # This accounts for upscale doing AveragePool instead of SumPool. This is\n", + " # a bit of a hack, but it works for now.\n", + " fudge_factor = scale[0] * scale[1] / scale[2]\n", + "\n", + " for scatterer in list_of_scatterers:\n", + "\n", + " position = _get_position(scatterer, mode=\"corner\", return_z=True)\n", + "\n", + " if scatterer.get_property(\"intensity\", None) is not None:\n", + " intensity = scatterer.get_property(\"intensity\")\n", + " scatterer_value = intensity * fudge_factor\n", + " elif scatterer.get_property(\"refractive_index\", None) is not None:\n", + " refractive_index = scatterer.get_property(\"refractive_index\")\n", + " scatterer_value = (\n", + " refractive_index - refractive_index_medium\n", + " )\n", + " else:\n", + " scatterer_value = scatterer.get_property(\"value\")\n", + "\n", + " scatterer = scatterer * scatterer_value\n", + "\n", + " if limits is None:\n", + " limits = np.zeros((3, 2), dtype=np.int32)\n", + " limits[:, 0] = np.floor(position).astype(np.int32)\n", + " limits[:, 1] = np.floor(position).astype(np.int32) + 1\n", + "\n", + " if (\n", + " position[0] + scatterer.shape[0] < OR[0]\n", + " or position[0] > OR[2]\n", + " or position[1] + scatterer.shape[1] < OR[1]\n", + " or position[1] > OR[3]\n", + " ):\n", + " continue\n", + "\n", + " padded_scatterer = Image(\n", + " np.pad(\n", + " scatterer,\n", + " [(2, 2), (2, 2), (2, 2)],\n", + " \"constant\",\n", + " constant_values=0,\n", + " )\n", + " )\n", + " padded_scatterer.merge_properties_from(scatterer)\n", + "\n", + " scatterer = padded_scatterer\n", + " position = _get_position(scatterer, mode=\"corner\", return_z=True)\n", + " shape = np.array(scatterer.shape)\n", + "\n", + " if position is None:\n", + " RuntimeWarning(\n", + " \"Optical device received an image without a position property.\"\n", + " \" It will be ignored.\"\n", + " )\n", + " continue\n", + "\n", + " splined_scatterer = np.zeros_like(scatterer)\n", + "\n", + " x_off = position[0] - np.floor(position[0])\n", + " y_off = position[1] - np.floor(position[1])\n", + "\n", + " kernel = np.array(\n", + " [\n", + " [0, 0, 0],\n", + " [0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off],\n", + " [0, x_off * (1 - y_off), x_off * y_off],\n", + " ]\n", + " )\n", + "\n", + " for z in range(scatterer.shape[2]):\n", + " if splined_scatterer.dtype == complex:\n", + " splined_scatterer[:, :, z] = (\n", + " convolve(\n", + " np.real(scatterer[:, :, z]), kernel, mode=\"constant\"\n", + " )\n", + " + convolve(\n", + " np.imag(scatterer[:, :, z]), kernel, mode=\"constant\"\n", + " )\n", + " * 1j\n", + " )\n", + " else:\n", + " splined_scatterer[:, :, z] = convolve(\n", + " scatterer[:, :, z], kernel, mode=\"constant\"\n", + " )\n", + "\n", + " scatterer = splined_scatterer\n", + " position = np.floor(position)\n", + " new_limits = np.zeros(limits.shape, dtype=np.int32)\n", + " for i in range(3):\n", + " new_limits[i, :] = (\n", + " np.min([limits[i, 0], position[i]]),\n", + " np.max([limits[i, 1], position[i] + shape[i]]),\n", + " )\n", + "\n", + " if not (np.array(new_limits) == np.array(limits)).all():\n", + " new_volume = np.zeros(\n", + " np.diff(new_limits, axis=1)[:, 0].astype(np.int32),\n", + " dtype=complex,\n", + " )\n", + " old_region = (limits - new_limits).astype(np.int32)\n", + " limits = limits.astype(np.int32)\n", + " new_volume[\n", + " old_region[0, 0] : \n", + " old_region[0, 0] + limits[0, 1] - limits[0, 0],\n", + " old_region[1, 0] : \n", + " old_region[1, 0] + limits[1, 1] - limits[1, 0],\n", + " old_region[2, 0] : \n", + " old_region[2, 0] + limits[2, 1] - limits[2, 0],\n", + " ] = volume\n", + " volume = new_volume\n", + " limits = new_limits\n", + "\n", + " within_volume_position = position - limits[:, 0]\n", + "\n", + " # NOTE: Maybe shouldn't be additive.\n", + " volume[\n", + " int(within_volume_position[0]) : \n", + " int(within_volume_position[0] + shape[0]),\n", + " \n", + " int(within_volume_position[1]) : \n", + " int(within_volume_position[1] + shape[1]),\n", + "\n", + " int(within_volume_position[2]) : \n", + " int(within_volume_position[2] + shape[2]),\n", + " ] += scatterer\n", + " return volume, limits" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f8119c4", + "metadata": {}, + "outputs": [], + "source": [ + "V,L=_create_volume(particle)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f580eec2", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19703b64", + "metadata": {}, + "outputs": [], + "source": [ + "brightfield_microscope = dt.Darkfield(\n", + " wavelength=500 * dt.units.nm, NA=1.0, resolution=1 * dt.units.um,\n", + " magnification=1, refractive_index_medium=1.33, upsample=4,\n", + " output_region=(0, 0, 64, 64),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28aaa164", + "metadata": {}, + "outputs": [], + "source": [ + "illuminated_sample = brightfield_microscope(particle)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31ebd977", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "def plot_image(title, image):\n", + " \"\"\"Plot a grayscale image with a title.\"\"\"\n", + " plt.imshow(image, cmap=\"gray\")\n", + " plt.title(title, fontsize=30)\n", + " plt.axis(\"off\")\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11878e7a", + "metadata": {}, + "outputs": [], + "source": [ + "plot_image('Illuminated Sample', illuminated_sample.resolve())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05bbb017", + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "import numpy as np\n", + "import torch\n", + "# Set a fixed seed value\n", + "seed = 89\n", + "\n", + "# Python, NumPy, and PyTorch (CPU)\n", + "random.seed(seed)\n", + "np.random.seed(seed)\n", + "torch.manual_seed(seed)\n", + "\n", + "# Only set CUDA seeds if a GPU is available\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed(seed)\n", + " torch.cuda.manual_seed_all(seed)\n", + " torch.backends.cudnn.deterministic = True\n", + " torch.backends.cudnn.benchmark = False\n", + "\n", + "print(f\"Seeds set to {seed} (with CUDA: {torch.cuda.is_available()})\")\n", + "\n", + "\n", + "exp_crop_size = 30\n", + "\n", + "# Same as when selecting a single object.\n", + "sim_crop_size = exp_crop_size\n", + "\n", + "# Size of a pixel in nanometers in the output image.\n", + "pixel_size_nm = 100 # In nm.\n", + "\n", + "# Size of the long ellipsoid semiaxis in nm.\n", + "major_axis_length = 1000 # In nm.\n", + "\n", + "# Eccentricity of ellipsoid.\n", + "eccentricity = 0.185\n", + "\n", + "ellipse = dt.Ellipsoid(\n", + " position=0.5 * np.array([sim_crop_size, sim_crop_size]),\n", + " z=0 * dt.units.nm, # Particle in focus.\n", + " radius=(major_axis_length, eccentricity * major_axis_length) * dt.units.nm, # Axes in nanometers\n", + " intensity=0.35, # Field magnitude squared\n", + " rotation=0.225 * np.pi,\n", + ")\n", + "\n", + "# Set the optical properties of the microscope.\n", + "optics = dt.Darkfield(\n", + " NA=1.0, # Numerical aperture\n", + " wavelength=500 * dt.units.nm,\n", + " refractive_index_medium=1.33,\n", + " output_region=[0, 0, sim_crop_size, sim_crop_size],\n", + " magnification=1,\n", + " resolution=pixel_size_nm * dt.units.nm, # Camera resolution or effective resolution.\n", + " upscale=1,\n", + ")\n", + "\n", + "# Apply transformations. Use `Upscale` to improve simulations rendering.\n", + "sim_crop = (\n", + " optics(ellipse)\n", + " # dt.Upscale(optics(ellipse), factor=2) # Upscale the image to the original size.\n", + " >> dt.Background(0)\n", + " >> dt.Poisson(snr=40)\n", + " >> dt.Multiply(70)\n", + ")\n", + "\n", + "# Convert crop into NumPy array.\n", + "sim_crop = np.squeeze(sim_crop())\n", + "\n", + "# Plot the simulated and experimental crops.\n", + "fig, axes = plt.subplots(1, 2)\n", + "\n", + "# Simulated crop.\n", + "axes[0].imshow(sim_crop, cmap=\"gray\")\n", + "axes[0].axis(\"off\")\n", + "axes[0].set_title(\"Simulated Crop\")\n", + "\n", + "# # Experimental crop.\n", + "# axes[1].imshow(exp_crop, cmap=\"gray\")\n", + "# axes[1].axis(\"off\")\n", + "# axes[1].set_title(\"Experimental Crop\")\n", + "\n", + "# Adjust layout and show plot.\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d24fe56", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "deeptrack2_edit_env (3.11.8)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}