Source code for camp.StructuredGridOperators.UnaryOperators.GradientRegularizer

import torch

from ...StructuredGridOperators.UnaryOperators.GradientFilter import Gradient
from ._UnaryFilter import Filter


[docs]class NormGradient(Filter): def __init__(self, weight, dim=2, device='cpu', dtype=torch.float32): super(NormGradient, self).__init__() self.weight = weight self.device = device self.dtype = dtype self.gradient_operator = Gradient.Create(dim=dim, device=device, dtype=dtype)
[docs] @staticmethod def Create(weight, dim=2, device='cpu', dtype=torch.float32): filt = NormGradient(weight, dim, device, dtype) filt = filt.to(device) filt = filt.type(dtype) # Can't add StructuredGrid to the register buffer, so we need to make sure they are on the right device for attr, val in filt.__dict__.items(): if type(val).__name__ == 'StructuredGrid': val.to_(device) val.to_type_(dtype) else: pass return filt
[docs] def forward(self, vector_field): return self.weight * (self.gradient_operator(vector_field) ** 2)