Source code for antspynet.utilities.deep_atropos

import numpy as np
import ants

[docs]def deep_atropos(t1, do_preprocessing=True, use_spatial_priors=1, do_denoising=True, verbose=False): """ Six-tissue segmentation. Perform Atropos-style six tissue segmentation using deep learning. The labeling is as follows: Label 0 : background Label 1 : CSF Label 2 : gray matter Label 3 : white matter Label 4 : deep gray matter Label 5 : brain stem Label 6 : cerebellum Preprocessing on the training data consisted of: * n4 bias correction, * denoising, * brain extraction, and * affine registration to MNI. The input T1 should undergo the same steps. If the input T1 is the raw T1, these steps can be performed by the internal preprocessing, i.e. set do_preprocessing = True Arguments --------- t1 : ANTsImage raw or preprocessed 3-D T1-weighted brain image. do_preprocessing : boolean See description above. use_spatial_priors : integer Use MNI spatial tissue priors (0 or 1). Currently, only '0' (no priors) and '1' (cerebellar prior only) are the only two options. Default is 1. do_denoising : boolean Activate denoising within preprocessing (default True). verbose : boolean Print progress to the screen. Returns ------- Dictionary consisting of the segmentation image and probability images for each label. Example ------- >>> image = ants.image_read("t1.nii.gz") >>> flash = deep_atropos(image) """ from ..architectures import create_unet_model_3d from ..utilities import get_pretrained_network from ..utilities import get_antsxnet_data from ..utilities import preprocess_brain_image from ..utilities import brain_extraction if not isinstance(t1, list): if t1.dimension != 3: raise ValueError("Image dimension must be 3.") ################################ # # Preprocess images # ################################ t1_preprocessed = t1 if do_preprocessing: t1_preprocessing = preprocess_brain_image(t1, truncate_intensity=(0.01, 0.99), brain_extraction_modality="t1", template="croppedMni152", template_transform_type="antsRegistrationSyNQuickRepro[a]", do_bias_correction=True, do_denoising=do_denoising, verbose=verbose) t1_preprocessed = t1_preprocessing["preprocessed_image"] * t1_preprocessing['brain_mask'] ################################ # # Build model and load weights # ################################ patch_size = (112, 112, 112) stride_length = (t1_preprocessed.shape[0] - patch_size[0], t1_preprocessed.shape[1] - patch_size[1], t1_preprocessed.shape[2] - patch_size[2]) classes = ("background", "csf", "gray matter", "white matter", "deep gray matter", "brain stem", "cerebellum") mni_priors = None channel_size = 1 if use_spatial_priors != 0: mni_priors = ants.ndimage_to_list(ants.image_read(get_antsxnet_data("croppedMni152Priors"))) for i in range(len(mni_priors)): mni_priors[i] = ants.copy_image_info(t1_preprocessed, mni_priors[i]) channel_size = 2 unet_model = create_unet_model_3d((*patch_size, channel_size), number_of_outputs=len(classes), mode = "classification", number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0, convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), weight_decay=1e-5, additional_options=("attentionGating")) if verbose: print("DeepAtropos: retrieving model weights.") weights_file_name = '' if use_spatial_priors == 0: weights_file_name = get_pretrained_network("sixTissueOctantBrainSegmentation") elif use_spatial_priors == 1: weights_file_name = get_pretrained_network("sixTissueOctantBrainSegmentationWithPriors1") else: raise ValueError("use_spatial_priors must be a 0 or 1") unet_model.load_weights(weights_file_name) ################################ # # Do prediction and normalize to native space # ################################ if verbose: print("Prediction.") t1_preprocessed = (t1_preprocessed - t1_preprocessed.mean()) / t1_preprocessed.std() image_patches = ants.extract_image_patches(t1_preprocessed, patch_size=patch_size, max_number_of_patches="all", stride_length=stride_length, return_as_array=True) batchX = np.zeros((*image_patches.shape, channel_size)) batchX[:,:,:,:,0] = image_patches if channel_size > 1: prior_patches = ants.extract_image_patches(mni_priors[6], patch_size=patch_size, max_number_of_patches="all", stride_length=stride_length, return_as_array=True) batchX[:,:,:,:,1] = prior_patches predicted_data = unet_model.predict(batchX, verbose=verbose) probability_images = list() for i in range(len(classes)): if verbose: print("Reconstructing image", classes[i]) reconstructed_image = ants.reconstruct_image_from_patches(predicted_data[:,:,:,:,i], domain_image=t1_preprocessed, stride_length=stride_length) if do_preprocessing: probability_images.append(ants.apply_transforms(fixed=t1, moving=reconstructed_image, transformlist=t1_preprocessing['template_transforms']['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose)) else: probability_images.append(reconstructed_image) image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1) segmentation_matrix = np.argmax(image_matrix, axis=0) segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] return_dict = {'segmentation_image' : segmentation_image, 'probability_images' : probability_images} return(return_dict) else: if len(t1) != 3: raise ValueError("Length of input list must be 3. Input images are (in order): [T1, T2, FA]." + "If a particular modality or modalities is not available, use None as a placeholder.") if t1[0] is None: raise ValueError("T1 modality must be specified.") which_network = "" input_images = list() input_images.append(t1[0]) if t1[1] is not None and t1[2] is not None: which_network = "t1_t2_fa" input_images.append(t1[1]) input_images.append(t1[2]) elif t1[1] is not None: which_network = "t1_t2" input_images.append(t1[1]) elif t1[2] is not None: which_network = "t1_fa" input_images.append(t1[2]) else: which_network = "t1" if verbose: print("Prediction using", which_network) ################################ # # Preprocess images # ################################ def truncate_image_intensity(image, truncate_values=[0.01, 0.99]): truncated_image = ants.image_clone(image) quantiles = (truncated_image.quantile(truncate_values[0]), truncated_image.quantile(truncate_values[1])) truncated_image[image < quantiles[0]] = quantiles[0] truncated_image[image > quantiles[1]] = quantiles[1] return truncated_image hcp_t1_template = ants.image_read(get_antsxnet_data("hcpinterT1Template")) hcp_template_brain_mask = ants.image_read(get_antsxnet_data("hcpinterTemplateBrainMask")) hcp_template_brain_segmentation = ants.image_read(get_antsxnet_data("hcpinterTemplateBrainSegmentation")) hcp_t1_template = hcp_t1_template * hcp_template_brain_mask reg = None t1_mask = None preprocessed_images = list() for i in range(len(input_images)): n4 = ants.n4_bias_field_correction(truncate_image_intensity(input_images[i]), mask=input_images[i]*0+1, convergence={'iters': [50, 50, 50, 50], 'tol': 0.0}, rescale_intensities=True, verbose=verbose) if i == 0: t1_bext = brain_extraction(input_images[0], modality="t1threetissue", verbose=verbose) t1_mask = ants.threshold_image(t1_bext['segmentation_image'], 1, 1, 1, 0) n4 = n4 * t1_mask reg = ants.registration(hcp_t1_template, n4, type_of_transform="antsRegistrationSyNQuick[a]", verbose=verbose) preprocessed_images.append(reg['warpedmovout']) else: n4 = n4 * t1_mask n4 = ants.apply_transforms(hcp_t1_template, n4, transformlist=reg['fwdtransforms'], verbose=verbose) preprocessed_images.append(n4) preprocessed_images[i] = ants.iMath_normalize(preprocessed_images[i]) ################################ # # Build model and load weights # ################################ patch_size = (192, 224, 192) stride_length = (hcp_t1_template.shape[0] - patch_size[0], hcp_t1_template.shape[1] - patch_size[1], hcp_t1_template.shape[2] - patch_size[2]) hcp_template_priors = list() for i in range(6): prior = ants.threshold_image(hcp_template_brain_segmentation, i+1, i+1, 1, 0) prior_smooth = ants.smooth_image(prior, 1.0) hcp_template_priors.append(prior_smooth) classes = ("background", "csf", "gray matter", "white matter", "deep gray matter", "brain stem", "cerebellum") number_of_classification_labels = len(classes) channel_size = len(input_images) + len(hcp_template_priors) unet_model = create_unet_model_3d((*patch_size, channel_size), number_of_outputs=number_of_classification_labels, mode="classification", number_of_filters=(16, 32, 64, 128), dropout_rate=0.0, convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), weight_decay=0.0) if verbose: print("DeepAtropos: retrieving model weights.") weights_file_name = "" if which_network == "t1": weights_file_name = get_pretrained_network("DeepAtroposHcpT1Weights") elif which_network == "t1_t2": weights_file_name = get_pretrained_network("DeepAtroposHcpT1T2Weights") elif which_network == "t1_fa": weights_file_name = get_pretrained_network("DeepAtroposHcpT1FAWeights") elif which_network == "t1_t2_fa": weights_file_name = get_pretrained_network("DeepAtroposHcpT1T2FAWeights") unet_model.load_weights(weights_file_name) ################################ # # Do prediction and normalize to native space # ################################ if verbose: print("Prediction.") predicted_data = np.zeros((8, *patch_size, number_of_classification_labels)) batchX = np.zeros((1, *patch_size, channel_size)) for h in range(8): index = 0 for i in range(len(preprocessed_images)): patches = ants.extract_image_patches(preprocessed_images[i], patch_size=patch_size, max_number_of_patches="all", stride_length=stride_length, return_as_array=True) batchX[0,:,:,:,index] = patches[h,:,:,:] index = index + 1 for i in range(len(hcp_template_priors)): patches = ants.extract_image_patches(hcp_template_priors[i], patch_size=patch_size, max_number_of_patches="all", stride_length=stride_length, return_as_array=True) batchX[0,:,:,:,index] = patches[h,:,:,:] index = index + 1 predicted_data[h,:,:,:,:] = unet_model.predict(batchX, verbose=verbose) probability_images = list() for i in range(len(classes)): if verbose: print("Reconstructing image", classes[i]) reconstructed_image = ants.reconstruct_image_from_patches(predicted_data[:,:,:,:,i], domain_image=hcp_t1_template, stride_length=stride_length) if do_preprocessing: probability_images.append(ants.apply_transforms(fixed=input_images[0], moving=reconstructed_image, transformlist=reg['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose)) else: probability_images.append(reconstructed_image) image_matrix = ants.image_list_to_matrix(probability_images, input_images[0] * 0 + 1) segmentation_matrix = np.argmax(image_matrix, axis=0) segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), input_images[0] * 0 + 1)[0] return_dict = {'segmentation_image' : segmentation_image, 'probability_images' : probability_images} return(return_dict)