Source code for antspynet.utilities.unet_utilities
import numpy as np
import ants
[docs]def encode_unet(segmentations_array,
segmentation_labels=None):
"""
Basic one-hot transformation of segmentations array
Arguments
---------
segmentations_array : numpy array
multi-label numpy array
segmentation_labels : tuple or list
Note that a background label (typically 0) needs to be included.
Returns
-------
An n-d array of shape batch_size x width x height x <depth> x number_of_segmentation_labels
Example
-------
>>> import ants
>>> image = ants.image_read(ants.get_ants_data('r16'))
>>> seg = ants.kmeans_segmentation(image, 3)['segmentation']
>>> one_hot = encode_unet(seg.numpy().astype('int'))
"""
if segmentation_labels is None:
segmentation_labels = np.unique(segmentations_array)
number_of_labels = len(segmentation_labels)
dim_segmentations = segmentations_array.shape
image_dimension = 2
if len(dim_segmentations) == 4:
image_dimension = 3
if number_of_labels < 2:
raise ValueError("At least two segmentation labels need to be specified.")
one_hot_array = np.zeros((*dim_segmentations, number_of_labels))
for i in range(number_of_labels):
per_label = np.zeros_like(segmentations_array)
per_label[segmentations_array == segmentation_labels[i]] = 1
if image_dimension == 2:
one_hot_array[:,:,:,i] = per_label
else:
one_hot_array[:,:,:,:,i] = per_label
return one_hot_array
[docs]def decode_unet(y_predicted,
domain_image):
"""
Decoding function for the u-net prediction outcome
Arguments
---------
y_predicted : an array
Shape batch_size x width x height x <depth> x number_of_segmentation_labels
domain_image : ANTs image
Defines the geometry of the returned probability images
Returns
-------
List of probability images.
Example
-------
>>> import ants
>>> image = ants.image_read(ants.get_ants_data('r16'))
"""
batch_size = y_predicted.shape[0]
number_of_labels = y_predicted.shape[-1]
image_dimension = 2
if len(y_predicted.shape) == 5:
image_dimension = 3
batch_probability_images = list()
for i in range(batch_size):
probability_images = list()
for j in range(number_of_labels):
if image_dimension == 2:
image_array = np.squeeze(y_predicted[i,:,:,j])
else:
image_array = np.squeeze(y_predicted[i,:,:,:,j])
probability_images.append(ants.from_numpy(image_array,
origin=domain_image.origin, spacing=domain_image.spacing,
direction=domain_image.direction))
batch_probability_images.append(probability_images)
return batch_probability_images