Source code for antspynet.utilities.white_matter_hyperintensity_segmentation


import ants
import numpy as np
from tensorflow import keras

[docs]def sysu_media_wmh_segmentation(flair, t1=None, use_ensemble=True, antsxnet_cache_directory=None, verbose=False): """ Perform WMH segmentation using the winning submission in the MICCAI 2017 challenge by the sysu_media team using FLAIR or T1/FLAIR. The MICCAI challenge is discussed in https://pubmed.ncbi.nlm.nih.gov/30908194/ with the sysu_media's team entry is discussed in https://pubmed.ncbi.nlm.nih.gov/30125711/ with the original implementation available here: https://github.com/hongweilibran/wmh_ibbmTum The original implementation used global thresholding as a quick brain extraction approach. Due to possible generalization difficulties, we leave such post-processing steps to the user. For brain or white matter masking see functions brain_extraction or deep_atropos, respectively. Arguments --------- flair : ANTsImage input 3-D FLAIR brain image (not skull-stripped). t1 : ANTsImage input 3-D T1 brain image (not skull-stripped). use_ensemble : boolean check whether to use all 3 sets of weights. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be reused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- WMH segmentation probability image Example ------- >>> image = ants.image_read("flair.nii.gz") >>> probability_mask = sysu_media_wmh_segmentation(image) """ from ..architectures import create_sysu_media_unet_model_2d from ..utilities import get_pretrained_network from ..utilities import pad_or_crop_image_to_size from ..utilities import preprocess_brain_image from ..utilities import binary_dice_coefficient if flair.dimension != 3: raise ValueError( "Image dimension must be 3." ) image_size = (200, 200) ################################ # # Preprocess images # ################################ def closest_simplified_direction_matrix(direction): closest = (np.abs(direction) + 0.5).astype(int).astype(float) closest[direction < 0] *= -1.0 return closest simplified_direction = closest_simplified_direction_matrix(flair.direction) flair_preprocessing = preprocess_brain_image(flair, truncate_intensity=None, brain_extraction_modality=None, do_bias_correction=False, do_denoising=False, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) flair_preprocessed = flair_preprocessing["preprocessed_image"] flair_preprocessed.set_direction(simplified_direction) flair_preprocessed.set_origin((0, 0, 0)) flair_preprocessed.set_spacing((1, 1, 1)) number_of_channels = 1 t1_preprocessed = None if t1 is not None: t1_preprocessing = preprocess_brain_image(t1, truncate_intensity=None, brain_extraction_modality=None, do_bias_correction=False, do_denoising=False, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) t1_preprocessed = t1_preprocessing["preprocessed_image"] t1_preprocessed.set_direction(simplified_direction) t1_preprocessed.set_origin((0, 0, 0)) t1_preprocessed.set_spacing((1, 1, 1)) number_of_channels = 2 ################################ # # Reorient images # ################################ reference_image = ants.make_image((256, 256, 256), voxval=0, spacing=(1, 1, 1), origin=(0, 0, 0), direction=np.identity(3)) center_of_mass_reference = np.floor(ants.get_center_of_mass(reference_image * 0 + 1)) center_of_mass_image = np.floor(ants.get_center_of_mass(flair_preprocessed)) translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_reference) xfrm = ants.create_ants_transform(transform_type="Euler3DTransform", center=np.asarray(center_of_mass_reference), translation=translation) flair_preprocessed_warped = ants.apply_ants_transform_to_image( xfrm, flair_preprocessed, reference_image, interpolation="nearestneighbor") crop_image = ants.image_clone(flair_preprocessed) * 0 + 1 crop_image_warped = ants.apply_ants_transform_to_image( xfrm, crop_image, reference_image, interpolation="nearestneighbor") flair_preprocessed_warped = ants.crop_image(flair_preprocessed_warped, crop_image_warped, 1) if t1 is not None: t1_preprocessed_warped = ants.apply_ants_transform_to_image( xfrm, t1_preprocessed, reference_image, interpolation="nearestneighbor") t1_preprocessed_warped = ants.crop_image(t1_preprocessed_warped, crop_image_warped, 1) ################################ # # Gaussian normalize intensity # ################################ mean_flair = flair_preprocessed.mean() std_flair = flair_preprocessed.std() if number_of_channels == 2: mean_t1 = t1_preprocessed.mean() std_t1 = t1_preprocessed.std() flair_preprocessed_warped = (flair_preprocessed_warped - mean_flair) / std_flair if number_of_channels == 2: t1_preprocessed_warped = (t1_preprocessed_warped - mean_t1) / std_t1 ################################ # # Build models and load weights # ################################ number_of_models = 1 if use_ensemble: number_of_models = 3 if verbose: print("White matter hyperintensity: retrieving model weights.") unet_models = list() for i in range(number_of_models): if number_of_channels == 1: weights_file_name = get_pretrained_network("sysuMediaWmhFlairOnlyModel" + str(i), antsxnet_cache_directory=antsxnet_cache_directory) else: weights_file_name = get_pretrained_network("sysuMediaWmhFlairT1Model" + str(i), antsxnet_cache_directory=antsxnet_cache_directory) unet_model = create_sysu_media_unet_model_2d((*image_size, number_of_channels)) unet_loss = binary_dice_coefficient(smoothing_factor=1.) unet_model.compile(optimizer=keras.optimizers.Adam(learning_rate=2e-4), loss=unet_loss) unet_model.load_weights(weights_file_name) unet_models.append(unet_model) ################################ # # Extract slices # ################################ dimensions_to_predict = [2] total_number_of_slices = 0 for d in range(len(dimensions_to_predict)): total_number_of_slices += flair_preprocessed_warped.shape[dimensions_to_predict[d]] batchX = np.zeros((total_number_of_slices, *image_size, number_of_channels)) slice_count = 0 for d in range(len(dimensions_to_predict)): number_of_slices = flair_preprocessed_warped.shape[dimensions_to_predict[d]] if verbose: print("Extracting slices for dimension ", dimensions_to_predict[d], ".") for i in range(number_of_slices): flair_slice = pad_or_crop_image_to_size(ants.slice_image(flair_preprocessed_warped, dimensions_to_predict[d], i), image_size) batchX[slice_count,:,:,0] = flair_slice.numpy() if number_of_channels == 2: t1_slice = pad_or_crop_image_to_size(ants.slice_image(t1_preprocessed_warped, dimensions_to_predict[d], i), image_size) batchX[slice_count,:,:,1] = t1_slice.numpy() slice_count += 1 ################################ # # Do prediction and then restack into the image # ################################ if verbose: print("Prediction.") prediction = unet_models[0].predict(np.transpose(batchX, axes=(0, 2, 1, 3)), verbose=verbose) if number_of_models > 1: for i in range(1, number_of_models, 1): prediction += unet_models[i].predict(np.transpose(batchX, axes=(0, 2, 1, 3)), verbose=verbose) prediction /= number_of_models prediction = np.transpose(prediction, axes=(0, 2, 1, 3)) permutations = list() permutations.append((0, 1, 2)) permutations.append((1, 0, 2)) permutations.append((1, 2, 0)) prediction_image_average = ants.image_clone(flair_preprocessed_warped) * 0 current_start_slice = 0 for d in range(len(dimensions_to_predict)): current_end_slice = current_start_slice + flair_preprocessed_warped.shape[dimensions_to_predict[d]] which_batch_slices = range(current_start_slice, current_end_slice) prediction_per_dimension = prediction[which_batch_slices,:,:,:] prediction_array = np.transpose(np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]]) prediction_image = ants.copy_image_info(flair_preprocessed_warped, pad_or_crop_image_to_size(ants.from_numpy(prediction_array), flair_preprocessed_warped.shape)) prediction_image_average = prediction_image_average + (prediction_image - prediction_image_average) / (d + 1) current_start_slice = current_end_slice probability_image = ants.apply_ants_transform_to_image( ants.invert_ants_transform(xfrm), prediction_image_average, flair_preprocessed) probability_image = ants.copy_image_info(flair, probability_image) return(probability_image)
def hypermapp3r_segmentation(t1, flair, number_of_monte_carlo_iterations=30, do_preprocessing=True, antsxnet_cache_directory=None, verbose=False): """ Perform HyperMapp3r (white matter hyperintensities) segmentation described in https://pubmed.ncbi.nlm.nih.gov/35088930/ with models and architecture ported from https://github.com/mgoubran/HyperMapp3r Additional documentation and attribution resources found at https://hypermapp3r.readthedocs.io/en/latest/ Preprocessing consists of: * n4 bias correction and * brain extraction The input T1 should undergo the same steps. If the input T1 is the raw T1, these steps can be performed by the internal preprocessing, i.e. set do_preprocessing = True Arguments --------- t1 : ANTsImage input 3-D t1-weighted MR image. Assumed to be aligned with the flair. flair : ANTsImage input 3-D flair MR image. Assumed to be aligned with the t1. do_preprocessing : boolean See description above. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be reused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- ANTs labeled wmh segmentationimage. Example ------- >>> mask = hypermapp3r_segmentation(t1, flair) """ from ..architectures import create_hypermapp3r_unet_model_3d from ..utilities import preprocess_brain_image from ..utilities import get_pretrained_network if t1.dimension != 3: raise ValueError( "Image dimension must be 3." ) ################################ # # Preprocess images # ################################ if verbose: print("************* Preprocessing ***************") print("") t1_preprocessed = t1 brain_mask = None if do_preprocessing: t1_preprocessing = preprocess_brain_image(t1, truncate_intensity=(0.01, 0.99), brain_extraction_modality="t1", do_bias_correction=True, do_denoising=False, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) brain_mask = t1_preprocessing['brain_mask'] t1_preprocessed = t1_preprocessing["preprocessed_image"] * brain_mask else: # If we don't generate the mask from the preprocessing, we assume that we # can extract the brain directly from the foreground of the t1 image. brain_mask = ants.threshold_image(t1, 0, 0, 0, 1) t1_preprocessed_mean = t1_preprocessed[brain_mask > 0].mean() t1_preprocessed_std = t1_preprocessed[brain_mask > 0].std() t1_preprocessed[brain_mask > 0] = (t1_preprocessed[brain_mask > 0] - t1_preprocessed_mean) / t1_preprocessed_std flair_preprocessed = flair if do_preprocessing: flair_preprocessing = preprocess_brain_image(flair, truncate_intensity=(0.01, 0.99), brain_extraction_modality=None, do_bias_correction=True, do_denoising=False, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) flair_preprocessed = flair_preprocessing["preprocessed_image"] * brain_mask flair_preprocessed_mean = flair_preprocessed[brain_mask > 0].mean() flair_preprocessed_std = flair_preprocessed[brain_mask > 0].std() flair_preprocessed[brain_mask > 0] = (flair_preprocessed[brain_mask > 0] - flair_preprocessed_mean) / flair_preprocessed_std if verbose: print(" HyperMapp3r: reorient input images.") channel_size = 2 input_image_size = (224, 224, 224) template_array = np.ones(input_image_size) template_direction = np.eye(3) template_direction[1, 1] = -1.0 reorient_template = ants.from_numpy(template_array, origin=(0, 0, 0), spacing=(1, 1, 1), direction=template_direction) center_of_mass_template = ants.get_center_of_mass(reorient_template) center_of_mass_image = ants.get_center_of_mass(brain_mask) translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template) xfrm = ants.create_ants_transform(transform_type="Euler3DTransform", center=np.asarray(center_of_mass_template), translation=translation) batchX = np.zeros((1, *input_image_size, channel_size)) t1_preprocessed_warped = ants.apply_ants_transform_to_image(xfrm, t1_preprocessed, reorient_template) batchX[0,:,:,:,0] = t1_preprocessed_warped.numpy() flair_preprocessed_warped = ants.apply_ants_transform_to_image(xfrm, flair_preprocessed, reorient_template) batchX[0,:,:,:,1] = flair_preprocessed_warped.numpy() if verbose: print(" HyperMapp3r: generate network and load weights.") model = create_hypermapp3r_unet_model_3d((*input_image_size, 2)) weights_file_name = get_pretrained_network("hyperMapp3r", antsxnet_cache_directory=antsxnet_cache_directory) model.load_weights(weights_file_name) if verbose: print(" HyperMapp3r: prediction.") if verbose: print(" HyperMapp3r: Monte Carlo iterations (SpatialDropout).") prediction_array = np.zeros(input_image_size) for i in range(number_of_monte_carlo_iterations): if verbose: print(" Monte Carlo iteration", i + 1, "out of", number_of_monte_carlo_iterations) prediction_array = (np.squeeze(model.predict(batchX, verbose=verbose)) + i * prediction_array) / (i + 1) prediction_image = ants.from_numpy(prediction_array, origin=reorient_template.origin, spacing=reorient_template.spacing, direction=reorient_template.direction) xfrm_inv = xfrm.invert() probability_image = xfrm_inv.apply_to_image(prediction_image, t1) return(probability_image) def wmh_segmentation(flair, t1, white_matter_mask=None, use_combined_model=True, prediction_batch_size=16, patch_stride_length=32, do_preprocessing=True, antsxnet_cache_directory=None, verbose=False): """ Perform White matter hyperintensity probabilistic segmentation given a pre-aligned FLAIR and T2 images. Note that the underlying model is 3-D and requires images to be of > 64 voxels in each dimension. Preprocessing on the training data consisted of: * n4 bias correction, * brain extraction The input T1 should undergo the same steps. If the input T1 is the raw T1, these steps can be performed by the internal preprocessing, i.e. set \code{do_preprocessing = True} Arguments --------- flair : ANTsImage input 3-D FLAIR brain image (not skull-stripped). t1 : ANTsImage input 3-D T1 brain image (not skull-stripped). white_matter_mask : ANTsImage input white matter mask for patch extraction. If None, calculated using deep_atropos (labels 3 and 4). use_combined_model : boolean Original or combined. prediction_batch_size : int Control memory usage for prediction. More consequential for GPU-usage. patch_stride_length : 3-D tuple or int Dictates the stride length for accumulating predicting patches. do_preprocessing : boolean perform n4 bias correction, intensity truncation, brain extraction. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be reused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- WMH segmentation probability image Example ------- >>> flair = ants.image_read("flair.nii.gz") >>> t1 = ants.image_read("t1.nii.gz") >>> probability_mask = wmh_segmentation(flair, t1) """ from ..architectures import create_sysu_media_unet_model_3d from ..utilities import deep_atropos from ..utilities import extract_image_patches from ..utilities import reconstruct_image_from_patches from ..utilities import get_pretrained_network from ..utilities import preprocess_brain_image if np.any(t1.shape < np.array((64, 64, 64))): raise ValueError("Images must be > 64 voxels per dimension.") ################################ # # Preprocess images # ################################ if white_matter_mask is None: if verbose: print("Calculate white matter mask.") atropos = deep_atropos(t1, do_preprocessing=True, verbose=verbose) white_matter_mask = ants.threshold_image(atropos['segmentation_image'], 3, 4, 1, 0) t1_preprocessed = None flair_preprocessed = None if do_preprocessing: if verbose: print("Preprocess T1 and FLAIR images.") t1_preprocessing = preprocess_brain_image(t1, truncate_intensity=(0.01, 0.995), brain_extraction_modality="t1", do_bias_correction=True, do_denoising=False, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) brain_mask = ants.threshold_image(t1_preprocessing["brain_mask"], 0.5, 1, 1, 0) t1_preprocessed = t1_preprocessing["preprocessed_image"] * brain_mask flair_preprocessing = preprocess_brain_image(flair, truncate_intensity=None, brain_extraction_modality=None, do_bias_correction=True, do_denoising=False, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) flair_preprocessed = flair_preprocessing["preprocessed_image"] * brain_mask else: t1_preprocessed = ants.image_clone(t1) flair_preprocessed = ants.image_clone(flair) white_matter_indices = white_matter_mask > 0 t1_preprocessed_min = t1_preprocessed[white_matter_indices].min() t1_preprocessed_max = t1_preprocessed[white_matter_indices].max() flair_preprocessed_min = flair_preprocessed[white_matter_indices].min() flair_preprocessed_max = flair_preprocessed[white_matter_indices].max() t1_preprocessed = (t1_preprocessed - t1_preprocessed_min) / (t1_preprocessed_max - t1_preprocessed_min) flair_preprocessed = (flair_preprocessed - flair_preprocessed_min) / (flair_preprocessed_max - flair_preprocessed_min) ################################ # # Build model and load weights # ################################ if verbose: print("Load model and weights.") patch_size = (64, 64, 64) if isinstance(patch_stride_length, int): patch_stride_length = (patch_stride_length,) * 3 number_of_filters = (64, 96, 128, 256, 512) channel_size = 2 model = create_sysu_media_unet_model_3d((*patch_size, channel_size), number_of_filters=number_of_filters) weights_file_name = None if use_combined_model: weights_file_name = get_pretrained_network("antsxnetWmhOr", antsxnet_cache_directory=antsxnet_cache_directory) else: weights_file_name = get_pretrained_network("antsxnetWmh", antsxnet_cache_directory=antsxnet_cache_directory) model.load_weights(weights_file_name) ################################ # # Extract patches # ################################ if verbose: print("Extract patches.") t1_patches = extract_image_patches(t1_preprocessed, patch_size=patch_size, max_number_of_patches="all", stride_length=patch_stride_length, mask_image=white_matter_mask, random_seed=None, return_as_array=True) flair_patches = extract_image_patches(flair_preprocessed, patch_size=patch_size, max_number_of_patches="all", stride_length=patch_stride_length, mask_image=white_matter_mask, random_seed=None, return_as_array=True) total_number_of_patches = t1_patches.shape[0] ################################ # # Do prediction and then restack into the image # ################################ number_of_batches = total_number_of_patches // prediction_batch_size residual_number_of_patches = total_number_of_patches - number_of_batches * prediction_batch_size if residual_number_of_patches > 0: number_of_batches = number_of_batches + 1 if verbose: print("Total number of patches: ", str(total_number_of_patches)) print("Prediction batch size: ", str(prediction_batch_size)) print("Number of batches: ", str(number_of_batches + 1)) prediction = np.zeros((total_number_of_patches, *patch_size, 1)) for b in range(number_of_batches): batchX = None if b < number_of_batches - 1 or residual_number_of_patches == 0: batchX = np.zeros((prediction_batch_size, *patch_size, channel_size)) else: batchX = np.zeros((residual_number_of_patches, *patch_size, channel_size)) indices = range(b * prediction_batch_size, b * prediction_batch_size + batchX.shape[0]) batchX[:,:,:,:,0] = flair_patches[indices,:,:,:] batchX[:,:,:,:,1] = t1_patches[indices,:,:,:] if verbose: print("Predicting batch ", str(b + 1), " of ", str(number_of_batches)) prediction[indices,:,:,:,:] = model.predict(batchX, verbose=verbose) if verbose: print("Predict patches and reconstruct.") wmh_probability_image = reconstruct_image_from_patches(np.squeeze(prediction), stride_length=patch_stride_length, domain_image=white_matter_mask, domain_image_is_mask=True) return(wmh_probability_image)