Source code for antspynet.utilities.deep_flash

import numpy as np
import ants

from tensorflow.keras.layers import Conv3D
from tensorflow.keras.models import Model
from tensorflow.keras import regularizers

[docs]def deep_flash(t1, t2=None, which_parcellation="yassa", do_preprocessing=True, use_rank_intensity=True, verbose=False ): """ Hippocampal/Enthorhinal segmentation using "Deep Flash" Perform hippocampal/entorhinal segmentation in T1 and T1/T2 images using labels from Mike Yassa's lab---https://faculty.sites.uci.edu/myassa/ https://www.nature.com/articles/s41598-024-59440-6 The labeling is as follows: Label 0 : background Label 5 : left aLEC Label 6 : right aLEC Label 7 : left pMEC Label 8 : right pMEC Label 9 : left perirhinal Label 10: right perirhinal Label 11: left parahippocampal Label 12: right parahippocampal Label 13: left DG/CA2/CA3/CA4 Label 14: right DG/CA2/CA3/CA4 Label 15: left CA1 Label 16: right CA1 Label 17: left subiculum Label 18: right subiculum Preprocessing on the training data consisted of: * n4 bias correction, * affine registration to the "deep flash" template. which is performed on the input images if do_preprocessing = True. Arguments --------- t1 : ANTsImage raw or preprocessed 3-D T1-weighted brain image. t2 : ANTsImage Optional 3-D T2-weighted brain image for yassa parcellation. If specified, it is assumed to be pre-aligned to the t1. which_parcellation : string --- "yassa" See above label descriptions. do_preprocessing : boolean See description above. use_rank_intensity : boolean If false, use histogram matching with cropped template ROI. Otherwise, use a rank intensity transform on the cropped ROI. Only for "yassa" parcellation. verbose : boolean Print progress to the screen. Returns ------- List consisting of the segmentation image and probability images for each label and foreground. Example ------- >>> image = ants.image_read("t1.nii.gz") >>> flash = deep_flash(image) """ from ..architectures import create_unet_model_3d from ..utilities import get_pretrained_network from ..utilities import get_antsxnet_data from ..utilities import brain_extraction if t1.dimension != 3: raise ValueError("Image dimension must be 3.") if which_parcellation == "yassa": ################################ # # Options temporarily taken from the user # ################################ # use_hierarchical_parcellation : boolean # If True, use u-net model with additional outputs of the medial temporal lobe # region, hippocampal, and entorhinal/perirhinal/parahippocampal regions. Otherwise # the only additional output is the medial temporal lobe. # # use_contralaterality : boolean # Use both hemispherical models to also predict the corresponding contralateral # segmentation and use both sets of priors to produce the results. use_hierarchical_parcellation = True use_contralaterality = True ################################ # # Preprocess images # ################################ t1_preprocessed = t1 t1_mask = None t1_preprocessed_flipped = None t1_template = ants.image_read(get_antsxnet_data("deepFlashTemplateT1SkullStripped")) template_transforms = None if do_preprocessing: if verbose: print("Preprocessing T1.") # Brain extraction probability_mask = brain_extraction(t1_preprocessed, modality="t1", verbose=verbose) t1_mask = ants.threshold_image(probability_mask, 0.5, 1, 1, 0) t1_preprocessed = t1_preprocessed * t1_mask # Do bias correction t1_preprocessed = ants.n4_bias_field_correction(t1_preprocessed, t1_mask, shrink_factor=4, verbose=verbose) # Warp to template registration = ants.registration(fixed=t1_template, moving=t1_preprocessed, type_of_transform="antsRegistrationSyNQuickRepro[a]", verbose=verbose) template_transforms = dict(fwdtransforms=registration['fwdtransforms'], invtransforms=registration['invtransforms']) t1_preprocessed = registration['warpedmovout'] if use_contralaterality: t1_preprocessed_array = t1_preprocessed.numpy() t1_preprocessed_array_flipped = np.flip(t1_preprocessed_array, axis=0) t1_preprocessed_flipped = ants.from_numpy(t1_preprocessed_array_flipped, origin=t1_preprocessed.origin, spacing=t1_preprocessed.spacing, direction=t1_preprocessed.direction) t2_preprocessed = t2 t2_preprocessed_flipped = None t2_template = None if t2 is not None: t2_template = ants.image_read(get_antsxnet_data("deepFlashTemplateT2SkullStripped")) t2_template = ants.copy_image_info(t1_template, t2_template) if do_preprocessing: if verbose: print("Preprocessing T2.") # Brain extraction t2_preprocessed = t2_preprocessed * t1_mask # Do bias correction t2_preprocessed = ants.n4_bias_field_correction(t2_preprocessed, t1_mask, shrink_factor=4, verbose=verbose) # Warp to template t2_preprocessed = ants.apply_transforms(fixed=t1_template, moving=t2_preprocessed, transformlist=template_transforms['fwdtransforms'], verbose=verbose) if use_contralaterality: t2_preprocessed_array = t2_preprocessed.numpy() t2_preprocessed_array_flipped = np.flip(t2_preprocessed_array, axis=0) t2_preprocessed_flipped = ants.from_numpy(t2_preprocessed_array_flipped, origin=t2_preprocessed.origin, spacing=t2_preprocessed.spacing, direction=t2_preprocessed.direction) probability_images = list() labels = (0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18) image_size = (64, 64, 96) ################################ # # Process left/right in split networks # ################################ ################################ # # Download spatial priors # ################################ spatial_priors_file_name_path = get_antsxnet_data("deepFlashPriors") spatial_priors = ants.image_read(spatial_priors_file_name_path) priors_image_list = ants.ndimage_to_list(spatial_priors) for i in range(len(priors_image_list)): priors_image_list[i] = ants.copy_image_info(t1_preprocessed, priors_image_list[i]) labels_left = labels[1::2] priors_image_left_list = priors_image_list[1::2] probability_images_left = list() foreground_probability_images_left = list() lower_bound_left = (76, 74, 56) upper_bound_left = (140, 138, 152) tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left, upper_bound_left) origin_left = tmp_cropped.origin spacing = tmp_cropped.spacing direction = tmp_cropped.direction t1_template_roi_left = ants.crop_indices(t1_template, lower_bound_left, upper_bound_left) t1_template_roi_left = (t1_template_roi_left - t1_template_roi_left.min()) / (t1_template_roi_left.max() - t1_template_roi_left.min()) * 2.0 - 1.0 t2_template_roi_left = None if t2_template is not None: t2_template_roi_left = ants.crop_indices(t2_template, lower_bound_left, upper_bound_left) t2_template_roi_left = (t2_template_roi_left - t2_template_roi_left.min()) / (t2_template_roi_left.max() - t2_template_roi_left.min()) * 2.0 - 1.0 labels_right = labels[2::2] priors_image_right_list = priors_image_list[2::2] probability_images_right = list() foreground_probability_images_right = list() lower_bound_right = (20, 74, 56) upper_bound_right = (84, 138, 152) tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right, upper_bound_right) origin_right = tmp_cropped.origin t1_template_roi_right = ants.crop_indices(t1_template, lower_bound_right, upper_bound_right) t1_template_roi_right = (t1_template_roi_right - t1_template_roi_right.min()) / (t1_template_roi_right.max() - t1_template_roi_right.min()) * 2.0 - 1.0 t2_template_roi_right = None if t2_template is not None: t2_template_roi_right = ants.crop_indices(t2_template, lower_bound_right, upper_bound_right) t2_template_roi_right = (t2_template_roi_right - t2_template_roi_right.min()) / (t2_template_roi_right.max() - t2_template_roi_right.min()) * 2.0 - 1.0 ################################ # # Create model # ################################ channel_size = 1 + len(labels_left) if t2 is not None: channel_size += 1 number_of_classification_labels = 1 + len(labels_left) unet_model = create_unet_model_3d((*image_size, channel_size), number_of_outputs=number_of_classification_labels, mode="classification", number_of_filters=(32, 64, 96, 128, 256), convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), dropout_rate=0.0, weight_decay=0) penultimate_layer = unet_model.layers[-2].output # medial temporal lobe output1 = Conv3D(filters=1, kernel_size=(1, 1, 1), activation='sigmoid', kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) if use_hierarchical_parcellation: # EC, perirhinal, and parahippo. output2 = Conv3D(filters=1, kernel_size=(1, 1, 1), activation='sigmoid', kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) # Hippocampus output3 = Conv3D(filters=1, kernel_size=(1, 1, 1), activation='sigmoid', kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) unet_model = Model(inputs=unet_model.input, outputs=[unet_model.output, output1, output2, output3]) else: unet_model = Model(inputs=unet_model.input, outputs=[unet_model.output, output1]) ################################ # # Left: build model and load weights # ################################ network_name = 'deepFlashLeftT1' if t2 is not None: network_name = 'deepFlashLeftBoth' if use_hierarchical_parcellation: network_name += "Hierarchical" if use_rank_intensity: network_name += "_ri" if verbose: print("DeepFlash: retrieving model weights (left).") weights_file_name = get_pretrained_network(network_name) unet_model.load_weights(weights_file_name) ################################ # # Left: do prediction and normalize to native space # ################################ if verbose: print("Prediction (left).") batchX = None if use_contralaterality: batchX = np.zeros((2, *image_size, channel_size)) else: batchX = np.zeros((1, *image_size, channel_size)) t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left, upper_bound_left) if use_rank_intensity: t1_cropped = ants.rank_intensity(t1_cropped) else: t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_left, 255, 64, False) batchX[0,:,:,:,0] = t1_cropped.numpy() if use_contralaterality: t1_cropped = ants.crop_indices(t1_preprocessed_flipped, lower_bound_left, upper_bound_left) if use_rank_intensity: t1_cropped = ants.rank_intensity(t1_cropped) else: t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_left, 255, 64, False) batchX[1,:,:,:,0] = t1_cropped.numpy() if t2 is not None: t2_cropped = ants.crop_indices(t2_preprocessed, lower_bound_left, upper_bound_left) if use_rank_intensity: t2_cropped = ants.rank_intensity(t2_cropped) else: t2_cropped = ants.histogram_match_image(t2_cropped, t2_template_roi_left, 255, 64, False) batchX[0,:,:,:,1] = t2_cropped.numpy() if use_contralaterality: t2_cropped = ants.crop_indices(t2_preprocessed_flipped, lower_bound_left, upper_bound_left) if use_rank_intensity: t2_cropped = ants.rank_intensity(t2_cropped) else: t2_cropped = ants.histogram_match_image(t2_cropped, t2_template_roi_left, 255, 64, False) batchX[1,:,:,:,1] = t2_cropped.numpy() for i in range(len(priors_image_left_list)): cropped_prior = ants.crop_indices(priors_image_left_list[i], lower_bound_left, upper_bound_left) for j in range(batchX.shape[0]): batchX[j,:,:,:,i + (channel_size - len(labels_left))] = cropped_prior.numpy() predicted_data = unet_model.predict(batchX, verbose=verbose) for i in range(1 + len(labels_left)): for j in range(predicted_data[0].shape[0]): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]), origin=origin_left, spacing=spacing, direction=direction) if i > 0: probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) else: probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) if j == 1: # flipped probability_array_flipped = np.flip(probability_image.numpy(), axis=0) probability_image = ants.from_numpy(probability_array_flipped, origin=probability_image.origin, spacing=probability_image.spacing, direction=probability_image.direction) if do_preprocessing: probability_image = ants.apply_transforms(fixed=t1, moving=probability_image, transformlist=template_transforms['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose) if j == 0: # not flipped probability_images_left.append(probability_image) else: # flipped probability_images_right.append(probability_image) ################################ # # Left: do prediction of mtl, hippocampal, and ec regions and normalize to native space # ################################ for i in range(1, len(predicted_data)): for j in range(predicted_data[i].shape[0]): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]), origin=origin_left, spacing=spacing, direction=direction) probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) if j == 1: # flipped probability_array_flipped = np.flip(probability_image.numpy(), axis=0) probability_image = ants.from_numpy(probability_array_flipped, origin=probability_image.origin, spacing=probability_image.spacing, direction=probability_image.direction) if do_preprocessing: probability_image = ants.apply_transforms(fixed=t1, moving=probability_image, transformlist=template_transforms['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose) if j == 0: # not flipped foreground_probability_images_left.append(probability_image) else: foreground_probability_images_right.append(probability_image) ################################ # # Right: build model and load weights # ################################ network_name = 'deepFlashRightT1' if t2 is not None: network_name = 'deepFlashRightBoth' if use_hierarchical_parcellation: network_name += "Hierarchical" if use_rank_intensity: network_name += "_ri" if verbose: print("DeepFlash: retrieving model weights (right).") weights_file_name = get_pretrained_network(network_name) unet_model.load_weights(weights_file_name) ################################ # # Right: do prediction and normalize to native space # ################################ if verbose: print("Prediction (right).") batchX = None if use_contralaterality: batchX = np.zeros((2, *image_size, channel_size)) else: batchX = np.zeros((1, *image_size, channel_size)) t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right, upper_bound_right) if use_rank_intensity: t1_cropped = ants.rank_intensity(t1_cropped) else: t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_right, 255, 64, False) batchX[0,:,:,:,0] = t1_cropped.numpy() if use_contralaterality: t1_cropped = ants.crop_indices(t1_preprocessed_flipped, lower_bound_right, upper_bound_right) if use_rank_intensity: t1_cropped = ants.rank_intensity(t1_cropped) else: t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_right, 255, 64, False) batchX[1,:,:,:,0] = t1_cropped.numpy() if t2 is not None: t2_cropped = ants.crop_indices(t2_preprocessed, lower_bound_right, upper_bound_right) if use_rank_intensity: t2_cropped = ants.rank_intensity(t2_cropped) else: t2_cropped = ants.histogram_match_image(t2_cropped, t2_template_roi_right, 255, 64, False) batchX[0,:,:,:,1] = t2_cropped.numpy() if use_contralaterality: t2_cropped = ants.crop_indices(t2_preprocessed_flipped, lower_bound_right, upper_bound_right) if use_rank_intensity: t2_cropped = ants.rank_intensity(t2_cropped) else: t2_cropped = ants.histogram_match_image(t2_cropped, t2_template_roi_right, 255, 64, False) batchX[1,:,:,:,1] = t2_cropped.numpy() for i in range(len(priors_image_right_list)): cropped_prior = ants.crop_indices(priors_image_right_list[i], lower_bound_right, upper_bound_right) for j in range(batchX.shape[0]): batchX[j,:,:,:,i + (channel_size - len(labels_right))] = cropped_prior.numpy() predicted_data = unet_model.predict(batchX, verbose=verbose) for i in range(1 + len(labels_right)): for j in range(predicted_data[0].shape[0]): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]), origin=origin_right, spacing=spacing, direction=direction) if i > 0: probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) else: probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) if j == 1: # flipped probability_array_flipped = np.flip(probability_image.numpy(), axis=0) probability_image = ants.from_numpy(probability_array_flipped, origin=probability_image.origin, spacing=probability_image.spacing, direction=probability_image.direction) if do_preprocessing: probability_image = ants.apply_transforms(fixed=t1, moving=probability_image, transformlist=template_transforms['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose) if j == 0: # not flipped if use_contralaterality: probability_images_right[i] = (probability_images_right[i] + probability_image) / 2 else: probability_images_right.append(probability_image) else: # flipped probability_images_left[i] = (probability_images_left[i] + probability_image) / 2 ################################ # # Right: do prediction of mtl, hippocampal, and ec regions and normalize to native space # ################################ for i in range(1, len(predicted_data)): for j in range(predicted_data[i].shape[0]): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]), origin=origin_right, spacing=spacing, direction=direction) probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) if j == 1: # flipped probability_array_flipped = np.flip(probability_image.numpy(), axis=0) probability_image = ants.from_numpy(probability_array_flipped, origin=probability_image.origin, spacing=probability_image.spacing, direction=probability_image.direction) if do_preprocessing: probability_image = ants.apply_transforms(fixed=t1, moving=probability_image, transformlist=template_transforms['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose) if j == 0: # not flipped if use_contralaterality: foreground_probability_images_right[i-1] = (foreground_probability_images_right[i-1] + probability_image) / 2 else: foreground_probability_images_right.append(probability_image) else: foreground_probability_images_left[i-1] = (foreground_probability_images_left[i-1] + probability_image) / 2 ################################ # # Combine priors # ################################ probability_background_image = ants.image_clone(t1) * 0 for i in range(1, len(probability_images_left)): probability_background_image += probability_images_left[i] for i in range(1, len(probability_images_right)): probability_background_image += probability_images_right[i] probability_images.append(probability_background_image * -1 + 1) for i in range(1, len(probability_images_left)): probability_images.append(probability_images_left[i]) probability_images.append(probability_images_right[i]) ################################ # # Convert probability images to segmentation # ################################ # image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1) # segmentation_matrix = np.argmax(image_matrix, axis=0) # segmentation_image = ants.matrix_to_images( # np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] image_matrix = ants.image_list_to_matrix(probability_images[1:(len(probability_images))], t1 * 0 + 1) background_foreground_matrix = np.stack([ants.image_list_to_matrix([probability_images[0]], t1 * 0 + 1), np.expand_dims(np.sum(image_matrix, axis=0), axis=0)]) foreground_matrix = np.argmax(background_foreground_matrix, axis=0) segmentation_matrix = (np.argmax(image_matrix, axis=0) + 1) * foreground_matrix segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] relabeled_image = ants.image_clone(segmentation_image) for i in range(len(labels)): relabeled_image[segmentation_image==i] = labels[i] foreground_probability_images = list() for i in range(len(foreground_probability_images_left)): foreground_probability_images.append(foreground_probability_images_left[i] + foreground_probability_images_right[i]) return_dict = None if use_hierarchical_parcellation: return_dict = {'segmentation_image' : relabeled_image, 'probability_images' : probability_images, 'medial_temporal_lobe_probability_image' : foreground_probability_images[0], 'other_region_probability_image' : foreground_probability_images[1], 'hippocampal_probability_image' : foreground_probability_images[2] } else: return_dict = {'segmentation_image' : relabeled_image, 'probability_images' : probability_images, 'medial_temporal_lobe_probability_image' : foreground_probability_images[0] } return(return_dict) elif which_parcellation == "wip": use_contralaterality = True ################################ # # Preprocess images # ################################ t1_preprocessed = t1 t1_mask = None t1_preprocessed_flipped = None t1_template = ants.image_read(get_antsxnet_data("deepFlashTemplate2T1SkullStripped")) template_transforms = None if do_preprocessing: if verbose: print("Preprocessing T1.") # Brain extraction probability_mask = brain_extraction(t1_preprocessed, modality="t1", verbose=verbose) t1_mask = ants.threshold_image(probability_mask, 0.5, 1, 1, 0) t1_preprocessed = t1_preprocessed * t1_mask # Do bias correction t1_preprocessed = ants.n4_bias_field_correction(t1_preprocessed, t1_mask, shrink_factor=4, verbose=verbose) # Warp to template registration = ants.registration(fixed=t1_template, moving=t1_preprocessed, type_of_transform="antsRegistrationSyNQuickRepro[a]", verbose=verbose) template_transforms = dict(fwdtransforms=registration['fwdtransforms'], invtransforms=registration['invtransforms']) t1_preprocessed = registration['warpedmovout'] if use_contralaterality: t1_preprocessed_array = t1_preprocessed.numpy() t1_preprocessed_array_flipped = np.flip(t1_preprocessed_array, axis=0) t1_preprocessed_flipped = ants.from_numpy(t1_preprocessed_array_flipped, origin=t1_preprocessed.origin, spacing=t1_preprocessed.spacing, direction=t1_preprocessed.direction) probability_images = list() labels_left = list((104, 105, 106, 108, 109, 110, 114, 115, 126, 6001, 6003, 6008, 6009, 6010)) labels_right = list((204, 205, 206, 208, 209, 210, 214, 215, 226, 7001, 7003, 7008, 7009, 7010)) # labels_left = list((103, 104, 105, 106, 108, 109, 110, 111, 112, 114, 115, 126, # 6001, 6003, 6005, 6006, 6007, 6008, 6009, 6010, 6015)) # labels_right = list((203, 204, 205, 206, 208, 209, 210, 211, 212, 214, 215, 226, # 7001, 7003, 7005, 7006, 7007, 7008, 7009, 7010, 7015)) labels = np.array(np.repeat(0, 1 + len(labels_left) + len(labels_right))) labels[1::2] = labels_left labels[2::2] = labels_right image_size = (64, 64, 128) ################################ # # Process left/right in split networks # ################################ ################################ # # Download spatial priors # ################################ prior_labels_file_name_path = get_antsxnet_data("deepFlashTemplate2Labels") prior_labels = ants.image_read(prior_labels_file_name_path) priors_image_left_list = list() for i in range(len(labels_left)): prior_image = ants.threshold_image(prior_labels, labels_left[i], labels_left[i], 1, 0) prior_image = ants.copy_image_info(t1_preprocessed, prior_image) priors_image_left_list.append(ants.smooth_image(prior_image, 1.0)) priors_image_right_list = list() for i in range(len(labels_right)): prior_image = ants.threshold_image(prior_labels, labels_right[i], labels_right[i], 1, 0) prior_image = ants.copy_image_info(t1_preprocessed, prior_image) priors_image_right_list.append(ants.smooth_image(prior_image, 1.0)) probability_images_left = list() foreground_probability_images_left = list() lower_bound_left = (114, 108, 82) upper_bound_left = (178, 172, 210) tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left, upper_bound_left) origin_left = tmp_cropped.origin spacing = tmp_cropped.spacing direction = tmp_cropped.direction t1_template_roi_left = ants.crop_indices(t1_template, lower_bound_left, upper_bound_left) t1_template_roi_left = ((t1_template_roi_left - t1_template_roi_left.min()) / (t1_template_roi_left.max() - t1_template_roi_left.min()) * 2.0 - 1.0) probability_images_right = list() foreground_probability_images_right = list() lower_bound_right = (50, 108, 82) upper_bound_right = (114, 172, 210) tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right, upper_bound_right) origin_right = tmp_cropped.origin t1_template_roi_right = ants.crop_indices(t1_template, lower_bound_right, upper_bound_right) t1_template_roi_right = ((t1_template_roi_right - t1_template_roi_right.min()) / (t1_template_roi_right.max() - t1_template_roi_right.min()) * 2.0 - 1.0) ################################ # # Create model # ################################ channel_size = 1 + len(labels_left) number_of_classification_labels = 1 + len(labels_left) unet_model = create_unet_model_3d((*image_size, channel_size), number_of_outputs=number_of_classification_labels, mode="classification", number_of_filters=(32, 64, 96, 128, 256), convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), dropout_rate=0.0, weight_decay=0) penultimate_layer = unet_model.layers[-2].output # whole complex output1 = Conv3D(filters=1, kernel_size=(1, 1, 1), activation='sigmoid', kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) # hippocampus output2 = Conv3D(filters=1, kernel_size=(1, 1, 1), activation='sigmoid', kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) # amygdala output3 = Conv3D(filters=1, kernel_size=(1, 1, 1), activation='sigmoid', kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) unet_model = Model(inputs=unet_model.input, outputs=[unet_model.output, output1, output2, output3]) ################################ # # Left: build model and load weights # ################################ network_name = 'deepFlash2LeftT1Hierarchical' if verbose: print("DeepFlash: retrieving model weights (left).") weights_file_name = get_pretrained_network(network_name) unet_model.load_weights(weights_file_name) ################################ # # Left: do prediction and normalize to native space # ################################ if verbose: print("Prediction (left).") batchX = None if use_contralaterality: batchX = np.zeros((2, *image_size, channel_size)) else: batchX = np.zeros((1, *image_size, channel_size)) t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left, upper_bound_left) t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_left, 255, 64, False) batchX[0,:,:,:,0] = t1_cropped.numpy() if use_contralaterality: t1_cropped = ants.crop_indices(t1_preprocessed_flipped, lower_bound_left, upper_bound_left) t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_left, 255, 64, False) batchX[1,:,:,:,0] = t1_cropped.numpy() for i in range(len(priors_image_left_list)): cropped_prior = ants.crop_indices(priors_image_left_list[i], lower_bound_left, upper_bound_left) for j in range(batchX.shape[0]): batchX[j,:,:,:,i + (channel_size - len(labels_left))] = cropped_prior.numpy() predicted_data = unet_model.predict(batchX, verbose=verbose) for i in range(1 + len(labels_left)): for j in range(predicted_data[0].shape[0]): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]), origin=origin_left, spacing=spacing, direction=direction) if i > 0: probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) else: probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) if j == 1: # flipped probability_array_flipped = np.flip(probability_image.numpy(), axis=0) probability_image = ants.from_numpy(probability_array_flipped, origin=probability_image.origin, spacing=probability_image.spacing, direction=probability_image.direction) if do_preprocessing: probability_image = ants.apply_transforms(fixed=t1, moving=probability_image, transformlist=template_transforms['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose) if j == 0: # not flipped probability_images_left.append(probability_image) else: # flipped probability_images_right.append(probability_image) ################################ # # Left: do prediction of whole, hippocampal, and amygdala regions and normalize to native space # ################################ for i in range(1, len(predicted_data)): for j in range(predicted_data[i].shape[0]): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]), origin=origin_left, spacing=spacing, direction=direction) probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) if j == 1: # flipped probability_array_flipped = np.flip(probability_image.numpy(), axis=0) probability_image = ants.from_numpy(probability_array_flipped, origin=probability_image.origin, spacing=probability_image.spacing, direction=probability_image.direction) if do_preprocessing: probability_image = ants.apply_transforms(fixed=t1, moving=probability_image, transformlist=template_transforms['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose) if j == 0: # not flipped foreground_probability_images_left.append(probability_image) else: foreground_probability_images_right.append(probability_image) ################################ # # Right: build model and load weights # ################################ network_name = 'deepFlash2RightT1Hierarchical' if verbose: print("DeepFlash: retrieving model weights (right).") weights_file_name = get_pretrained_network(network_name) unet_model.load_weights(weights_file_name) ################################ # # Right: do prediction and normalize to native space # ################################ if verbose: print("Prediction (right).") batchX = None if use_contralaterality: batchX = np.zeros((2, *image_size, channel_size)) else: batchX = np.zeros((1, *image_size, channel_size)) t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right, upper_bound_right) t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_right, 255, 64, False) batchX[0,:,:,:,0] = t1_cropped.numpy() if use_contralaterality: t1_cropped = ants.crop_indices(t1_preprocessed_flipped, lower_bound_right, upper_bound_right) t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_right, 255, 64, False) batchX[1,:,:,:,0] = t1_cropped.numpy() for i in range(len(priors_image_right_list)): cropped_prior = ants.crop_indices(priors_image_right_list[i], lower_bound_right, upper_bound_right) for j in range(batchX.shape[0]): batchX[j,:,:,:,i + (channel_size - len(labels_right))] = cropped_prior.numpy() predicted_data = unet_model.predict(batchX, verbose=verbose) for i in range(1 + len(labels_right)): for j in range(predicted_data[0].shape[0]): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]), origin=origin_right, spacing=spacing, direction=direction) if i > 0: probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) else: probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) if j == 1: # flipped probability_array_flipped = np.flip(probability_image.numpy(), axis=0) probability_image = ants.from_numpy(probability_array_flipped, origin=probability_image.origin, spacing=probability_image.spacing, direction=probability_image.direction) if do_preprocessing: probability_image = ants.apply_transforms(fixed=t1, moving=probability_image, transformlist=template_transforms['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose) if j == 0: # not flipped if use_contralaterality: probability_images_right[i] = (probability_images_right[i] + probability_image) / 2 else: probability_images_right.append(probability_image) else: # flipped probability_images_left[i] = (probability_images_left[i] + probability_image) / 2 ################################ # # Right: do prediction of whole, hippocampal, and amygdala regions and normalize to native space # ################################ for i in range(1, len(predicted_data)): for j in range(predicted_data[i].shape[0]): probability_image = \ ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]), origin=origin_right, spacing=spacing, direction=direction) probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) if j == 1: # flipped probability_array_flipped = np.flip(probability_image.numpy(), axis=0) probability_image = ants.from_numpy(probability_array_flipped, origin=probability_image.origin, spacing=probability_image.spacing, direction=probability_image.direction) if do_preprocessing: probability_image = ants.apply_transforms(fixed=t1, moving=probability_image, transformlist=template_transforms['invtransforms'], whichtoinvert=[True], interpolator="linear", verbose=verbose) if j == 0: # not flipped if use_contralaterality: foreground_probability_images_right[i-1] = (foreground_probability_images_right[i-1] + probability_image) / 2 else: foreground_probability_images_right.append(probability_image) else: foreground_probability_images_left[i-1] = (foreground_probability_images_left[i-1] + probability_image) / 2 ################################ # # Combine priors # ################################ probability_background_image = ants.image_clone(t1) * 0 for i in range(1, len(probability_images_left)): probability_background_image += probability_images_left[i] for i in range(1, len(probability_images_right)): probability_background_image += probability_images_right[i] probability_images.append(probability_background_image * -1 + 1) for i in range(1, len(probability_images_left)): probability_images.append(probability_images_left[i]) probability_images.append(probability_images_right[i]) ################################ # # Convert probability images to segmentation # ################################ # image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1) # segmentation_matrix = np.argmax(image_matrix, axis=0) # segmentation_image = ants.matrix_to_images( # np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] image_matrix = ants.image_list_to_matrix(probability_images[1:(len(probability_images))], t1 * 0 + 1) background_foreground_matrix = np.stack([ants.image_list_to_matrix([probability_images[0]], t1 * 0 + 1), np.expand_dims(np.sum(image_matrix, axis=0), axis=0)]) foreground_matrix = np.argmax(background_foreground_matrix, axis=0) segmentation_matrix = (np.argmax(image_matrix, axis=0) + 1) * foreground_matrix segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] relabeled_image = ants.image_clone(segmentation_image) for i in range(len(labels)): relabeled_image[segmentation_image==i] = labels[i] foreground_probability_images = list() for i in range(len(foreground_probability_images_left)): foreground_probability_images.append(foreground_probability_images_left[i] + foreground_probability_images_right[i]) return_dict = {'segmentation_image' : relabeled_image, 'probability_images' : probability_images, 'whole_probability_image' : foreground_probability_images[0], 'hippocampal_probability_image' : foreground_probability_images[1], 'amygdala_probability_image' : foreground_probability_images[2] } return(return_dict) else: raise ValueError("Unrecognized parcellation.")