Source code for antspynet.utilities.deep_atropos

import numpy as np
import ants

[docs]def deep_atropos(t1, do_preprocessing=True, use_spatial_priors=1, antsxnet_cache_directory=None, verbose=False): """ Six-tissue segmentation. Perform Atropos-style six tissue segmentation using deep learning. The labeling is as follows: Label 0 : background Label 1 : CSF Label 2 : gray matter Label 3 : white matter Label 4 : deep gray matter Label 5 : brain stem Label 6 : cerebellum Preprocessing on the training data consisted of: * n4 bias correction, * denoising, * brain extraction, and * affine registration to MNI. The input T1 should undergo the same steps. If the input T1 is the raw T1, these steps can be performed by the internal preprocessing, i.e. set do_preprocessing = True Arguments --------- t1 : ANTsImage raw or preprocessed 3-D T1-weighted brain image. do_preprocessing : boolean See description above. use_spatial_priors : integer Use MNI spatial tissue priors (0 or 1). Currently, only '0' (no priors) and '1' (cerebellar prior only) are the only two options. Default is 1. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be reused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- List consisting of the segmentation image and probability images for each label. Example ------- >>> image = ants.image_read("t1.nii.gz") >>> flash = deep_atropos(image) """ from ..architectures import create_unet_model_3d from ..utilities import get_pretrained_network from ..utilities import get_antsxnet_data from ..utilities import preprocess_brain_image from ..utilities import extract_image_patches from ..utilities import reconstruct_image_from_patches if t1.dimension != 3: raise ValueError( "Image dimension must be 3." ) ################################ # # Preprocess images # ################################ t1_preprocessed = t1 if do_preprocessing == True: t1_preprocessing = preprocess_brain_image(t1, truncate_intensity=(0.01, 0.99), brain_extraction_modality="t1", template="croppedMni152", template_transform_type="antsRegistrationSyNQuickRepro[a]", do_bias_correction=True, do_denoising=True, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) t1_preprocessed = t1_preprocessing["preprocessed_image"] * t1_preprocessing['brain_mask'] ################################ # # Build model and load weights # ################################ patch_size = (112, 112, 112) stride_length = (t1_preprocessed.shape[0] - patch_size[0], t1_preprocessed.shape[1] - patch_size[1], t1_preprocessed.shape[2] - patch_size[2]) classes = ("background", "csf", "gray matter", "white matter", "deep gray matter", "brain stem", "cerebellum") mni_priors = None channel_size = 1 if use_spatial_priors != 0: mni_priors = ants.ndimage_to_list(ants.image_read(get_antsxnet_data("croppedMni152Priors", antsxnet_cache_directory=antsxnet_cache_directory))) for i in range(len(mni_priors)): mni_priors[i] = ants.copy_image_info(t1_preprocessed, mni_priors[i]) channel_size = 2 unet_model = create_unet_model_3d((*patch_size, channel_size), number_of_outputs=len(classes), mode = "classification", number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0, convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), weight_decay=1e-5, additional_options=("attentionGating")) if verbose == True: print("DeepAtropos: retrieving model weights.") weights_file_name = '' if use_spatial_priors == 0: weights_file_name = get_pretrained_network("sixTissueOctantBrainSegmentation", antsxnet_cache_directory=antsxnet_cache_directory) elif use_spatial_priors == 1: weights_file_name = get_pretrained_network("sixTissueOctantBrainSegmentationWithPriors1", antsxnet_cache_directory=antsxnet_cache_directory) else: raise ValueError("use_spatial_priors must be a 0 or 1") unet_model.load_weights(weights_file_name) ################################ # # Do prediction and normalize to native space # ################################ if verbose == True: print("Prediction.") t1_preprocessed = (t1_preprocessed - t1_preprocessed.mean()) / t1_preprocessed.std() image_patches = extract_image_patches(t1_preprocessed, patch_size=patch_size, max_number_of_patches="all", stride_length=stride_length, return_as_array=True) batchX = np.zeros((*image_patches.shape, channel_size)) batchX[:,:,:,:,0] = image_patches if channel_size > 1: prior_patches = extract_image_patches(mni_priors[6], patch_size=patch_size, max_number_of_patches="all", stride_length=stride_length, return_as_array=True) batchX[:,:,:,:,1] = prior_patches predicted_data = unet_model.predict(batchX, verbose=verbose) probability_images = list() for i in range(len(classes)): if verbose == True: print("Reconstructing image", classes[i]) reconstructed_image = reconstruct_image_from_patches(predicted_data[:,:,:,:,i], domain_image=t1_preprocessed, stride_length=stride_length) if do_preprocessing == True: probability_images.append(ants.apply_transforms(fixed=t1, moving=reconstructed_image, transformlist=t1_preprocessing['template_transforms']['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose)) else: probability_images.append(reconstructed_image) image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1) segmentation_matrix = np.argmax(image_matrix, axis=0) segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] return_dict = {'segmentation_image' : segmentation_image, 'probability_images' : probability_images} return(return_dict)