Source code for antspynet.utilities.brain_extraction

import numpy as np
import tensorflow as tf

import ants

[docs]def brain_extraction(image, modality, verbose=False): """ Perform brain extraction using U-net and ANTs-based training data. "NoBrainer" is also possible where brain extraction uses U-net and FreeSurfer training data ported from the https://github.com/neuronets/nobrainer-models Arguments --------- image : ANTsImage input image (or list of images for multi-modal scenarios). modality : string Modality image type. Options include: * "t1": T1-weighted MRI---ANTs-trained. Previous versions are specified as "t1.v0", "t1.v1". * "t1nobrainer": T1-weighted MRI---FreeSurfer-trained: h/t Satra Ghosh and Jakub Kaczmarzyk. * "t1combined": Brian's combination of "t1" and "t1nobrainer". One can also specify "t1combined[X]" where X is the morphological radius. X = 12 by default. * "t1threetissue": T1-weighted MRI---originally developed from BrainWeb20 (and later expanded). Label 1: brain + subdural CSF, label 2: sinuses + skull, label 3: other head, face, neck tissue. * "t1hemi": Label 1 of "t1threetissue" subdivided into left and right hemispheres. * "t1lobes": Labels 1) frontal, 2) parietal, 3) temporal, 4) occipital. 5) csf, cerebellum, and brain stem. * "flair": FLAIR MRI. Previous versions are specified as "flair.v0". * "t2": T2 MRI. Previous versions are specified as "t2.v0". * "t2star": T2Star MRI. * "bold": 3-D mean BOLD MRI. Previous versions are specified as "bold.v0". * "fa": fractional anisotropy. Previous versions are specified as "fa.v0". * "mra": MRA h/t Tyler Hanson "mmbop". * "t1t2infant": Combined T1-w/T2-w infant MRI h/t Martin Styner. * "t1infant": T1-w infant MRI h/t Martin Styner. * "t2infant": T2-w infant MRI h/t Martin Styner. verbose : boolean Print progress to the screen. Returns ------- ANTs probability brain mask image. Example ------- >>> probability_brain_mask = brain_extraction(brain_image, modality="t1") """ from ..architectures import create_unet_model_3d from ..utilities import get_pretrained_network from ..utilities import get_antsxnet_data from ..architectures import create_nobrainer_unet_model_3d channel_size = 1 if isinstance(image, list): channel_size = len(image) input_images = list() if channel_size == 1: if modality == "t1hemi" or modality == "t1lobes": bext = brain_extraction(image, modality="t1threetissue", verbose=verbose) mask = ants.threshold_image(bext['segmentation_image'], 1, 1, 1, 0) input_images.append(image * mask) else: input_images.append(image) else: input_images = image if input_images[0].dimension != 3: raise ValueError("Image dimension must be 3.") for i in range(len(input_images)): if input_images[i].pixeltype != 'float': input_images[i] = input_images[i].clone('float') if "t1combined" in modality: # Need to change with voxel resolution morphological_radius = 12 if '[' in modality and ']' in modality: morphological_radius = int(modality.split("[")[1].split("]")[0]) brain_extraction_t1 = brain_extraction(image, modality="t1", verbose=verbose) brain_mask = ants.iMath_get_largest_component( ants.threshold_image(brain_extraction_t1, 0.5, 10000)) brain_mask = ants.morphology(brain_mask, "close", morphological_radius).iMath_fill_holes() brain_extraction_t1nobrainer = brain_extraction(image * ants.iMath_MD(brain_mask, radius=morphological_radius), modality = "t1nobrainer", verbose=verbose) brain_extraction_combined = ants.iMath_fill_holes( ants.iMath_get_largest_component(brain_extraction_t1nobrainer * brain_mask)) brain_extraction_combined = brain_extraction_combined + ants.iMath_ME(brain_mask, morphological_radius) + brain_mask return(brain_extraction_combined) if modality != "t1nobrainer": ##################### # # ANTs-based # ##################### weights_file_name_prefix = None is_standard_network = False if modality == "t1.v0": weights_file_name_prefix = "brainExtraction" elif modality == "t1.v1": weights_file_name_prefix = "brainExtractionT1v1" is_standard_network = True elif modality == "t1": weights_file_name_prefix = "brainExtractionRobustT1" is_standard_network = True elif modality == "t2.v0": weights_file_name_prefix = "brainExtractionT2" elif modality == "t2": weights_file_name_prefix = "brainExtractionRobustT2" is_standard_network = True elif modality == "t2star": weights_file_name_prefix = "brainExtractionRobustT2Star" is_standard_network = True elif modality == "flair.v0": weights_file_name_prefix = "brainExtractionFLAIR" elif modality == "flair": weights_file_name_prefix = "brainExtractionRobustFLAIR" is_standard_network = True elif modality == "bold.v0": weights_file_name_prefix = "brainExtractionBOLD" elif modality == "bold": weights_file_name_prefix = "brainExtractionRobustBOLD" is_standard_network = True elif modality == "fa.v0": weights_file_name_prefix = "brainExtractionFA" elif modality == "fa": weights_file_name_prefix = "brainExtractionRobustFA" is_standard_network = True elif modality == "mra": weights_file_name_prefix = "brainExtractionMra" is_standard_network = True elif modality == "t1t2infant": weights_file_name_prefix = "brainExtractionInfantT1T2" elif modality == "t1infant": weights_file_name_prefix = "brainExtractionInfantT1" elif modality == "t2infant": weights_file_name_prefix = "brainExtractionInfantT2" elif modality == "t1threetissue": weights_file_name_prefix = "brainExtractionBrainWeb20" is_standard_network = True elif modality == "t1hemi": weights_file_name_prefix = "brainExtractionT1Hemi" is_standard_network = True elif modality == "t1lobes": weights_file_name_prefix = "brainExtractionT1Lobes" is_standard_network = True else: raise ValueError("Unknown modality type.") if verbose: print("Brain extraction: retrieving model weights.") weights_file_name = get_pretrained_network(weights_file_name_prefix) if verbose: print("Brain extraction: retrieving template.") if modality == "t1threetissue": reorient_template = ants.image_read(get_antsxnet_data("nki")) elif modality == "t1hemi" or modality == "t1lobes": reorient_template = ants.image_read(get_antsxnet_data("hcpyaT1Template")) reorient_template_mask = ants.image_read(get_antsxnet_data("hcpyaTemplateBrainMask")) reorient_template = reorient_template * reorient_template_mask reorient_template = ants.resample_image(reorient_template, (1, 1, 1), use_voxels=False, interp_type=0) reorient_template = ants.pad_or_crop_image_to_size(reorient_template, (160, 192, 160)) xfrm = ants.create_ants_transform(transform_type="Euler3DTransform", center=np.asarray(ants.get_center_of_mass(reorient_template)), translation=(0, 0, -10)) reorient_template = xfrm.apply_to_image(reorient_template) else: reorient_template = ants.image_read(get_antsxnet_data("S_template3")) if is_standard_network and (modality != "t1.v1" and modality != "mra"): ants.set_spacing(reorient_template, (1.5, 1.5, 1.5)) resampled_image_size = reorient_template.shape number_of_filters = (8, 16, 32, 64) number_of_classification_labels = 2 mode = "classification" if is_standard_network: number_of_filters = (16, 32, 64, 128) number_of_classification_labels = 1 mode = "sigmoid" unet_model = None if modality == "t1threetissue" or modality == "t1hemi" or modality == "t1lobes": mode = "classification" if modality == "t1threetissue": number_of_classification_labels = 4 # background, brain, meninges/csf, misc. head elif modality == "t1hemi": number_of_classification_labels = 3 # background, left, right elif modality == "t1lobes": number_of_classification_labels = 6 # background, frontal, parietal, temporal, occipital, misc unet_model = create_unet_model_3d((*resampled_image_size, channel_size), number_of_outputs=number_of_classification_labels, mode=mode, number_of_filters=number_of_filters, dropout_rate=0.0, convolution_kernel_size=3, deconvolution_kernel_size=2, weight_decay=0) else: unet_model = create_unet_model_3d((*resampled_image_size, channel_size), number_of_outputs=number_of_classification_labels, mode=mode, number_of_filters=number_of_filters, dropout_rate=0.0, convolution_kernel_size=3, deconvolution_kernel_size=2, weight_decay=1e-5) unet_model.load_weights(weights_file_name) if verbose: print("Brain extraction: normalizing image to the template.") center_of_mass_template = ants.get_center_of_mass(reorient_template) center_of_mass_image = ants.get_center_of_mass(input_images[0]) translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template) xfrm = ants.create_ants_transform(transform_type="Euler3DTransform", center=np.asarray(center_of_mass_template), translation=translation) batchX = np.zeros((1, *resampled_image_size, channel_size)) for i in range(len(input_images)): warped_image = ants.apply_ants_transform_to_image(xfrm, input_images[i], reorient_template) if is_standard_network and modality != "t1.v1": batchX[0,:,:,:,i] = (ants.iMath(warped_image, "Normalize")).numpy() else: warped_array = warped_image.numpy() batchX[0,:,:,:,i] = (warped_array - warped_array.mean()) / warped_array.std() if verbose: print("Brain extraction: prediction and decoding.") predicted_data = unet_model.predict(batchX, verbose=verbose) probability_images = ants.one_hot_to_segmentation(predicted_data[0,:,:,:], reorient_template) if verbose: print("Brain extraction: renormalize probability mask to native space.") xfrm_inv = xfrm.invert() if modality == "t1threetissue" or modality == "t1hemi" or modality == "t1lobes": probability_images_warped = list() for i in range(number_of_classification_labels): probability_images_warped.append(xfrm_inv.apply_to_image( probability_images[i], input_images[0])) image_matrix = ants.image_list_to_matrix(probability_images_warped, input_images[0] * 0 + 1) segmentation_matrix = np.argmax(image_matrix, axis=0) segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), input_images[0] * 0 + 1)[0] return_dict = {'segmentation_image' : segmentation_image, 'probability_images' : probability_images_warped} return(return_dict) else: probability_image = xfrm_inv.apply_to_image(probability_images[number_of_classification_labels-1], input_images[0]) return(probability_image) else: ##################### # # NoBrainer # ##################### if verbose == True: print("NoBrainer: generating network.") model = create_nobrainer_unet_model_3d((None, None, None, 1)) weights_file_name = get_pretrained_network("brainExtractionNoBrainer") model.load_weights(weights_file_name) if verbose == True: print("NoBrainer: preprocessing (intensity truncation and resampling).") image_array = image.numpy() image_robust_range = np.quantile(image_array[np.where(image_array != 0)], (0.02, 0.98)) threshold_value = 0.10 * (image_robust_range[1] - image_robust_range[0]) + image_robust_range[0] thresholded_mask = ants.threshold_image(image, -10000, threshold_value, 0, 1) thresholded_image = image * thresholded_mask image_resampled = ants.resample_image(thresholded_image, (256, 256, 256), use_voxels=True) image_array = np.expand_dims(image_resampled.numpy(), axis=0) image_array = np.expand_dims(image_array, axis=-1) if verbose == True: print("NoBrainer: predicting mask.") brain_mask_array = np.squeeze(model.predict(image_array, verbose=verbose)) brain_mask_resampled = ants.copy_image_info(image_resampled, ants.from_numpy(brain_mask_array)) brain_mask_image = ants.resample_image(brain_mask_resampled, image.shape, use_voxels=True, interp_type=1) spacing = ants.get_spacing(image) spacing_product = spacing[0] * spacing[1] * spacing[2] minimum_brain_volume = round(649933.7/spacing_product) brain_mask_labeled = ants.label_clusters(brain_mask_image, minimum_brain_volume) return(brain_mask_labeled)