Source code for antspynet.utilities.brain_age

import statistics
import numpy as np
import tensorflow.keras as keras

import ants

[docs]def brain_age(t1, do_preprocessing=True, number_of_simulations=0, sd_affine=0.01, antsxnet_cache_directory=None, verbose=False): """ Estimate BrainAge from a T1-weighted MR image using the DeepBrainNet architecture and weights described here: https://github.com/vishnubashyam/DeepBrainNet and described in the following article: https://pubmed.ncbi.nlm.nih.gov/32591831/ Preprocessing on the training data consisted of: * n4 bias correction, * brain extraction, and * affine registration to MNI. The input T1 should undergo the same steps. If the input T1 is the raw T1, these steps can be performed by the internal preprocessing, i.e. set do_preprocessing = True Arguments --------- t1 : ANTsImage raw or preprocessed 3-D T1-weighted brain image. do_preprocessing : boolean See description above. number_of_simulations : integer Number of random affine perturbations to transform the input. sd_affine : float Define the standard deviation of the affine transformation parameter. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be reused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- List consisting of the segmentation image and probability images for each label. Example ------- >>> image = ants.image_read("t1.nii.gz") >>> deep = brain_age(image) >>> print("Predicted age: ", deep['predicted_age'] """ from ..utilities import preprocess_brain_image from ..utilities import get_pretrained_network from ..utilities import randomly_transform_image_data if t1.dimension != 3: raise ValueError( "Image dimension must be 3." ) ################################ # # Preprocess images # ################################ t1_preprocessed = t1 if do_preprocessing == True: t1_preprocessing = preprocess_brain_image(t1, truncate_intensity=(0.01, 0.99), brain_extraction_modality="t1", template="croppedMni152", template_transform_type="antsRegistrationSyNQuickRepro[a]", do_bias_correction=True, do_denoising=True, antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) t1_preprocessed = t1_preprocessing["preprocessed_image"] * t1_preprocessing['brain_mask'] t1_preprocessed = (t1_preprocessed - t1_preprocessed.min()) / (t1_preprocessed.max() - t1_preprocessed.min()) ################################ # # Load model and weights # ################################ model_weights_file_name = get_pretrained_network("brainAgeDeepBrainNet", antsxnet_cache_directory=antsxnet_cache_directory) model = keras.models.load_model(model_weights_file_name) # The paper only specifies that 80 slices are used for prediction. I just picked # a reasonable range spanning the center of the brain which_slices = list(range(45, 125)) batchX = np.zeros((len(which_slices), *t1_preprocessed.shape[0:2], 3)) input_image = list() input_image.append(t1_preprocessed) input_image_list = list() input_image_list.append(input_image) if number_of_simulations > 0: data_augmentation = randomly_transform_image_data( reference_image=t1_preprocessed, input_image_list=input_image_list, number_of_simulations=number_of_simulations, transform_type='affine', sd_affine=sd_affine, input_image_interpolator='linear') brain_age_per_slice = None for i in range(number_of_simulations + 1): batch_image = t1_preprocessed if i > 0: batch_image = data_augmentation['simulated_images'][i-1][0] for j in range(len(which_slices)): slice = (ants.slice_image(batch_image, axis=2, idx=which_slices[j])).numpy() batchX[j,:,:,0] = slice batchX[j,:,:,1] = slice batchX[j,:,:,2] = slice if verbose == True: print("Brain age (DeepBrainNet): predicting brain age per slice (batch = ", i, ")") if i == 0: brain_age_per_slice = model.predict(batchX, verbose=verbose) else: prediction = model.predict(batchX, verbose=verbose) brain_age_per_slice = brain_age_per_slice + (prediction - brain_age_per_slice) / (i+1) predicted_age = statistics.median(brain_age_per_slice)[0] return_dict = {'predicted_age' : predicted_age, 'brain_age_per_slice' : brain_age_per_slice} return(return_dict)