Source code for antspynet.utilities.spatial_transformer_network_utilities


import tensorflow.keras.backend as K
from tensorflow.keras.layers import Layer, InputSpec
from tensorflow.keras import initializers

import tensorflow as tf

[docs]class SpatialTransformer2D(Layer): """ Custom layer for the spatial transfomer network. Arguments --------- inputs : list of size 2 The first element are the images and the second element are the weights. resampled_size : tuple of length 2 Size of the resampled output images. transform_type : string Transform type (default = 'affine'). interpolator_type : string Interpolator type (default = 'linear'). Returns ------- Keras layer A 2-D keras layer """ def __init__(self, resampled_size, transform_type='affine', interpolator_type='linear', **kwargs): if K.backend() != 'tensorflow': raise ValueError("Tensorflow is required for this STN implementation.") if len(resampled_size) != 2: raise ValueError("Resampled size must be a vector of length 2 (for 2-D).") self.resampled_size = resampled_size self.transform_type = transform_type self.interpolator_type = interpolator_type super(SpatialTransformer2D, self).__init__(**kwargs) def build(self, input_shape): super(SpatialTransformer2D, self).build(input_shape) def call(self, inputs, mask=None): images = inputs[0] transform_parameters = inputs[1] output = None if self.transform_type == 'affine': output = self.affine_transform_images(images, transform_parameters, self.resampled_size) else: raise ValueError("Unsupported transform type.") input_shape = [K.shape(images)] return(tf.reshape(output, self.compute_output_shape(input_shape))) def compute_output_shape(self, input_shape): output_shape = (input_shape[0][0], self.resampled_size[0], self.resampled_size[1], input_shape[0][-1]) return(output_shape) def affine_transform_images(self, images, affine_transform_parameters, resampled_size): batch_size = K.int_shape(images)[0] number_of_channels = K.int_shape(images)[-1] if batch_size is None: transform_parameters = K.reshape(affine_transform_parameters, shape=(-1, 2, 3)) else: transform_parameters = K.reshape(affine_transform_parameters, shape=(batch_size, 2, 3)) regular_grid = self.make_regular_grid(resampled_size) sampled_grids = K.dot(transform_parameters, regular_grid) if self.interpolator_type == 'linear': interpolated_image = self.linear_interpolate(images, sampled_grids, resampled_size) else: raise ValueError("Unsupported interpolator type.") if batch_size is None: new_output_shape = (-1, resampled_size[0], resampled_size[1], number_of_channels) else: new_output_shape = (batch_size, resampled_size[0], resampled_size[1], number_of_channels) interpolated_image = K.reshape(interpolated_image, shape = new_output_shape) return(interpolated_image) def make_regular_grid(self, resampled_size): x_linear_space = tf.linspace(-1.0, 1.0, resampled_size[1]) y_linear_space = tf.linspace(-1.0, 1.0, resampled_size[0]) x_coords, y_coords = tf.meshgrid(x_linear_space, y_linear_space) x_coords = K.flatten(x_coords) y_coords = K.flatten(y_coords) ones = K.ones_like(x_coords) regular_grid = K.concatenate([x_coords, y_coords, ones], axis = 0) regular_grid = K.flatten(regular_grid) regular_grid = K.reshape(regular_grid, (3, resampled_size[0] * resampled_size[1])) return(regular_grid) def linear_interpolate(self, images, sampled_grids, resampled_size): batch_size = K.shape(images)[0] height = K.shape(images)[1] width = K.shape(images)[2] number_of_channels = K.shape(images)[3] x = K.cast(K.flatten(sampled_grids[:, 0:1, :]), dtype='float32') y = K.cast(K.flatten(sampled_grids[:, 1:2, :]), dtype='float32') x = 0.5 * (x + 1.0) * K.cast(width, dtype='float32') y = 0.5 * (y + 1.0) * K.cast(height, dtype='float32') x0 = K.cast(x, dtype='int32') x1 = x0 + 1 y0 = K.cast(y, dtype='int32') y1 = y0 + 1 xMax = int(K.int_shape(images)[2] - 1) yMax = int(K.int_shape(images)[1] - 1) x0 = K.clip(x0, 0, xMax) x1 = K.clip(x1, 0, xMax) y0 = K.clip(y0, 0, yMax) y1 = K.clip(y1, 0, yMax) batch_pixels = K.arange(0, batch_size) * (height * width) batch_pixels = K.expand_dims(batch_pixels, axis = -1) base = K.repeat_elements(batch_pixels, rep=int(resampled_size[0] * resampled_size[1]), axis=1) base = K.flatten(base) indices00 = base + y0 * width + x0 indices01 = base + y1 * width + x0 indices10 = base + y0 * width + x1 indices11 = base + y1 * width + x1 flat_images = K.reshape(images, shape=(-1, number_of_channels)) flat_images = K.cast(flat_images, dtype='float32') pixelValues00 = K.gather(flat_images, indices00) pixelValues01 = K.gather(flat_images, indices01) pixelValues10 = K.gather(flat_images, indices10) pixelValues11 = K.gather(flat_images, indices11) x0 = K.cast(x0, dtype='float32') x1 = K.cast(x1, dtype='float32') y0 = K.cast(y0, dtype='float32') y1 = K.cast(y1, dtype='float32') weight00 = K.expand_dims(((x1 - x) * (y1 - y)), axis=1) weight01 = K.expand_dims(((x1 - x) * (y - y0)), axis=1) weight10 = K.expand_dims(((x - x0) * (y1 - y)), axis=1) weight11 = K.expand_dims(((x - x0) * (y - y0)), axis=1) interpolatedValues00 = weight00 * pixelValues00 interpolatedValues01 = weight01 * pixelValues01 interpolatedValues10 = weight10 * pixelValues10 interpolatedValues11 = weight11 * pixelValues11 interpolatedValues = (interpolatedValues00 + interpolatedValues01 + interpolatedValues10 + interpolatedValues11) return(interpolatedValues)
[docs]class SpatialTransformer3D(Layer): """ Custom layer for the spatial transfomer network. Arguments --------- inputs : list of size 2 The first element are the images and the second element are the weights. resampled_size : tuple of length 3 Size of the resampled output images. transform_type : string Transform type (default = 'affine'). interpolator_type : string Interpolator type (default = 'linear'). Returns ------- Keras layer A 3-D keras layer """ def __init__(self, resampled_size, transform_type='affine', interpolator_type='linear', **kwargs): if K.backend() != 'tensorflow': raise ValueError("Tensorflow is required for this STN implementation.") if len(resampled_size) != 3: raise ValueError("Resampled size must be a vector of length 3 (for 3-D).") self.resampled_size = resampled_size self.transform_type = transform_type self.interpolator_type = interpolator_type super(SpatialTransformer3D, self).__init__(**kwargs) def build(self, input_shape): super(SpatialTransformer3D, self).build(input_shape) def call(self, inputs, mask=None): images = inputs[0] transform_parameters = inputs[1] output = None if self.transform_type == 'affine': output = self.affine_transform_images(images, transform_parameters, self.resampled_size) else: raise ValueError("Unsupported transform type.") input_shape = [K.shape(images)] return(tf.reshape(output, self.compute_output_shape(input_shape))) def compute_output_shape(self, input_shape): output_shape = (input_shape[0][0], self.resampled_size[0], self.resampled_size[1], self.resampled_size[2], input_shape[0][-1]) return(output_shape) def affine_transform_images(self, images, affine_transform_parameters, resampled_size): batch_size = K.int_shape(images)[0] number_of_channels = K.int_shape(images)[-1] if batch_size is None: transform_parameters = K.reshape(affine_transform_parameters, shape=(-1, 3, 4)) else: transform_parameters = K.reshape(affine_transform_parameters, shape=(batch_size, 3, 4)) regular_grid = self.make_regular_grid(resampled_size) sampled_grids = K.dot(transform_parameters, regular_grid) if self.interpolator_type == 'linear': interpolated_image = self.linear_interpolate(images, sampled_grids, resampled_size) else: raise ValueError("Unsupported interpolator type.") if batch_size is None: new_output_shape = (-1, resampled_size[0], resampled_size[1], resampled_size[2], number_of_channels) else: new_output_shape = (batch_size, resampled_size[0], resampled_size[1], resampled_size[2], number_of_channels) interpolated_image = K.reshape(interpolated_image, shape = new_output_shape) return(interpolated_image) def make_regular_grid(self, resampled_size): x_linear_space = tf.linspace(-1.0, 1.0, resampled_size[1]) y_linear_space = tf.linspace(-1.0, 1.0, resampled_size[0]) z_linear_space = tf.linspace(-1.0, 1.0, resampled_size[2]) x_coords, y_coords, z_coords = tf.meshgrid(x_linear_space, y_linear_space, z_linear_space) x_coords = K.flatten(x_coords) y_coords = K.flatten(y_coords) z_coords = K.flatten(z_coords) ones = K.ones_like(x_coords) regular_grid = K.concatenate([x_coords, y_coords, z_coords, ones], axis = 0) regular_grid = K.flatten(regular_grid) regular_grid = K.reshape(regular_grid, (4, resampled_size[0] * resampled_size[1] * resampled_size[2])) return(regular_grid) def linear_interpolate(self, images, sampled_grids, resampled_size): batch_size = K.shape(images)[0] height = K.shape(images)[1] width = K.shape(images)[2] depth = K.shape(images)[3] number_of_channels = K.shape(images)[4] x = K.cast(K.flatten(sampled_grids[:, 0:1, :]), dtype='float32') y = K.cast(K.flatten(sampled_grids[:, 1:2, :]), dtype='float32') z = K.cast(K.flatten(sampled_grids[:, 2:3, :]), dtype='float32') x = 0.5 * (x + 1.0) * K.cast(width, dtype='float32') y = 0.5 * (y + 1.0) * K.cast(height, dtype='float32') z = 0.5 * (z + 1.0) * K.cast(depth, dtype='float32') x0 = K.cast(x, dtype='int32') x1 = x0 + 1 y0 = K.cast(y, dtype='int32') y1 = y0 + 1 z0 = K.cast(z, dtype='int32') z1 = z0 + 1 xMax = int(K.int_shape(images)[2] - 1) yMax = int(K.int_shape(images)[1] - 1) zMax = int(K.int_shape(images)[3] - 1) x0 = K.clip(x0, 0, xMax) x1 = K.clip(x1, 0, xMax) y0 = K.clip(y0, 0, yMax) y1 = K.clip(y1, 0, yMax) z0 = K.clip(z0, 0, zMax) z1 = K.clip(z1, 0, zMax) batch_pixels = K.arange(0, batch_size) * (height * width * depth) batch_pixels = K.expand_dims(batch_pixels, axis=-1) base = K.repeat_elements(batch_pixels, rep=int(resampled_size[0] * resampled_size[1] * resampled_size[2]), axis=1) base = K.flatten(base) indices000 = base + z0 * (width * height) + y0 * width + x0 indices001 = base + z1 * (width * height) + y0 * width + x0 indices010 = base + z0 * (width * height) + y1 * width + x0 indices011 = base + z1 * (width * height) + y1 * width + x0 indices100 = base + z0 * (width * height) + y0 * width + x1 indices101 = base + z1 * (width * height) + y0 * width + x1 indices110 = base + z0 * (width * height) + y1 * width + x1 indices111 = base + z1 * (width * height) + y1 * width + x1 flatImages = K.reshape(images, shape=(-1, number_of_channels)) flatImages = K.cast(flatImages, dtype='float32') pixelValues000 = K.gather(flatImages, indices000) pixelValues001 = K.gather(flatImages, indices001) pixelValues010 = K.gather(flatImages, indices010) pixelValues011 = K.gather(flatImages, indices011) pixelValues100 = K.gather(flatImages, indices100) pixelValues101 = K.gather(flatImages, indices101) pixelValues110 = K.gather(flatImages, indices110) pixelValues111 = K.gather(flatImages, indices111) x0 = K.cast(x0, dtype='float32') x1 = K.cast(x1, dtype='float32') y0 = K.cast(y0, dtype='float32') y1 = K.cast(y1, dtype='float32') z0 = K.cast(z0, dtype='float32') z1 = K.cast(z1, dtype='float32') weight000 = K.expand_dims(((x1 - x) * (y1 - y) * (z1 - z)), axis = 1) weight001 = K.expand_dims(((x1 - x) * (y1 - y) * (z - z0)), axis = 1) weight010 = K.expand_dims(((x1 - x) * (y - y0) * (z1 - z)), axis = 1) weight011 = K.expand_dims(((x1 - x) * (y - y0) * (z - z0)), axis = 1) weight100 = K.expand_dims(((x - x0) * (y1 - y) * (z1 - z)), axis = 1) weight101 = K.expand_dims(((x - x0) * (y1 - y) * (z - z0)), axis = 1) weight110 = K.expand_dims(((x - x0) * (y - y0) * (z1 - z)), axis = 1) weight111 = K.expand_dims(((x - x0) * (y - y0) * (z - z0)), axis = 1) interpolatedValues000 = weight000 * pixelValues000 interpolatedValues001 = weight001 * pixelValues001 interpolatedValues010 = weight010 * pixelValues010 interpolatedValues011 = weight011 * pixelValues011 interpolatedValues100 = weight100 * pixelValues100 interpolatedValues101 = weight101 * pixelValues101 interpolatedValues110 = weight110 * pixelValues110 interpolatedValues111 = weight111 * pixelValues111 interpolatedValues = ( interpolatedValues000 + interpolatedValues001 + interpolatedValues010 + interpolatedValues011 + interpolatedValues100 + interpolatedValues101 + interpolatedValues110 + interpolatedValues111) return(interpolatedValues)