Source code for antspynet.utilities.deep_flash

import numpy as np
import ants

from tensorflow.keras.layers import Conv3D
from tensorflow.keras.models import Model
from tensorflow.keras import regularizers

[docs]def deep_flash(t1, t2=None, do_preprocessing=True, use_rank_intensity=True, antsxnet_cache_directory=None, verbose=False ): """ Hippocampal/Enthorhinal segmentation using "Deep Flash" Perform hippocampal/entorhinal segmentation in T1 and T1/T2 images using labels from Mike Yassa's lab---https://faculty.sites.uci.edu/myassa/ https://www.nature.com/articles/s41598-024-59440-6 The labeling is as follows: Label 0 : background Label 5 : left aLEC Label 6 : right aLEC Label 7 : left pMEC Label 8 : right pMEC Label 9 : left perirhinal Label 10: right perirhinal Label 11: left parahippocampal Label 12: right parahippocampal Label 13: left DG/CA2/CA3/CA4 Label 14: right DG/CA2/CA3/CA4 Label 15: left CA1 Label 16: right CA1 Label 17: left subiculum Label 18: right subiculum Preprocessing on the training data consisted of: * n4 bias correction, * affine registration to the "deep flash" template. which is performed on the input images if do_preprocessing = True. Arguments --------- t1 : ANTsImage raw or preprocessed 3-D T1-weighted brain image. t2 : ANTsImage Optional 3-D T2-weighted brain image. If specified, it is assumed to be pre-aligned to the t1. do_preprocessing : boolean See description above. use_rank_intensity : boolean If false, use histogram matching with cropped template ROI. Otherwise, use a rank intensity transform on the cropped ROI. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be reused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- List consisting of the segmentation image and probability images for each label and foreground. Example ------- >>> image = ants.image_read("t1.nii.gz") >>> flash = deep_flash(image) """ from ..architectures import create_unet_model_3d from ..utilities import get_pretrained_network from ..utilities import get_antsxnet_data from ..utilities import brain_extraction if t1.dimension != 3: raise ValueError("Image dimension must be 3.") ################################ # # Options temporarily taken from the user # ################################ # use_hierarchical_parcellation : boolean # If True, use u-net model with additional outputs of the medial temporal lobe # region, hippocampal, and entorhinal/perirhinal/parahippocampal regions. Otherwise # the only additional output is the medial temporal lobe. # # use_contralaterality : boolean # Use both hemispherical models to also predict the corresponding contralateral # segmentation and use both sets of priors to produce the results. use_hierarchical_parcellation = True use_contralaterality = True ################################ # # Preprocess images # ################################ t1_preprocessed = t1 t1_mask = None t1_preprocessed_flipped = None t1_template = ants.image_read(get_antsxnet_data("deepFlashTemplateT1SkullStripped")) template_transforms = None if do_preprocessing: if verbose: print("Preprocessing T1.") # Brain extraction probability_mask = brain_extraction(t1_preprocessed, modality="t1", antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) t1_mask = ants.threshold_image(probability_mask, 0.5, 1, 1, 0) t1_preprocessed = t1_preprocessed * t1_mask # Do bias correction t1_preprocessed = ants.n4_bias_field_correction(t1_preprocessed, t1_mask, shrink_factor=4, verbose=verbose) # Warp to template registration = ants.registration(fixed=t1_template, moving=t1_preprocessed, type_of_transform="antsRegistrationSyNQuickRepro[a]", verbose=verbose) template_transforms = dict(fwdtransforms=registration['fwdtransforms'], invtransforms=registration['invtransforms']) t1_preprocessed = registration['warpedmovout'] if use_contralaterality: t1_preprocessed_array = t1_preprocessed.numpy() t1_preprocessed_array_flipped = np.flip(t1_preprocessed_array, axis=0) t1_preprocessed_flipped = ants.from_numpy(t1_preprocessed_array_flipped, origin=t1_preprocessed.origin, spacing=t1_preprocessed.spacing, direction=t1_preprocessed.direction) t2_preprocessed = t2 t2_preprocessed_flipped = None t2_template = None if t2 is not None: t2_template = ants.image_read(get_antsxnet_data("deepFlashTemplateT2SkullStripped")) t2_template = ants.copy_image_info(t1_template, t2_template) if do_preprocessing: if verbose: print("Preprocessing T2.") # Brain extraction t2_preprocessed = t2_preprocessed * t1_mask # Do bias correction t2_preprocessed = ants.n4_bias_field_correction(t2_preprocessed, t1_mask, shrink_factor=4, verbose=verbose) # Warp to template t2_preprocessed = ants.apply_transforms(fixed=t1_template, moving=t2_preprocessed, transformlist=template_transforms['fwdtransforms'], verbose=verbose) if use_contralaterality: t2_preprocessed_array = t2_preprocessed.numpy() t2_preprocessed_array_flipped = np.flip(t2_preprocessed_array, axis=0) t2_preprocessed_flipped = ants.from_numpy(t2_preprocessed_array_flipped, origin=t2_preprocessed.origin, spacing=t2_preprocessed.spacing, direction=t2_preprocessed.direction) probability_images = list() labels = (0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18) image_size = (64, 64, 96) ################################ # # Process left/right in split networks # ################################ ################################ # # Download spatial priors # ################################ spatial_priors_file_name_path = get_antsxnet_data("deepFlashPriors", antsxnet_cache_directory=antsxnet_cache_directory) spatial_priors = ants.image_read(spatial_priors_file_name_path) priors_image_list = ants.ndimage_to_list(spatial_priors) for i in range(len(priors_image_list)): priors_image_list[i] = ants.copy_image_info(t1_preprocessed, priors_image_list[i]) labels_left = labels[1::2] priors_image_left_list = priors_image_list[1::2] probability_images_left = list() foreground_probability_images_left = list() lower_bound_left = (76, 74, 56) upper_bound_left = (140, 138, 152) tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left, upper_bound_left) origin_left = tmp_cropped.origin spacing = tmp_cropped.spacing direction = tmp_cropped.direction t1_template_roi_left = ants.crop_indices(t1_template, lower_bound_left, upper_bound_left) t1_template_roi_left = (t1_template_roi_left - t1_template_roi_left.min()) / (t1_template_roi_left.max() - t1_template_roi_left.min()) * 2.0 - 1.0 t2_template_roi_left = None if t2_template is not None: t2_template_roi_left = ants.crop_indices(t2_template, lower_bound_left, upper_bound_left) t2_template_roi_left = (t2_template_roi_left - t2_template_roi_left.min()) / (t2_template_roi_left.max() - t2_template_roi_left.min()) * 2.0 - 1.0 labels_right = labels[2::2] priors_image_right_list = priors_image_list[2::2] probability_images_right = list() foreground_probability_images_right = list() lower_bound_right = (20, 74, 56) upper_bound_right = (84, 138, 152) tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right, upper_bound_right) origin_right = tmp_cropped.origin t1_template_roi_right = ants.crop_indices(t1_template, lower_bound_right, upper_bound_right) t1_template_roi_right = (t1_template_roi_right - t1_template_roi_right.min()) / (t1_template_roi_right.max() - t1_template_roi_right.min()) * 2.0 - 1.0 t2_template_roi_right = None if t2_template is not None: t2_template_roi_right = ants.crop_indices(t2_template, lower_bound_right, upper_bound_right) t2_template_roi_right = (t2_template_roi_right - t2_template_roi_right.min()) / (t2_template_roi_right.max() - t2_template_roi_right.min()) * 2.0 - 1.0 ################################ # # Create model # ################################ channel_size = 1 + len(labels_left) if t2 is not None: channel_size += 1 number_of_classification_labels = 1 + len(labels_left) unet_model = create_unet_model_3d((*image_size, channel_size), number_of_outputs=number_of_classification_labels, mode="classification", number_of_filters=(32, 64, 96, 128, 256), convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), dropout_rate=0.0, weight_decay=0) penultimate_layer = unet_model.layers[-2].output # medial temporal lobe output1 = Conv3D(filters=1, kernel_size=(1, 1, 1), activation='sigmoid', kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) if use_hierarchical_parcellation: # EC, perirhinal, and parahippo. output2 = Conv3D(filters=1, kernel_size=(1, 1, 1), activation='sigmoid', kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) # Hippocampus output3 = Conv3D(filters=1, kernel_size=(1, 1, 1), activation='sigmoid', kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) unet_model = Model(inputs=unet_model.input, outputs=[unet_model.output, output1, output2, output3]) else: unet_model = Model(inputs=unet_model.input, outputs=[unet_model.output, output1]) ################################ # # Left: build model and load weights # ################################ network_name = 'deepFlashLeftT1' if t2 is not None: network_name = 'deepFlashLeftBoth' if use_hierarchical_parcellation: network_name += "Hierarchical" if use_rank_intensity: network_name += "_ri" if verbose: print("DeepFlash: retrieving model weights (left).") weights_file_name = get_pretrained_network(network_name, antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) ################################ # # Left: do prediction and normalize to native space # ################################ if verbose: print("Prediction (left).") batchX = None if use_contralaterality: batchX = np.zeros((2, *image_size, channel_size)) else: batchX = np.zeros((1, *image_size, channel_size)) t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left, upper_bound_left) if use_rank_intensity: t1_cropped = ants.rank_intensity(t1_cropped) else: t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_left, 255, 64, False) batchX[0,:,:,:,0] = t1_cropped.numpy() if use_contralaterality: t1_cropped = ants.crop_indices(t1_preprocessed_flipped, lower_bound_left, upper_bound_left) if use_rank_intensity: t1_cropped = ants.rank_intensity(t1_cropped) else: t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_left, 255, 64, False) batchX[1,:,:,:,0] = t1_cropped.numpy() if t2 is not None: t2_cropped = ants.crop_indices(t2_preprocessed, lower_bound_left, upper_bound_left) if use_rank_intensity: t2_cropped = ants.rank_intensity(t2_cropped) else: t2_cropped = ants.histogram_match_image(t2_cropped, t2_template_roi_left, 255, 64, False) batchX[0,:,:,:,1] = t2_cropped.numpy() if use_contralaterality: t2_cropped = ants.crop_indices(t2_preprocessed_flipped, lower_bound_left, upper_bound_left) if use_rank_intensity: t2_cropped = ants.rank_intensity(t2_cropped) else: t2_cropped = ants.histogram_match_image(t2_cropped, t2_template_roi_left, 255, 64, False) batchX[1,:,:,:,1] = t2_cropped.numpy() for i in range(len(priors_image_left_list)): cropped_prior = ants.crop_indices(priors_image_left_list[i], lower_bound_left, upper_bound_left) for j in range(batchX.shape[0]): batchX[j,:,:,:,i + (channel_size - len(labels_left))] = cropped_prior.numpy() predicted_data = unet_model.predict(batchX, verbose=verbose) for i in range(1 + len(labels_left)): for j in range(predicted_data[0].shape[0]): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]), origin=origin_left, spacing=spacing, direction=direction) if i > 0: probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) else: probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) if j == 1: # flipped probability_array_flipped = np.flip(probability_image.numpy(), axis=0) probability_image = ants.from_numpy(probability_array_flipped, origin=probability_image.origin, spacing=probability_image.spacing, direction=probability_image.direction) if do_preprocessing: probability_image = ants.apply_transforms(fixed=t1, moving=probability_image, transformlist=template_transforms['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose) if j == 0: # not flipped probability_images_left.append(probability_image) else: # flipped probability_images_right.append(probability_image) ################################ # # Left: do prediction of mtl, hippocampal, and ec regions and normalize to native space # ################################ for i in range(1, len(predicted_data)): for j in range(predicted_data[i].shape[0]): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]), origin=origin_left, spacing=spacing, direction=direction) probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) if j == 1: # flipped probability_array_flipped = np.flip(probability_image.numpy(), axis=0) probability_image = ants.from_numpy(probability_array_flipped, origin=probability_image.origin, spacing=probability_image.spacing, direction=probability_image.direction) if do_preprocessing: probability_image = ants.apply_transforms(fixed=t1, moving=probability_image, transformlist=template_transforms['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose) if j == 0: # not flipped foreground_probability_images_left.append(probability_image) else: foreground_probability_images_right.append(probability_image) ################################ # # Right: build model and load weights # ################################ network_name = 'deepFlashRightT1' if t2 is not None: network_name = 'deepFlashRightBoth' if use_hierarchical_parcellation: network_name += "Hierarchical" if use_rank_intensity: network_name += "_ri" if verbose: print("DeepFlash: retrieving model weights (right).") weights_file_name = get_pretrained_network(network_name, antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) ################################ # # Right: do prediction and normalize to native space # ################################ if verbose: print("Prediction (right).") batchX = None if use_contralaterality: batchX = np.zeros((2, *image_size, channel_size)) else: batchX = np.zeros((1, *image_size, channel_size)) t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right, upper_bound_right) if use_rank_intensity: t1_cropped = ants.rank_intensity(t1_cropped) else: t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_right, 255, 64, False) batchX[0,:,:,:,0] = t1_cropped.numpy() if use_contralaterality: t1_cropped = ants.crop_indices(t1_preprocessed_flipped, lower_bound_right, upper_bound_right) if use_rank_intensity: t1_cropped = ants.rank_intensity(t1_cropped) else: t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_right, 255, 64, False) batchX[1,:,:,:,0] = t1_cropped.numpy() if t2 is not None: t2_cropped = ants.crop_indices(t2_preprocessed, lower_bound_right, upper_bound_right) if use_rank_intensity: t2_cropped = ants.rank_intensity(t2_cropped) else: t2_cropped = ants.histogram_match_image(t2_cropped, t2_template_roi_right, 255, 64, False) batchX[0,:,:,:,1] = t2_cropped.numpy() if use_contralaterality: t2_cropped = ants.crop_indices(t2_preprocessed_flipped, lower_bound_right, upper_bound_right) if use_rank_intensity: t2_cropped = ants.rank_intensity(t2_cropped) else: t2_cropped = ants.histogram_match_image(t2_cropped, t2_template_roi_right, 255, 64, False) batchX[1,:,:,:,1] = t2_cropped.numpy() for i in range(len(priors_image_right_list)): cropped_prior = ants.crop_indices(priors_image_right_list[i], lower_bound_right, upper_bound_right) for j in range(batchX.shape[0]): batchX[j,:,:,:,i + (channel_size - len(labels_right))] = cropped_prior.numpy() predicted_data = unet_model.predict(batchX, verbose=verbose) for i in range(1 + len(labels_right)): for j in range(predicted_data[0].shape[0]): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]), origin=origin_right, spacing=spacing, direction=direction) if i > 0: probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) else: probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) if j == 1: # flipped probability_array_flipped = np.flip(probability_image.numpy(), axis=0) probability_image = ants.from_numpy(probability_array_flipped, origin=probability_image.origin, spacing=probability_image.spacing, direction=probability_image.direction) if do_preprocessing: probability_image = ants.apply_transforms(fixed=t1, moving=probability_image, transformlist=template_transforms['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose) if j == 0: # not flipped if use_contralaterality: probability_images_right[i] = (probability_images_right[i] + probability_image) / 2 else: probability_images_right.append(probability_image) else: # flipped probability_images_left[i] = (probability_images_left[i] + probability_image) / 2 ################################ # # Right: do prediction of mtl, hippocampal, and ec regions and normalize to native space # ################################ for i in range(1, len(predicted_data)): for j in range(predicted_data[i].shape[0]): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]), origin=origin_right, spacing=spacing, direction=direction) probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) if j == 1: # flipped probability_array_flipped = np.flip(probability_image.numpy(), axis=0) probability_image = ants.from_numpy(probability_array_flipped, origin=probability_image.origin, spacing=probability_image.spacing, direction=probability_image.direction) if do_preprocessing: probability_image = ants.apply_transforms(fixed=t1, moving=probability_image, transformlist=template_transforms['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose) if j == 0: # not flipped if use_contralaterality: foreground_probability_images_right[i-1] = (foreground_probability_images_right[i-1] + probability_image) / 2 else: foreground_probability_images_right.append(probability_image) else: foreground_probability_images_left[i-1] = (foreground_probability_images_left[i-1] + probability_image) / 2 ################################ # # Combine priors # ################################ probability_background_image = ants.image_clone(t1) * 0 for i in range(1, len(probability_images_left)): probability_background_image += probability_images_left[i] for i in range(1, len(probability_images_right)): probability_background_image += probability_images_right[i] probability_images.append(probability_background_image * -1 + 1) for i in range(1, len(probability_images_left)): probability_images.append(probability_images_left[i]) probability_images.append(probability_images_right[i]) ################################ # # Convert probability images to segmentation # ################################ # image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1) # segmentation_matrix = np.argmax(image_matrix, axis=0) # segmentation_image = ants.matrix_to_images( # np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] image_matrix = ants.image_list_to_matrix(probability_images[1:(len(probability_images))], t1 * 0 + 1) background_foreground_matrix = np.stack([ants.image_list_to_matrix([probability_images[0]], t1 * 0 + 1), np.expand_dims(np.sum(image_matrix, axis=0), axis=0)]) foreground_matrix = np.argmax(background_foreground_matrix, axis=0) segmentation_matrix = (np.argmax(image_matrix, axis=0) + 1) * foreground_matrix segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] relabeled_image = ants.image_clone(segmentation_image) for i in range(len(labels)): relabeled_image[segmentation_image==i] = labels[i] foreground_probability_images = list() for i in range(len(foreground_probability_images_left)): foreground_probability_images.append(foreground_probability_images_left[i] + foreground_probability_images_right[i]) return_dict = None if use_hierarchical_parcellation: return_dict = {'segmentation_image' : relabeled_image, 'probability_images' : probability_images, 'medial_temporal_lobe_probability_image' : foreground_probability_images[0], 'other_region_probability_image' : foreground_probability_images[1], 'hippocampal_probability_image' : foreground_probability_images[2] } else: return_dict = {'segmentation_image' : relabeled_image, 'probability_images' : probability_images, 'medial_temporal_lobe_probability_image' : foreground_probability_images[0] } return(return_dict)
def deep_flash_deprecated(t1, do_preprocessing=True, do_per_hemisphere=True, which_hemisphere_models="new", antsxnet_cache_directory=None, verbose=False ): """ Hippocampal/Enthorhinal segmentation using "Deep Flash" Perform hippocampal/entorhinal segmentation in T1 images using labels from Mike Yassa's lab https://faculty.sites.uci.edu/myassa/ The labeling is as follows: Label 0 : background Label 5 : left aLEC Label 6 : right aLEC Label 7 : left pMEC Label 8 : right pMEC Label 9 : left perirhinal Label 10: right perirhinal Label 11: left parahippocampal Label 12: right parahippocampal Label 13: left DG/CA3 Label 14: right DG/CA3 Label 15: left CA1 Label 16: right CA1 Label 17: left subiculum Label 18: right subiculum Preprocessing on the training data consisted of: * n4 bias correction, * denoising, * brain extraction, and * affine registration to MNI. The input T1 should undergo the same steps. If the input T1 is the raw T1, these steps can be performed by the internal preprocessing, i.e. set do_preprocessing = True Arguments --------- t1 : ANTsImage raw or preprocessed 3-D T1-weighted brain image. do_preprocessing : boolean See description above. do_per_hemisphere : boolean If True, do prediction based on separate networks per hemisphere. Otherwise, use the single network trained for both hemispheres. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be reused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- List consisting of the segmentation image and probability images for each label. Example ------- >>> image = ants.image_read("t1.nii.gz") >>> flash = deep_flash(image) """ from ..architectures import create_unet_model_3d from ..utilities import get_pretrained_network from ..utilities import get_antsxnet_data from ..utilities import preprocess_brain_image from ..utilities import pad_or_crop_image_to_size print("This function is deprecated. Please update to deep_flash().") if t1.dimension != 3: raise ValueError("Image dimension must be 3.") if antsxnet_cache_directory == None: antsxnet_cache_directory = "ANTsXNet" ################################ # # Preprocess images # ################################ t1_preprocessed = t1 if do_preprocessing: t1_preprocessing = preprocess_brain_image(t1, truncate_intensity=(0.01, 0.99), brain_extraction_modality="t1", template="croppedMni152", template_transform_type="antsRegistrationSyNQuickRepro[a]", do_bias_correction=True, do_denoising=True, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) t1_preprocessed = t1_preprocessing["preprocessed_image"] * t1_preprocessing['brain_mask'] probability_images = list() labels = (0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18) ################################ # # Process left/right in same network # ################################ if do_per_hemisphere == False: ################################ # # Build model and load weights # ################################ template_size = (160, 192, 160) unet_model = create_unet_model_3d((*template_size, 1), number_of_outputs=len(labels), number_of_layers=4, number_of_filters_at_base_layer=8, dropout_rate=0.0, convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), weight_decay=1e-5, additional_options=("attentionGating",)) if verbose: print("DeepFlash: retrieving model weights.") weights_file_name = get_pretrained_network("deepFlash", antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) ################################ # # Do prediction and normalize to native space # ################################ if verbose: print("Prediction.") cropped_image = pad_or_crop_image_to_size(t1_preprocessed, template_size) batchX = np.expand_dims(cropped_image.numpy(), axis=0) batchX = np.expand_dims(batchX, axis=-1) batchX = (batchX - batchX.mean()) / batchX.std() predicted_data = unet_model.predict(batchX, verbose=verbose) origin = cropped_image.origin spacing = cropped_image.spacing direction = cropped_image.direction for i in range(len(labels)): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]), origin=origin, spacing=spacing, direction=direction) if i > 0: decropped_image = ants.decrop_image(probability_image, t1_preprocessed * 0) else: decropped_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) if do_preprocessing: probability_images.append(ants.apply_transforms(fixed=t1, moving=decropped_image, transformlist=t1_preprocessing['template_transforms']['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose)) else: probability_images.append(decropped_image) ################################ # # Process left/right in split networks # ################################ else: ################################ # # Left: download spatial priors # ################################ spatial_priors_left_file_name_path = get_antsxnet_data("priorDeepFlashLeftLabels", antsxnet_cache_directory=antsxnet_cache_directory) spatial_priors_left = ants.image_read(spatial_priors_left_file_name_path) priors_image_left_list = ants.ndimage_to_list(spatial_priors_left) ################################ # # Left: build model and load weights # ################################ template_size = (64, 96, 96) labels_left = (0, 5, 7, 9, 11, 13, 15, 17) channel_size = 1 + len(labels_left) number_of_filters = 16 network_name = '' if which_hemisphere_models == "old": network_name = "deepFlashLeft16" elif which_hemisphere_models == "new": network_name = "deepFlashLeft16new" else: raise ValueError("network_name must be \"old\" or \"new\".") unet_model = create_unet_model_3d((*template_size, channel_size), number_of_outputs = len(labels_left), number_of_layers = 4, number_of_filters_at_base_layer = number_of_filters, dropout_rate = 0.0, convolution_kernel_size = (3, 3, 3), deconvolution_kernel_size = (2, 2, 2), weight_decay = 1e-5, additional_options=("attentionGating",)) if verbose: print("DeepFlash: retrieving model weights (left).") weights_file_name = get_pretrained_network(network_name, antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) ################################ # # Left: do prediction and normalize to native space # ################################ if verbose: print("Prediction (left).") cropped_image = ants.crop_indices(t1_preprocessed, (30, 51, 0), (94, 147, 96)) image_array = cropped_image.numpy() image_array = (image_array - image_array.mean()) / image_array.std() batchX = np.zeros((1, *template_size, channel_size)) batchX[0,:,:,:,0] = image_array for i in range(len(priors_image_left_list)): cropped_prior = ants.crop_indices(priors_image_left_list[i], (30, 51, 0), (94, 147, 96)) batchX[0,:,:,:,i+1] = cropped_prior.numpy() predicted_data = unet_model.predict(batchX, verbose=verbose) origin = cropped_image.origin spacing = cropped_image.spacing direction = cropped_image.direction probability_images_left = list() for i in range(len(labels_left)): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]), origin=origin, spacing=spacing, direction=direction) if i > 0: decropped_image = ants.decrop_image(probability_image, t1_preprocessed * 0) else: decropped_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) if do_preprocessing: probability_images_left.append(ants.apply_transforms(fixed=t1, moving=decropped_image, transformlist=t1_preprocessing['template_transforms']['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose)) else: probability_images_left.append(decropped_image) ################################ # # Right: download spatial priors # ################################ spatial_priors_right_file_name_path = get_antsxnet_data("priorDeepFlashRightLabels", antsxnet_cache_directory=antsxnet_cache_directory) spatial_priors_right = ants.image_read(spatial_priors_right_file_name_path) priors_image_right_list = ants.ndimage_to_list(spatial_priors_right) ################################ # # Right: build model and load weights # ################################ template_size = (64, 96, 96) labels_right = (0, 6, 8, 10, 12, 14, 16, 18) channel_size = 1 + len(labels_right) number_of_filters = 16 network_name = '' if which_hemisphere_models == "old": network_name = "deepFlashRight16" elif which_hemisphere_models == "new": network_name = "deepFlashRight16new" else: raise ValueError("network_name must be \"old\" or \"new\".") unet_model = create_unet_model_3d((*template_size, channel_size), number_of_outputs = len(labels_right), number_of_layers = 4, number_of_filters_at_base_layer = number_of_filters, dropout_rate = 0.0, convolution_kernel_size = (3, 3, 3), deconvolution_kernel_size = (2, 2, 2), weight_decay = 1e-5, additional_options=("attentionGating",)) weights_file_name = get_pretrained_network(network_name, antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) ################################ # # Right: do prediction and normalize to native space # ################################ if verbose: print("Prediction (right).") cropped_image = ants.crop_indices(t1_preprocessed, (88, 51, 0), (152, 147, 96)) image_array = cropped_image.numpy() image_array = (image_array - image_array.mean()) / image_array.std() batchX = np.zeros((1, *template_size, channel_size)) batchX[0,:,:,:,0] = image_array for i in range(len(priors_image_right_list)): cropped_prior = ants.crop_indices(priors_image_right_list[i], (88, 51, 0), (152, 147, 96)) batchX[0,:,:,:,i+1] = cropped_prior.numpy() predicted_data = unet_model.predict(batchX, verbose=verbose) origin = cropped_image.origin spacing = cropped_image.spacing direction = cropped_image.direction probability_images_right = list() for i in range(len(labels_right)): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]), origin=origin, spacing=spacing, direction=direction) if i > 0: decropped_image = ants.decrop_image(probability_image, t1_preprocessed * 0) else: decropped_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) if do_preprocessing: probability_images_right.append(ants.apply_transforms(fixed=t1, moving=decropped_image, transformlist=t1_preprocessing['template_transforms']['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose)) else: probability_images_right.append(decropped_image) ################################ # # Combine priors # ################################ probability_background_image = ants.image_clone(t1) * 0 for i in range(1, len(probability_images_left)): probability_background_image += probability_images_left[i] for i in range(1, len(probability_images_right)): probability_background_image += probability_images_right[i] probability_images.append(probability_background_image * -1 + 1) for i in range(1, len(probability_images_left)): probability_images.append(probability_images_left[i]) probability_images.append(probability_images_right[i]) ################################ # # Convert probability images to segmentation # ################################ # image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1) # segmentation_matrix = np.argmax(image_matrix, axis=0) # segmentation_image = ants.matrix_to_images( # np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] image_matrix = ants.image_list_to_matrix(probability_images[1:(len(probability_images))], t1 * 0 + 1) background_foreground_matrix = np.stack([ants.image_list_to_matrix([probability_images[0]], t1 * 0 + 1), np.expand_dims(np.sum(image_matrix, axis=0), axis=0)]) foreground_matrix = np.argmax(background_foreground_matrix, axis=0) segmentation_matrix = (np.argmax(image_matrix, axis=0) + 1) * foreground_matrix segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] relabeled_image = ants.image_clone(segmentation_image) for i in range(len(labels)): relabeled_image[segmentation_image==i] = labels[i] return_dict = {'segmentation_image' : relabeled_image, 'probability_images' : probability_images} return(return_dict)