Source code for camp.StructuredGridOperators.BinaryOperators.L2ImageSimilarity

import torch

from ..UnaryOperators.GradientFilter import Gradient
from ._BinaryFilter import Filter


[docs]class L2Similarity(Filter): def __init__(self, dim=2, device='cpu', dtype=torch.float32): super(L2Similarity, self).__init__() self.device = device self.dtype = dtype self.gradient_operator = Gradient.Create(dim=dim, device=device, dtype=dtype)
[docs] @staticmethod def Create(dim=2, device='cpu', dtype=torch.float32): """ Compare two :class:`StructuredGrid` objects using an L2 similarity metric. :param dim: Dimensionality of the :class:`StructuredGrid` to be compared (not including channels). :type dim: int :param device: Memory location - one of 'cpu', 'cuda', or 'cuda:X' where X specifies the device identifier. Default: 'cpu' :type device: str :param dtype: Data type for the attributes. Specified from torch memory types. Default: 'torch.float32' :type dtype: str :return: L2 comparision object. """ filt = L2Similarity(dim, device, dtype) filt = filt.to(device) filt = filt.type(dtype) # Can't add StructuredGrid to the register buffer, so we need to make sure they are on the right device for attr, val in filt.__dict__.items(): if type(val).__name__ == 'StructuredGrid': val.to_(device) val.to_type_(dtype) else: pass return filt
[docs] def forward(self, target, moving): """ Compare two :class:`StructuredGrid` with L2 similarity metric. This is often used for registration so the variables are labeled as target and moving. This function preserves the dimensionality of the original grids. :param target: Structured Grid 1 :type target: :class:`StructuredGrid` :param moving: Structured Grid 2 :type moving: :class:`StructuredGrid` :return: L2 similarity as :class:`StructuredGrid` """ return 0.5 * ((moving - target) ** 2)
[docs] def c1(self, target, moving, grads): """ First derivative of the L2 similarity metric. :param target: Structured Grid 1 :type target: :class:`StructuredGrid` :param moving: Structured Grid 2 :type moving: :class:`StructuredGrid` :param grads: Gradients of the moving image. :type grads: :class:`StructuredGrid` :return: """ # grads = self.gradient_operator(moving) return (moving - target) * grads