import ants
import numpy as np
import tensorflow as tf
from tensorflow import keras
def hypermapp3r_segmentation(t1,
flair,
number_of_monte_carlo_iterations=30,
do_preprocessing=True,
verbose=False):
"""
Perform HyperMapp3r (white matter hyperintensities) segmentation described in
https://pubmed.ncbi.nlm.nih.gov/35088930/
with models and architecture ported from
https://github.com/mgoubran/HyperMapp3r
Additional documentation and attribution resources found at
https://hypermapp3r.readthedocs.io/en/latest/
Preprocessing consists of:
* n4 bias correction and
* brain extraction
The input T1 should undergo the same steps. If the input T1 is the raw
T1, these steps can be performed by the internal preprocessing, i.e. set
do_preprocessing = True
Arguments
---------
t1 : ANTsImage
input 3-D t1-weighted MR image. Assumed to be aligned with the flair.
flair : ANTsImage
input 3-D flair MR image. Assumed to be aligned with the t1.
do_preprocessing : boolean
See description above.
verbose : boolean
Print progress to the screen.
Returns
-------
ANTs labeled wmh segmentationimage.
Example
-------
>>> mask = hypermapp3r_segmentation(t1, flair)
"""
from ..architectures import create_hypermapp3r_unet_model_3d
from ..utilities import preprocess_brain_image
from ..utilities import get_pretrained_network
if t1.dimension != 3:
raise ValueError( "Image dimension must be 3." )
################################
#
# Preprocess images
#
################################
if verbose:
print("************* Preprocessing ***************")
print("")
t1_preprocessed = t1
brain_mask = None
if do_preprocessing:
t1_preprocessing = preprocess_brain_image(t1,
truncate_intensity=(0.01, 0.99),
brain_extraction_modality="t1",
do_bias_correction=True,
do_denoising=False,
verbose=verbose)
brain_mask = t1_preprocessing['brain_mask']
t1_preprocessed = t1_preprocessing["preprocessed_image"] * brain_mask
else:
# If we don't generate the mask from the preprocessing, we assume that we
# can extract the brain directly from the foreground of the t1 image.
brain_mask = ants.threshold_image(t1, 0, 0, 0, 1)
t1_preprocessed_mean = t1_preprocessed[brain_mask > 0].mean()
t1_preprocessed_std = t1_preprocessed[brain_mask > 0].std()
t1_preprocessed[brain_mask > 0] = (t1_preprocessed[brain_mask > 0] - t1_preprocessed_mean) / t1_preprocessed_std
flair_preprocessed = flair
if do_preprocessing:
flair_preprocessing = preprocess_brain_image(flair,
truncate_intensity=(0.01, 0.99),
brain_extraction_modality=None,
do_bias_correction=True,
do_denoising=False,
verbose=verbose)
flair_preprocessed = flair_preprocessing["preprocessed_image"] * brain_mask
flair_preprocessed_mean = flair_preprocessed[brain_mask > 0].mean()
flair_preprocessed_std = flair_preprocessed[brain_mask > 0].std()
flair_preprocessed[brain_mask > 0] = (flair_preprocessed[brain_mask > 0] - flair_preprocessed_mean) / flair_preprocessed_std
if verbose:
print(" HyperMapp3r: reorient input images.")
channel_size = 2
input_image_size = (224, 224, 224)
template_array = np.ones(input_image_size)
template_direction = np.eye(3)
template_direction[1, 1] = -1.0
reorient_template = ants.from_numpy(template_array, origin=(0, 0, 0), spacing=(1, 1, 1),
direction=template_direction)
center_of_mass_template = ants.get_center_of_mass(reorient_template)
center_of_mass_image = ants.get_center_of_mass(brain_mask)
translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template)
xfrm = ants.create_ants_transform(transform_type="Euler3DTransform",
center=np.asarray(center_of_mass_template), translation=translation)
batchX = np.zeros((1, *input_image_size, channel_size))
t1_preprocessed_warped = ants.apply_ants_transform_to_image(xfrm, t1_preprocessed, reorient_template)
batchX[0,:,:,:,0] = t1_preprocessed_warped.numpy()
flair_preprocessed_warped = ants.apply_ants_transform_to_image(xfrm, flair_preprocessed, reorient_template)
batchX[0,:,:,:,1] = flair_preprocessed_warped.numpy()
if verbose:
print(" HyperMapp3r: generate network and load weights.")
model = create_hypermapp3r_unet_model_3d((*input_image_size, 2))
weights_file_name = get_pretrained_network("hyperMapp3r")
model.load_weights(weights_file_name)
if verbose:
print(" HyperMapp3r: prediction.")
if verbose:
print(" HyperMapp3r: Monte Carlo iterations (SpatialDropout).")
prediction_array = np.zeros(input_image_size)
for i in range(number_of_monte_carlo_iterations):
if verbose:
print(" Monte Carlo iteration", i + 1, "out of", number_of_monte_carlo_iterations)
prediction_array = (np.squeeze(model.predict(batchX, verbose=verbose)) + i * prediction_array) / (i + 1)
prediction_image = ants.from_numpy(prediction_array, origin=reorient_template.origin,
spacing=reorient_template.spacing, direction=reorient_template.direction)
xfrm_inv = xfrm.invert()
probability_image = xfrm_inv.apply_to_image(prediction_image, t1)
return(probability_image)
def wmh_segmentation(flair,
t1,
white_matter_mask=None,
use_combined_model=True,
prediction_batch_size=16,
patch_stride_length=32,
do_preprocessing=True,
verbose=False):
"""
Perform White matter hyperintensity probabilistic segmentation
given a pre-aligned FLAIR and T1 images. Note that the underlying
model is 3-D and requires images to be of > 64 voxels in each
dimension.
Preprocessing on the training data consisted of:
* n4 bias correction,
* brain extraction
The input T1 should undergo the same steps. If the input T1 is the raw
T1, these steps can be performed by the internal preprocessing, i.e. set
\code{do_preprocessing = True}
Arguments
---------
flair : ANTsImage
input 3-D FLAIR brain image (not skull-stripped).
t1 : ANTsImage
input 3-D T1 brain image (not skull-stripped).
white_matter_mask : ANTsImage
input white matter mask for patch extraction. If None, calculated using
deep_atropos (labels 3 and 4).
use_combined_model : boolean
Original or combined.
prediction_batch_size : int
Control memory usage for prediction. More consequential for GPU-usage.
patch_stride_length : 3-D tuple or int
Dictates the stride length for accumulating predicting patches.
do_preprocessing : boolean
perform n4 bias correction, intensity truncation, brain extraction.
verbose : boolean
Print progress to the screen.
Returns
-------
WMH segmentation probability image
Example
-------
>>> flair = ants.image_read("flair.nii.gz")
>>> t1 = ants.image_read("t1.nii.gz")
>>> probability_mask = wmh_segmentation(flair, t1)
"""
from ..architectures import create_sysu_media_unet_model_3d
from ..utilities import deep_atropos
from ..utilities import get_pretrained_network
from ..utilities import preprocess_brain_image
if np.any(t1.shape < np.array((64, 64, 64))):
raise ValueError("Images must be > 64 voxels per dimension.")
################################
#
# Preprocess images
#
################################
if white_matter_mask is None:
if verbose:
print("Calculate white matter mask.")
atropos = deep_atropos(t1, do_preprocessing=True, verbose=verbose)
white_matter_mask = ants.threshold_image(atropos['segmentation_image'], 3, 4, 1, 0)
t1_preprocessed = None
flair_preprocessed = None
if do_preprocessing:
if verbose:
print("Preprocess T1 and FLAIR images.")
t1_preprocessing = preprocess_brain_image(t1,
truncate_intensity=(0.01, 0.995),
brain_extraction_modality="t1",
do_bias_correction=True,
do_denoising=False,
verbose=verbose)
brain_mask = ants.threshold_image(t1_preprocessing["brain_mask"], 0.5, 1, 1, 0)
t1_preprocessed = t1_preprocessing["preprocessed_image"] * brain_mask
flair_preprocessing = preprocess_brain_image(flair,
truncate_intensity=None,
brain_extraction_modality=None,
do_bias_correction=True,
do_denoising=False,
verbose=verbose)
flair_preprocessed = flair_preprocessing["preprocessed_image"] * brain_mask
else:
t1_preprocessed = ants.image_clone(t1)
flair_preprocessed = ants.image_clone(flair)
white_matter_indices = white_matter_mask > 0
t1_preprocessed_min = t1_preprocessed[white_matter_indices].min()
t1_preprocessed_max = t1_preprocessed[white_matter_indices].max()
flair_preprocessed_min = flair_preprocessed[white_matter_indices].min()
flair_preprocessed_max = flair_preprocessed[white_matter_indices].max()
t1_preprocessed = (t1_preprocessed - t1_preprocessed_min) / (t1_preprocessed_max - t1_preprocessed_min)
flair_preprocessed = (flair_preprocessed - flair_preprocessed_min) / (flair_preprocessed_max - flair_preprocessed_min)
################################
#
# Build model and load weights
#
################################
if verbose:
print("Load model and weights.")
patch_size = (64, 64, 64)
if isinstance(patch_stride_length, int):
patch_stride_length = (patch_stride_length,) * 3
number_of_filters = (64, 96, 128, 256, 512)
channel_size = 2
model = create_sysu_media_unet_model_3d((*patch_size, channel_size),
number_of_filters=number_of_filters)
weights_file_name = None
if use_combined_model:
weights_file_name = get_pretrained_network("antsxnetWmhOr")
else:
weights_file_name = get_pretrained_network("antsxnetWmh")
model.load_weights(weights_file_name)
################################
#
# Extract patches
#
################################
if verbose:
print("Extract patches.")
t1_patches = ants.extract_image_patches(t1_preprocessed,
patch_size=patch_size,
max_number_of_patches="all",
stride_length=patch_stride_length,
mask_image=white_matter_mask,
random_seed=None,
return_as_array=True)
flair_patches = ants.extract_image_patches(flair_preprocessed,
patch_size=patch_size,
max_number_of_patches="all",
stride_length=patch_stride_length,
mask_image=white_matter_mask,
random_seed=None,
return_as_array=True)
total_number_of_patches = t1_patches.shape[0]
################################
#
# Do prediction and then restack into the image
#
################################
number_of_batches = total_number_of_patches // prediction_batch_size
residual_number_of_patches = total_number_of_patches - number_of_batches * prediction_batch_size
if residual_number_of_patches > 0:
number_of_batches = number_of_batches + 1
if verbose:
print("Total number of patches: ", str(total_number_of_patches))
print("Prediction batch size: ", str(prediction_batch_size))
print("Number of batches: ", str(number_of_batches + 1))
prediction = np.zeros((total_number_of_patches, *patch_size, 1))
for b in range(number_of_batches):
batchX = None
if b < number_of_batches - 1 or residual_number_of_patches == 0:
batchX = np.zeros((prediction_batch_size, *patch_size, channel_size))
else:
batchX = np.zeros((residual_number_of_patches, *patch_size, channel_size))
indices = range(b * prediction_batch_size, b * prediction_batch_size + batchX.shape[0])
batchX[:,:,:,:,0] = flair_patches[indices,:,:,:]
batchX[:,:,:,:,1] = t1_patches[indices,:,:,:]
if verbose:
print("Predicting batch ", str(b + 1), " of ", str(number_of_batches))
prediction[indices,:,:,:,:] = model.predict(batchX, verbose=verbose)
if verbose:
print("Predict patches and reconstruct.")
wmh_probability_image = ants.reconstruct_image_from_patches(np.squeeze(prediction),
stride_length=patch_stride_length,
domain_image=white_matter_mask,
domain_image_is_mask=True)
return(wmh_probability_image)
def shiva_pvs_segmentation(t1,
flair=None,
which_model="all",
do_preprocessing=True,
verbose=False):
"""
Perform segmentation of perivascular (PVS) or Vircho-Robin spaces (VRS).
https://pubmed.ncbi.nlm.nih.gov/34262443/
with the original implementation available here:
https://github.com/pboutinaud/SHIVA_PVS
Arguments
---------
t1 : ANTsImage
input 3-D T1 brain image (not skull-stripped).
flair : ANTsImage
(Optional) input 3-D FLAIR brain image (not skull-stripped) aligned to the T1 image.
which_model : integer or string
Several models were trained for the case of T1-only or T1/FLAIR image
pairs. One can use a specific single trained model or the average of
the entire ensemble. I.e., options are:
* For T1-only: 0, 1, 2, 3, 4, 5.
* For T1/FLAIR: 0, 1, 2, 3, 4.
* Or "all" for using the entire ensemble.
do_preprocessing : boolean
perform n4 bias correction, intensity truncation, brain extraction.
verbose : boolean
Print progress to the screen.
Returns
-------
PVS or VRS segmentation probability image
Example
-------
>>> image = ants.image_read("flair.nii.gz")
>>> probability_mask = shiva_pvs_segmentation(image)
"""
from ..utilities import get_pretrained_network
from ..utilities import preprocess_brain_image
from ..architectures import create_shiva_unet_model_3d
################################
#
# Preprocess images
#
################################
t1_preprocessed = None
flair_preprocessed = None
brain_mask = None
if do_preprocessing:
if verbose:
print("Preprocess image(s).")
t1_preprocessing = preprocess_brain_image(t1,
truncate_intensity=(0.0, 0.99),
brain_extraction_modality="t1",
do_bias_correction=True,
do_denoising=False,
intensity_normalization_type="01",
verbose=verbose)
brain_mask = ants.threshold_image(t1_preprocessing["brain_mask"], 0.5, 1, 1, 0)
t1_preprocessed = t1_preprocessing["preprocessed_image"] * brain_mask
if flair is not None:
flair_preprocessing = preprocess_brain_image(flair,
truncate_intensity=(0.0, 0.99),
brain_extraction_modality=None,
do_bias_correction=True,
do_denoising=False,
intensity_normalization_type="01",
verbose=verbose)
flair_preprocessed = flair_preprocessing["preprocessed_image"] * brain_mask
else:
t1_preprocessed = ants.image_clone(t1)
if flair is not None:
flair_preprocessed = ants.image_clone(flair)
brain_mask = ants.threshold_image(t1, 0, 0, 0, 1)
image_shape = (160, 214, 176)
reorient_template = ants.from_numpy(np.ones(image_shape), origin=(0, 0, 0),
spacing=(1, 1, 1), direction=np.eye(3))
center_of_mass_template = ants.get_center_of_mass(reorient_template)
center_of_mass_image = ants.get_center_of_mass(brain_mask)
translation = np.round(np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template))
xfrm = ants.create_ants_transform(transform_type="Euler3DTransform",
center=np.round(np.asarray(center_of_mass_template)), translation=translation)
t1_preprocessed = ants.apply_ants_transform_to_image(xfrm, t1_preprocessed, reorient_template)
if flair is not None:
flair_preprocessed = ants.apply_ants_transform_to_image(xfrm, flair_preprocessed, reorient_template)
################################
#
# Load models and predict
#
################################
batchY = None
if flair is None:
batchX = np.zeros((1, *image_shape, 1))
batchX[0,:,:,:,0] = t1_preprocessed.numpy()
model_ids = [which_model,]
if which_model == "all":
model_ids = [0, 1, 2, 3, 4, 5]
for i in range(len(model_ids)):
model_weights_file = get_pretrained_network("pvs_shiva_t1_" + str(model_ids[i]))
if verbose:
print("Loading", model_weights_file)
model = create_shiva_unet_model_3d(number_of_modalities=1)
model.load_weights(model_weights_file)
if i == 0:
batchY = model.predict(batchX, verbose=verbose)
else:
batchY += model.predict(batchX, verbose=verbose)
batchY /= len(model_ids)
else:
batchX = np.zeros((1, *image_shape, 2))
batchX[0,:,:,:,0] = t1_preprocessed.numpy()
batchX[0,:,:,:,1] = flair_preprocessed.numpy()
model_ids = [which_model,]
if which_model == "all":
model_ids = [0, 1, 2, 3, 4]
for i in range(len(model_ids)):
model_weights_file = get_pretrained_network("pvs_shiva_t1_flair_" + str(model_ids[i]))
if verbose:
print("Loading", model_weights_file)
model = create_shiva_unet_model_3d(number_of_modalities=2)
model.load_weights(model_weights_file)
if i == 0:
batchY = model.predict(batchX, verbose=verbose)
else:
batchY += model.predict(batchX, verbose=verbose)
batchY /= len(model_ids)
pvs = ants.from_numpy(np.squeeze(batchY), origin=reorient_template.origin,
spacing=reorient_template.spacing,
direction=reorient_template.direction)
pvs = ants.apply_ants_transform_to_image(xfrm.invert(), pvs, t1)
return pvs
def shiva_wmh_segmentation(flair,
t1=None,
which_model="all",
do_preprocessing=True,
verbose=False):
"""
Perform segmentation of white matter hyperintensities.
https://pubmed.ncbi.nlm.nih.gov/38050769/
with the original implementation available here:
https://github.com/pboutinaud/SHIVA_WMH
Arguments
---------
flair : ANTsImage
input 3-D FLAIR brain image (not skull-stripped) aligned to the T1 image.
t1 : ANTsImage
(optional) input 3-D T1 brain image (not skull-stripped).
which_model : integer or string
Several models were trained for the case of T1-only or T1/FLAIR image
pairs. One can use a specific single trained model or the average of
the entire ensemble. I.e., options are:
* For T1-only: 0, 1, 2, 3, 4.
* For T1/FLAIR: 0, 1, 2, 3, 4.
* Or "all" for using the entire ensemble.
do_preprocessing : boolean
perform n4 bias correction, intensity truncation, brain extraction.
verbose : boolean
Print progress to the screen.
Returns
-------
PVS or VRS segmentation probability image
Example
-------
>>> image = ants.image_read("flair.nii.gz")
>>> probability_mask = shiva_wmh_segmentation(image)
"""
from ..utilities import get_pretrained_network
from ..utilities import preprocess_brain_image
from ..architectures import create_shiva_unet_model_3d
################################
#
# Preprocess images
#
################################
t1_preprocessed = None
flair_preprocessed = None
brain_mask = None
if do_preprocessing:
if verbose:
print("Preprocess image(s).")
flair_preprocessing = preprocess_brain_image(flair,
truncate_intensity=(0.0, 0.99),
brain_extraction_modality="flair",
do_bias_correction=True,
do_denoising=False,
intensity_normalization_type="01",
verbose=verbose)
brain_mask = ants.threshold_image(flair_preprocessing["brain_mask"], 0.5, 1, 1, 0)
flair_preprocessed = flair_preprocessing["preprocessed_image"] * brain_mask
if t1 is not None:
t1_preprocessing = preprocess_brain_image(t1,
truncate_intensity=(0.0, 0.99),
brain_extraction_modality=None,
do_bias_correction=True,
do_denoising=False,
intensity_normalization_type="01",
verbose=verbose)
t1_preprocessed = t1_preprocessing["preprocessed_image"] * brain_mask
else:
flair_preprocessed = ants.image_clone(flair)
if t1 is not None:
t1_preprocessed = ants.image_clone(t1)
brain_mask = ants.threshold_image(flair, 0, 0, 0, 1)
image_shape = (160, 214, 176)
reorient_template = ants.from_numpy(np.ones(image_shape), origin=(0, 0, 0),
spacing=(1, 1, 1), direction=np.eye(3))
center_of_mass_template = ants.get_center_of_mass(reorient_template)
center_of_mass_image = ants.get_center_of_mass(brain_mask)
translation = np.round(np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template))
xfrm = ants.create_ants_transform(transform_type="Euler3DTransform",
center=np.round(np.asarray(center_of_mass_template)), translation=translation)
flair_preprocessed = ants.apply_ants_transform_to_image(xfrm, flair_preprocessed, reorient_template)
if t1 is not None:
t1_preprocessed = ants.apply_ants_transform_to_image(xfrm, t1_preprocessed, reorient_template)
################################
#
# Load models and predict
#
################################
batchY = None
if t1 is None:
batchX = np.zeros((1, *image_shape, 1))
batchX[0,:,:,:,0] = flair_preprocessed.numpy()
model_ids = [which_model,]
if which_model == "all":
model_ids = [0, 1, 2, 3, 4]
for i in range(len(model_ids)):
model_weights_file = get_pretrained_network("wmh_shiva_flair_" + str(model_ids[i]))
if verbose:
print("Loading", model_weights_file)
model = create_shiva_unet_model_3d(number_of_modalities=1)
model.load_weights(model_weights_file)
if i == 0:
batchY = model.predict(batchX, verbose=verbose)
else:
batchY += model.predict(batchX, verbose=verbose)
batchY /= len(model_ids)
else:
batchX = np.zeros((1, *image_shape, 2))
batchX[0,:,:,:,0] = t1_preprocessed.numpy()
batchX[0,:,:,:,1] = flair_preprocessed.numpy()
model_ids = [which_model,]
if which_model == "all":
model_ids = [0, 1, 2, 3, 4]
for i in range(len(model_ids)):
model_weights_file = get_pretrained_network("wmh_shiva_t1_flair_" + str(model_ids[i]))
if verbose:
print("Loading", model_weights_file)
model = create_shiva_unet_model_3d(number_of_modalities=2)
model.load_weights(model_weights_file)
if i == 0:
batchY = model.predict(batchX, verbose=verbose)
else:
batchY += model.predict(batchX, verbose=verbose)
batchY /= len(model_ids)
wmh = ants.from_numpy(np.squeeze(batchY), origin=reorient_template.origin,
spacing=reorient_template.spacing,
direction=reorient_template.direction)
wmh = ants.apply_ants_transform_to_image(xfrm.invert(), wmh, flair)
return wmh