import numpy as np
import tensorflow as tf
import ants
[docs]def hippmapp3r_segmentation(t1,
do_preprocessing=True,
antsxnet_cache_directory=None,
verbose=False):
"""
Perform HippMapp3r (hippocampal) segmentation described in
https://www.ncbi.nlm.nih.gov/pubmed/31609046
with models and architecture ported from
https://github.com/mgoubran/HippMapp3r
Additional documentation and attribution resources found at
https://hippmapp3r.readthedocs.io/en/latest/
Preprocessing consists of:
* n4 bias correction and
* brain extraction
The input T1 should undergo the same steps. If the input T1 is the raw
T1, these steps can be performed by the internal preprocessing, i.e. set
do_preprocessing = True
Arguments
---------
t1 : ANTsImage
input image
do_preprocessing : boolean
See description above.
antsxnet_cache_directory : string
Destination directory for storing the downloaded template and model weights.
Since these can be reused, if is None, these data will be downloaded to a
~/.keras/ANTsXNet/.
verbose : boolean
Print progress to the screen.
Returns
-------
ANTs labeled hippocampal image.
Example
-------
>>> mask = hippmapp3r_segmentation(t1)
"""
from ..architectures import create_hippmapp3r_unet_model_3d
from ..utilities import preprocess_brain_image
from ..utilities import get_pretrained_network
from ..utilities import get_antsxnet_data
if t1.dimension != 3:
raise ValueError( "Image dimension must be 3." )
if verbose == True:
print("************* Preprocessing ***************")
print("")
t1_preprocessed = t1
if do_preprocessing == True:
t1_preprocessing = preprocess_brain_image(t1,
truncate_intensity=None,
brain_extraction_modality="t1",
template=None,
do_bias_correction=True,
do_denoising=False,
antsxnet_cache_directory=antsxnet_cache_directory,
verbose=verbose)
t1_preprocessed = t1_preprocessing["preprocessed_image"] * t1_preprocessing['brain_mask']
if verbose == True:
print("************* Initial stage segmentation ***************")
print("")
# Normalize to mprage_hippmapp3r space
if verbose == True:
print(" HippMapp3r: template normalization.")
template_file_name_path = get_antsxnet_data("mprage_hippmapp3r", antsxnet_cache_directory=antsxnet_cache_directory)
template_image = ants.image_read(template_file_name_path)
registration = ants.registration(fixed=template_image, moving=t1_preprocessed,
type_of_transform="antsRegistrationSyNQuickRepro[t]", verbose=verbose)
image = registration['warpedmovout']
transforms = dict(fwdtransforms=registration['fwdtransforms'],
invtransforms=registration['invtransforms'])
# Threshold at 10th percentile of non-zero voxels in "robust range (fslmaths)"
if verbose == True:
print(" HippMapp3r: threshold.")
image_array = image.numpy()
image_robust_range = np.quantile(image_array[np.where(image_array != 0)], (0.02, 0.98))
threshold_value = 0.10 * (image_robust_range[1] - image_robust_range[0]) + image_robust_range[0]
thresholded_mask = ants.threshold_image(image, -10000, threshold_value, 0, 1)
thresholded_image = image * thresholded_mask
# Standardize image
if verbose == True:
print(" HippMapp3r: standardize.")
mean_image = np.mean(thresholded_image[thresholded_mask==1])
sd_image = np.std(thresholded_image[thresholded_mask==1])
image_normalized = (image - mean_image) / sd_image
image_normalized = image_normalized * thresholded_mask
# Trim and resample image
if verbose == True:
print(" HippMapp3r: trim and resample to (160, 160, 128).")
image_cropped = ants.crop_image(image_normalized, thresholded_mask, 1)
shape_initial_stage = (160, 160, 128)
image_resampled = ants.resample_image(image_cropped, shape_initial_stage, use_voxels=True, interp_type=1)
if verbose == True:
print(" HippMapp3r: generate first network and download weights.")
model_initial_stage = create_hippmapp3r_unet_model_3d((*shape_initial_stage, 1), do_first_network=True)
initial_stage_weights_file_name = get_pretrained_network("hippMapp3rInitial", antsxnet_cache_directory=antsxnet_cache_directory)
model_initial_stage.load_weights(initial_stage_weights_file_name)
if verbose == True:
print(" HippMapp3r: prediction.")
data_initial_stage = np.expand_dims(image_resampled.numpy(), axis=0)
data_initial_stage = np.expand_dims(data_initial_stage, axis=-1)
mask_array = model_initial_stage.predict(data_initial_stage, verbose=verbose)
mask_image_resampled = ants.copy_image_info(image_resampled, ants.from_numpy(np.squeeze(mask_array)))
mask_image = ants.resample_image(mask_image_resampled, image.shape, use_voxels=True, interp_type=0)
mask_image[mask_image >= 0.5] = 1
mask_image[mask_image < 0.5] = 0
#########################################
#
# Perform refined (stage 2) segmentation
#
if verbose == True:
print("")
print("")
print("************* Refine stage segmentation ***************")
print("")
mask_array = np.squeeze(mask_array)
centroid_indices = np.where(mask_array == 1)
centroid = np.zeros((3,))
centroid[0] = centroid_indices[0].mean()
centroid[1] = centroid_indices[1].mean()
centroid[2] = centroid_indices[2].mean()
shape_refine_stage = (112, 112, 64)
lower = (np.floor(centroid - 0.5 * np.array(shape_refine_stage)) - 1).astype(int)
upper = (lower + np.array(shape_refine_stage)).astype(int)
image_trimmed = ants.crop_indices(image_resampled, lower.astype(int), upper.astype(int))
if verbose == True:
print(" HippMapp3r: generate second network and download weights.")
model_refine_stage = create_hippmapp3r_unet_model_3d((*shape_refine_stage, 1), do_first_network=False)
refine_stage_weights_file_name = get_pretrained_network("hippMapp3rRefine", antsxnet_cache_directory=antsxnet_cache_directory)
model_refine_stage.load_weights(refine_stage_weights_file_name)
data_refine_stage = np.expand_dims(image_trimmed.numpy(), axis=0)
data_refine_stage = np.expand_dims(data_refine_stage, axis=-1)
if verbose == True:
print(" HippMapp3r: Monte Carlo iterations (SpatialDropout).")
number_of_mci_iterations = 30
prediction_refine_stage = np.zeros(shape_refine_stage)
for i in range(number_of_mci_iterations):
tf.random.set_seed(i)
if verbose == True:
print(" Monte Carlo iteration", i + 1, "out of", number_of_mci_iterations)
prediction_refine_stage = \
(np.squeeze(model_refine_stage.predict(data_refine_stage, verbose=verbose)) + \
i * prediction_refine_stage ) / (i + 1)
prediction_refine_stage_array = np.zeros(image_resampled.shape)
prediction_refine_stage_array[lower[0]:upper[0],
lower[1]:upper[1],
lower[2]:upper[2]] = prediction_refine_stage
probability_mask_refine_stage_resampled = ants.copy_image_info(image_resampled, ants.from_numpy(prediction_refine_stage_array))
segmentation_image_resampled = ants.label_clusters(
ants.threshold_image(probability_mask_refine_stage_resampled, 0.0, 0.5, 0, 1), min_cluster_size=10)
segmentation_image_resampled[segmentation_image_resampled > 2] = 0
geom = ants.label_geometry_measures(segmentation_image_resampled)
if len(geom['VolumeInMillimeters']) < 2:
raise ValueError("Error: left and right hippocampus not found.")
if geom['Centroid_x'][0] < geom['Centroid_x'][1]:
segmentation_image_resampled[segmentation_image_resampled == 1] = 3
segmentation_image_resampled[segmentation_image_resampled == 2] = 1
segmentation_image_resampled[segmentation_image_resampled == 3] = 2
segmentation_image = ants.apply_transforms(fixed=t1,
moving=segmentation_image_resampled, transformlist=transforms['invtransforms'],
whichtoinvert=[True], interpolator="genericLabel", verbose=verbose)
return(segmentation_image)