Source code for camp.StructuredGridTools.GradientFlowFilter

import torch
import torch.nn.functional as F
from ..StructuredGridOperators.UnaryOperators.GradientFilter import Gradient
from ..Core.StructuredGridClass import StructuredGrid

from ._BaseTool import Filter


# TODO Multiscale matching
# TODO Use only one applier
# TODO Check for memery saving and computation saving areas

[docs]class IterativeMatch(Filter): def __init__(self, source, target, similarity, operator, regularization=None, step_size=0.001, incompressible=True, device='cpu', dtype=torch.float32): super(IterativeMatch, self).__init__(source, target) if not type(source).__name__ == 'StructuredGrid': raise RuntimeError( f'Only type "StructuredGrid" for source is accepted, got {type(source).__name__}' ) if not type(target).__name__ == 'StructuredGrid': raise RuntimeError( f'Only type "StructuredGrid" for target is accepted, got {type(target).__name__}' ) if not any([source.size[x].item() == target.size[x].item() for x in range(len(source.size))]): raise RuntimeError( f'Images must have the same size - Target Size: {target.size}, Source Size: {source.size}' ) self.device = device self.dtype = dtype self.similarity = similarity self.regularization = regularization self.operator = operator self.step_size = step_size self.incompressible = incompressible self.target = target.clone() self.source = source.clone() self.moving = source.clone() self.field = StructuredGrid.FromGrid(source) self.field.set_to_identity_lut_() self.identity = StructuredGrid.FromGrid(source) self.identity.set_to_identity_lut_() self.update = StructuredGrid.FromGrid(source) self.update.set_to_identity_lut_() self.initial_energy = self.energy() self.gradients = Gradient.Create(dim=len(source.size), device=device, dtype=dtype)(source) self.moving_grads = self.gradients.clone()
[docs] @staticmethod def Create(source, target, similarity, operator, regularization=None, step_size=0.001, incompressible=True, device='cpu', dtype=torch.float32): match = IterativeMatch(source, target, similarity, operator, regularization, step_size, incompressible, device, dtype) match = match.to(device=device, dtype=dtype) # Can't add StructuredGrid to the register buffer, so we need to make sure they are on the right device for attr, val in match.__dict__.items(): if type(val).__name__ == 'StructuredGrid': val.to_(device) val.to_type_(dtype) else: pass return match
[docs] def energy(self): energy = self.similarity(self.target, self.moving).sum() if self.regularization: reg_e = self.regularization(self.field - self.identity).sum() energy = energy + reg_e return energy.item()
@staticmethod def _apply_field(x, field, interpolation_mode='bilinear', padding_mode='zeros'): grid = field.clone() # Change the field to be in index space grid = grid - x.origin.view(*x.size.shape, *([1] * len(x.size))) grid = grid / (x.spacing * (x.size / 2)).view(*x.size.shape, *([1] * len(x.size))) grid = grid - 1 grid = grid.data.permute(torch.arange(1, len(grid.shape())).tolist() + [0]) grid = grid.data.view(1, *grid.shape) resample_grid = grid.flip(-1) out_tensor = F.grid_sample(x.data.view(1, *x.data.shape), resample_grid, mode=interpolation_mode, align_corners=True, padding_mode=padding_mode).squeeze(0) out = StructuredGrid.FromGrid( x, tensor=out_tensor, channels=out_tensor.shape[0] ) return out
[docs] def step(self): # Calculate the similarity body force body_v = self.similarity.c1(self.target, self.moving, self.moving_grads) # Apply the operator to the body force body_v = self.operator.apply_inverse(body_v) # Apply the step size body_v = self.step_size*body_v if self.incompressible: body_v = self.operator.project_incompressible(body_v) # Create the update field self.update = self.update - body_v # Sample the field at the locations of the update field self.field = self._apply_field(self.field, self.update, padding_mode="border") self.moving = self._apply_field(self.source, self.field) self.moving_grads = self._apply_field(self.gradients, self.field) # self.update = self.field.clone() self.update.set_to_identity_lut_() # Calculate and return the new energy new_energy = self.energy() return new_energy
[docs] def get_field(self): return self.field
[docs] def get_image(self): return self.moving