pymuster package
Subpackages
Submodules
pymuster.muster module
Multi Session Temporal Registration.
This module is the core of Multi Session Temporal Registration (MUSTER).
- class pymuster.muster.Registration(stages_iterations: list, stages_img_scales: list, stages_deform_scales: list, **kwargs)
Bases:
object
Multi Stage Temporal Registration. Registrates a series of images. The images are first registrated at a low resolution and then the resolution is increased and the images are registrated again. At each resolution the deformation field is has a resolution relative to the resampled image resolution set by
deform_res_scale
.Accepts same arguments as
StageRegistration
butiterations
is replaced bystages_iterations
, attr:deform_res_scale is replaced bystages_deform_scales
, andimg_size
is replaced bystages_img_scales
.Example
>>> from muster import Registration >>> deform_reg = Registration( >>> stages_iterations=[500, 250, 100], >>> stages_img_scales=[4, 2, 1], >>> stages_deform_scales=[4, 2, 2], >>> img_size=[128, 128, 128], >>> pix_dim=[1, 1, 1], >>> device="cuda:0", >>> ) >>> out = deform_reg.fit(images)
This example will registrate the image first at a resolution of
[32, 32, 32]
with a deformation grid of[8, 8, 8]
. Then the resolution is increased to[64, 64, 64]
with a deformation grid of[32, 32, 32]
. Finally the resolution is increased to[128, 128, 128]
with a deformation grid of[64, 64, 64]
.See
StageRegistration
for more information about the arguments.- Parameters:
stages_iterations (list of int) – Number of iterations for each stage.
stages_img_scales (list of int) – Image rescaling factors for each stage.
stages_deform_scales (list of int) – Deformation rescaling factors for each stage relative to the image rescaling factor.
- deform(images, deform_field, mode='bilinear', padding_mode='zeros', displacement=True)
Transforms the batch of images according to the deformation field.
- Parameters:
images (torch.Tensor or np.ndarray) – The images to deform. Shape
(N, C, x, y, z)
deform_field (torch.Tensor or np.ndarray) – The deformation field. Shape
(N, 3, x, y, z)
mode (str, optional) – The interpolation mode. Either ‘bilinear’ or ‘nearest’.
padding_mode (str, optional) – The padding mode. Either ‘zeros’ or ‘border’.
displacement (bool, optional) – Whether the deformation field is a displacement field or a deformation field.
- Returns:
The deformed images. Shape
(N, C, x, y, z)
- Return type:
torch.Tensor or np.ndarray
- fit(images, timepoints=None, masks=None)
Fits the registration model to the images
- Parameters:
images (numpy.ndarray or torch.Tensor) – The images to registrate. Shape (N, C, x, y, z)
timepoints (numpy.ndarray or torch.Tensor, optional) – The timepoints of the images. Shape (N,)
masks (numpy.ndarray or torch.Tensor, optional) – The masks of the images. Shape (N, C, x, y, z)
- Returns:
A dictionary containing the deformation field, rotations and translations.
- Return type:
dict
- get_deform_fields(deform_flow)
Computes the all deformation fields from the deformation flow.
The deformations are organized in a matrix such that the deformation from the deformation that distortes image_i to image_j is given by deform_matrix[j, i]. Another way of viewing the matrix is that the the displacment at time t_j of a particle starting at the identity grid in image_i is given with deform_matrix[i, j].
- Parameters:
deform_flow (torch.Tensor or np.ndarray) – The deformation flow. Shape
(N, 3, x, y, z)
- Returns:
The deformation fields. Shape
(N, N, 3, x, y, z)
- Return type:
torch.Tensor or np.ndarray
- get_identity_grid()
Returns the identity grid.
- rigid_transform(images, rotations, translations)
Applies the rigid transformation to the images.
- class pymuster.muster.SpatialTransformer(size, pix_dim=(1, 1, 1), mode='bilinear', padding_mode='zeros')
Bases:
Module
Spatial Transformer for performing grid pull operations on spaces. Based on https://github.com/voxelmorph/voxelmorph. The spatial transformer uses a deformation field to deform the input space. The deformation field is given in mm/step.
- Parameters:
size (tuple) – the target size of the output tensor. shape
(x, y, z)
pix_dim (tuple) – the pixel spacing (mm/px) of the output tensor
(dx, dy, dz)
mode ('bilinear' or 'nearest') – interpolation mode
padding_mode ('zeros' or 'border') – padding mode
- forward(field, transformation_field, displacement=True)
Deforms the field image with the transformation field with units in mm/step.
- Parameters:
field – the source space to be transformed
transformation_field – the transformation field to be applied to the source image, with units in mm/step
displacement – whether the transformation field is a displacement field or a deformation field
- Returns:
the transformed source image
- Return type:
torch.Tensor
- class pymuster.muster.StageRegistration(img_size: tuple, pix_dim: tuple = (1, 1, 1), deform_res_scale: int = 1, device: str = 'cpu', num_iterations: int = 100, spatial_smoothness_penalty: float = 0.0, temporal_smoothness_penalty: float = 0.0, invertability_penalty: float = 0.0, l2_penalty: float = 0.0, smoothing_sigma: float = 0.0, interpolation_mode: str = 'bilinear', integration_steps: int = 7, integration_method: str = 'ss', field_composition_method: str = 'interpolate', affine_adjustment: str = 'none', learning_rate: float = 0.001, betas: tuple = (0.9, 0.999), tol: float = 0.0001, img_similarity_metric: str = 'NCC', img_similarity_spatial_size: int | float = 3, img_similarity_scale_invariant: bool = True, verbose: bool = True)
Bases:
object
Register a longitudinal series of images using a series of deformation fields using a single resolution.
The images must be in the same space and must be pre-registered using an rigid or affine transformation.
The deformation between two consectutive timepoints is represented by a deformation field. These fields are generated by integrating a vector field, also known as the deformation flow. Various methods for integration are supported:
integration_method="ss"
: Use scaling and squaring to integrate stationary vector fields. (Default)integration_method="euler"
: Use euler integration for stationary vector fields.integration_method="rk4"
: Use Runge-Kutta 4 integration for stationary vector field.integration_method="euler_time"
: Use euler integration for time varying vector fields.integration_method="rk4_time"
: Use Runge-Kutta 4 integration for time varying vector fields.
For non-consecutive timepoints, the deformation field can be composed in two ways:
field_composition_method="interpolate"
: Use interpolation to compose the intermediate deformation fields. (Default)field_composition_method="flow_continuation"
: Continuate the deformation field by integrating the intermediate deformation flows.
Supported image similarity metrics include:
img_similarity_metric="NCC"
: Normalized local cross correlation between the two images as loss. The local neighborhood is a cube of size with side lengthimg_similarity_spatial_size
.img_similarity_metric="L2"
: L2 norm of the difference between the two images as loss.img_similarity_metric="NCCS"
: Use a sobel filter to compute the gradient of the two images, and use the normalized local cross correlation between the two gradients as loss. The local window is a cube of size with side length3
.img_similarity_metric="WNCC"
: Normalized local cross correlation between the two images as loss where each local neighborhood is weighted by the cross standard deviation of the two images.The local neighborhood is a cube of size with side lengthimg_similarity_spatial_size
.img_similarity_metric="GaussNCC"
: Normalized local cross correlation between the two images as loss where each local neighborhood is weighted by a gaussian filter. The standard deviation of the gaussian filter is given byimg_similarity_spatial_size
.img_similarity_metric="Fourier"
: Use the Fourier tranform to compute a filtered gradient of each images, and compute the global normalized cross correlation between the two filtered gradients as loss. The standard deviation of the gaussian filter isimg_similarity_spatial_size
.
- Parameters:
img_size (tuple) – The size of the input images in the format
(x, y, z)
.pix_dim (tuple) – The pixel dimensions of the input images in the format
(dx, dy, dz)
. Default is(1, 1, 1)
.deform_res_scale (int) – The resolution scale of the deformation field. The deformation field will have a resolution of img_size/deform_res_scale. Default is
1
.device (str) – The device to use for computation. (‘cpu’ or ‘cuda:n’ or torch.device object). Default is
"cpu"
.num_iterations (int) – Number of optimization iterations. Default is
100
.spatial_smoothness_penalty (float) – Weight for spatial smoothness in loss function. Default is
1.0
.temporal_smoothness_penalty (float) – Weight for temporal smoothness in loss function. Default is
1.0
.invertability_penalty (float) – Weight for invertibility penalty in loss function. Default is
1.0
.l2_penalty (float) – Weight for L2 penalty on the deformation flow. Default is
0.0
.smoothing_sigma (float) – Sigma for Gaussian smoothing of deformation flows. Default is
1.0
.mode (str) – The mode for interpolation of the images. (‘nearest’, ‘bilinear’, or ‘bicubic’). Default is
"bilinear"
.integration_steps (int) – The number of integration steps for integrating deformation flow. Default is
7
.integration_method (str) – The method for integrating deformation field. (‘ss’, ‘euler’, ‘rk4’, ‘euler_time’, or ‘rk4’). Default is
"ss"
.field_composition_method (str) – The method for composing deformation fields. (‘interpolate’ or ‘flow_continuation’). Default is
"interpolate"
.affine_adjustment (str) – The method for adjusting the affine transformation. (‘none’, ‘rigid’, or ‘affine’). Default is
"none"
.learning_rate (float) – Learning rate for optimizer. Default is
1e-3
.betas (tuple) – The beta coefficients for Adam optimizer. Default is
(0.9, 0.999)
.tol (float) – Tolerance for optimization convergence. Default is
1e-4
.img_similarity_metric (str) – The image similarity metric (‘NCC’, ‘L2’, ‘NCCS’, ‘WNCC’, ‘GaussNCC’, or ‘Fourier’). Default is
"NCC"
.img_similarity_spatial_size (int/float) – Size or standard deviation of the local neighborhood for the image similarity metric. Default is
3
.verbose (bool) – Flag for printing progress during optimization. Default is
True
.
- deform(images, deform_field, mode='bilinear', padding_mode='zeros', displacement=True)
Transforms the batch of images according to the deformation field.
- Parameters:
images (torch.Tensor or np.ndarray) – The images to deform. Shape
(N, C, x, y, z)
deform_field (torch.Tensor or np.ndarray) – The deformation field. Shape
(N, 3, x, y, z)
mode (str, optional) – The interpolation mode. Either ‘bilinear’ or ‘nearest’.
padding_mode (str, optional) – The padding mode. Either ‘zeros’ or ‘border’.
displacement (bool, optional) – Whether the deformation field is a displacement field or a deformation field.
- Returns:
The deformed images. Shape
(N, C, x, y, z)
- Return type:
torch.Tensor or np.ndarray
- fit(images, inital_deform_flow=None, inital_sigmas=None, timepoints=None, initial_rotations=None, initial_translations=None, initial_affine=None, masks=None)
Fits the deformation flow to the images.
- Parameters:
images (torch.Tensor or np.ndarray) – shape
(N, channels, x, y, z)
inital_deform_flow (torch.Tensor or np.ndarray) – shape
(N-1, 3, x, y, z)
timepoints (torch.Tensor or np.ndarray) – shape
(N,)
The timepoints of the images. Used to adjust the temporal penalties. May be None, in which case the timepoints are assumed to be equally spaced.
- Returns:
Dictionary containing the deformation flow and optionally the rotations and translations.
- Return type:
out (dict)
- get_deform_fields(deform_flow)
Computes the all deformation fields from the deformation flow.
The deformations are organized in a matrix such that the deformation from the deformation that distortes image_i to image_j is given by deform_matrix[j, i]. Another way of viewing the matrix is that the the displacment at time t_j of a particle starting at the identity grid in image_i is given with deform_matrix[i, j].
- Parameters:
deform_flow (torch.Tensor or np.ndarray) – The deformation flow. Shape
(N, 3, x, y, z)
- Returns:
The deformation fields. Shape
(N, N, 3, x, y, z)
- Return type:
torch.Tensor or np.ndarray
- rigid_transform(images, rotations, translations)