Source code for camp.FileIO.ITKFileIO

import torch
import SimpleITK as sitk

from ..Core import *


[docs]def LoadITKFile(filename, device='cpu', dtype=torch.float32): """ Load an ITK compatible file using the SimpleITK package into a :class:`StructuredGrid` object. :param filename: File path :type filename: str :param device: Memory location - one of 'cpu', 'cuda', or 'cuda:X' where X specifies the device identifier. Default: 'cpu' :type device: str, optional :param dtype: Data type, specified from torch memory types. Default: 'torch.float32' :type dtype: str, optional :return: :class:`StructuredGrid` """ itk_image = sitk.ReadImage(filename) # ITK ordering is x, y, z. But numpy is z, y, x image_size = torch.as_tensor(itk_image.GetSize()[::-1], dtype=dtype) image_spacing = torch.as_tensor(itk_image.GetSpacing()[::-1], dtype=dtype) image_origin = torch.as_tensor(itk_image.GetOrigin()[::-1], dtype=dtype) channels = itk_image.GetNumberOfComponentsPerPixel() dataArray = sitk.GetArrayFromImage(itk_image) if dataArray.dtype == 'uint16': dataArray = dataArray.astype('int32') tensor = torch.as_tensor(dataArray) image_origin = image_origin[image_size != 1.0] image_spacing = image_spacing[image_size != 1.0] image_size = image_size[image_size != 1.0] tensor = tensor.squeeze() # Make sure that the channels are accounted for if channels == 1: tensor = tensor.view(channels, *tensor.shape) else: tensor = tensor.permute([-1] + torch.arange(0, len(image_size)).tolist()) out = StructuredGrid( size=image_size, spacing=image_spacing, origin=image_origin, device=device, dtype=dtype, tensor=tensor, channels=channels ) return out
[docs]def SaveITKFile(grid, f_name): """ Save a :class:`StructuredGrid` object to an ITK compatible file using the SimpleITK package. :param grid: :class:`StructuredGrid` to be saved. :param f_name: File path :type f_name: str :return: None """ dim = len(grid.size) # Need to put the vector in the last dimension vector_grid = grid.data.permute(list(range(1, dim +1)) + [0]).squeeze(-1) # it will always be this size now if dim == 2 and vector_grid.shape[-1] == 3: itk_image = sitk.GetImageFromArray(vector_grid.cpu().numpy(), isVector=True) # elif dim == 2: # itk_image = sitk.GetImageFromArray(vector_grid.unsqueeze(-2).cpu().numpy()) else: itk_image = sitk.GetImageFromArray(vector_grid.cpu().numpy()) spacing = grid.spacing.tolist() if dim == 2: spacing = [1.0] + spacing origin = grid.origin.tolist() if dim == 2: origin = [1.0] + origin # ITK ordering is x, y, z. But numpy is z, y, x itk_image.SetSpacing(spacing[::-1]) itk_image.SetOrigin(origin[::-1]) sitk.WriteImage(itk_image, f_name)