Source code for antspynet.utilities.lung_segmentation


import numpy as np
import ants

[docs]def el_bicho(ventilation_image, mask, use_coarse_slices_only=True, antsxnet_cache_directory=None, verbose=False): """ Perform functional lung segmentation using hyperpolarized gases. https://pubmed.ncbi.nlm.nih.gov/30195415/ Arguments --------- ventilation_image : ANTsImage input ventilation image. mask : ANTsImage input mask. use_coarse_slices_only : boolean If True, apply network only in the dimension of greatest slice thickness. If False, apply to all dimensions and average the results. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be reused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- Ventilation segmentation and corresponding probability images Example ------- >>> image = ants.image_read("ventilation.nii.gz") >>> mask = ants.image_read("mask.nii.gz") >>> lung_seg = el_bicho(image, mask, use_coarse_slices=True, verbose=False) """ from ..architectures import create_unet_model_2d from ..utilities import get_pretrained_network from ..utilities import pad_or_crop_image_to_size if ventilation_image.dimension != 3: raise ValueError("Image dimension must be 3.") if ventilation_image.shape != mask.shape: raise ValueError("Ventilation image and mask size are not the same size.") ################################ # # Preprocess image # ################################ template_size = (256, 256) classes = (0, 1, 2, 3, 4) number_of_classification_labels = len(classes) image_modalities = ("Ventilation", "Mask") channel_size = len(image_modalities) preprocessed_image = (ventilation_image - ventilation_image.mean()) / ventilation_image.std() ants.set_direction(preprocessed_image, np.identity(3)) mask_identity = ants.image_clone(mask) ants.set_direction(mask_identity, np.identity(3)) ################################ # # Build models and load weights # ################################ unet_model = create_unet_model_2d((*template_size, channel_size), number_of_outputs=number_of_classification_labels, number_of_layers=4, number_of_filters_at_base_layer=32, dropout_rate=0.0, convolution_kernel_size=(3, 3), deconvolution_kernel_size=(2, 2), weight_decay=1e-5, additional_options=("attentionGating")) if verbose == True: print("El Bicho: retrieving model weights.") weights_file_name = get_pretrained_network("elBicho", antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) ################################ # # Extract slices # ################################ spacing = ants.get_spacing(preprocessed_image) dimensions_to_predict = (spacing.index(max(spacing)),) if use_coarse_slices_only == False: dimensions_to_predict = list(range(3)) total_number_of_slices = 0 for d in range(len(dimensions_to_predict)): total_number_of_slices += preprocessed_image.shape[dimensions_to_predict[d]] batchX = np.zeros((total_number_of_slices, *template_size, channel_size)) slice_count = 0 for d in range(len(dimensions_to_predict)): number_of_slices = preprocessed_image.shape[dimensions_to_predict[d]] if verbose == True: print("Extracting slices for dimension ", dimensions_to_predict[d], ".") for i in range(number_of_slices): ventilation_slice = pad_or_crop_image_to_size(ants.slice_image(preprocessed_image, dimensions_to_predict[d], i), template_size) batchX[slice_count,:,:,0] = ventilation_slice.numpy() mask_slice = pad_or_crop_image_to_size(ants.slice_image(mask_identity, dimensions_to_predict[d], i), template_size) batchX[slice_count,:,:,1] = mask_slice.numpy() slice_count += 1 ################################ # # Do prediction and then restack into the image # ################################ if verbose == True: print("Prediction.") prediction = unet_model.predict(batchX, verbose=verbose) permutations = list() permutations.append((0, 1, 2)) permutations.append((1, 0, 2)) permutations.append((1, 2, 0)) probability_images = list() for l in range(number_of_classification_labels): probability_images.append(ants.image_clone(mask) * 0) current_start_slice = 0 for d in range(len(dimensions_to_predict)): current_end_slice = current_start_slice + preprocessed_image.shape[dimensions_to_predict[d]] which_batch_slices = range(current_start_slice, current_end_slice) for l in range(number_of_classification_labels): prediction_per_dimension = prediction[which_batch_slices,:,:,l] prediction_array = np.transpose(np.squeeze(prediction_per_dimension), permutations[dimensions_to_predict[d]]) prediction_image = ants.copy_image_info(ventilation_image, pad_or_crop_image_to_size(ants.from_numpy(prediction_array), ventilation_image.shape)) probability_images[l] = probability_images[l] + (prediction_image - probability_images[l]) / (d + 1) current_start_slice = current_end_slice + 1 ################################ # # Convert probability images to segmentation # ################################ image_matrix = ants.image_list_to_matrix(probability_images[1:(len(probability_images))], mask * 0 + 1) background_foreground_matrix = np.stack([ants.image_list_to_matrix([probability_images[0]], mask * 0 + 1), np.expand_dims(np.sum(image_matrix, axis=0), axis=0)]) foreground_matrix = np.argmax(background_foreground_matrix, axis=0) segmentation_matrix = (np.argmax(image_matrix, axis=0) + 1) * foreground_matrix segmentation_image = ants.matrix_to_images( np.expand_dims(segmentation_matrix, axis=0), mask * 0 + 1)[0] return_dict = {'segmentation_image' : segmentation_image, 'probability_images' : probability_images} return(return_dict)
def lung_pulmonary_artery_segmentation(ct, lung_mask=None, prediction_batch_size=16, patch_stride_length=32, antsxnet_cache_directory=None, verbose=False): """ Perform pulmonary artery segmentation. Training data taken from the PARSE2022 challenge (Luo, Gongning, et al. "Efficient automatic segmentation for multi-level pulmonary arteries: The PARSE challenge." https://arxiv.org/abs/2304.03708). Arguments --------- ct : ANTsImage input ct image lung_mask : ANTsImage input binary lung mask which defines the patch extraction. If not supplied, one is estimated. prediction_batch_size : int Control memory usage for prediction. More consequential for GPU-usage. patch_stride_length : 3-D tuple or int Dictates the stride length for accumulating predicting patches. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be reused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- Segmentation probability image Example ------- >>> ct = ants.image_read("ct.nii.gz") """ from ..architectures import create_unet_model_3d from ..utilities import extract_image_patches from ..utilities import reconstruct_image_from_patches from ..utilities import get_pretrained_network from ..utilities import lung_extraction patch_size = (160, 160, 160) if np.any(ct.shape < np.array(patch_size)): raise ValueError("Images must be > 160 voxels per dimension.") ################################ # # Preprocess images # ################################ if lung_mask is None: lung_ex = lung_extraction(ct, modality="ct", verbose=verbose) lung_mask = ants.threshold_image(lung_ex['segmentation_image'], 1, 3, 1, 0) ct_preprocessed = ants.image_clone(ct) ct_preprocessed = (ct_preprocessed + 800) / (500 + 800) ct_preprocessed[ct_preprocessed > 1.0] = 1.0 ct_preprocessed[ct_preprocessed < 0.0] = 0.0 ################################ # # Build model and load weights # ################################ if verbose: print("Load model and weights.") if isinstance(patch_stride_length, int): patch_stride_length = (patch_stride_length,) * 3 number_of_classification_labels = 2 channel_size = 1 model = create_unet_model_3d((*patch_size, channel_size), number_of_outputs=number_of_classification_labels, mode="sigmoid", number_of_filters=(32, 64, 128, 256, 512), convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), dropout_rate=0.0, weight_decay=0) weights_file_name = get_pretrained_network("pulmonaryArteryWeights", antsxnet_cache_directory=antsxnet_cache_directory) model.load_weights(weights_file_name) ################################ # # Extract patches # ################################ if verbose: print("Extract patches.") ct_patches = extract_image_patches(ct_preprocessed, patch_size=patch_size, max_number_of_patches="all", stride_length=patch_stride_length, mask_image=lung_mask, random_seed=None, return_as_array=True) total_number_of_patches = ct_patches.shape[0] ################################ # # Do prediction and then restack into the image # ################################ number_of_batches = total_number_of_patches // prediction_batch_size residual_number_of_patches = total_number_of_patches - number_of_batches * prediction_batch_size if residual_number_of_patches > 0: number_of_batches = number_of_batches + 1 if verbose: print(" Total number of patches: ", str(total_number_of_patches)) print(" Prediction batch size: ", str(prediction_batch_size)) print(" Number of batches: ", str(number_of_batches)) prediction = np.zeros((total_number_of_patches, *patch_size, 2)) for b in range(number_of_batches): batchX = None if b < number_of_batches - 1 or residual_number_of_patches == 0: batchX = np.zeros((prediction_batch_size, *patch_size, channel_size)) else: batchX = np.zeros((residual_number_of_patches, *patch_size, channel_size)) indices = range(b * prediction_batch_size, b * prediction_batch_size + batchX.shape[0]) batchX[:,:,:,:,0] = ct_patches[indices,:,:,:] if verbose: print("Predicting batch ", str(b + 1), " of ", str(number_of_batches)) prediction[indices,:,:,:,:] = model.predict(batchX, verbose=verbose) if verbose: print("Predict patches and reconstruct.") probability_image = reconstruct_image_from_patches(np.squeeze(prediction[:,:,:,:,1]), stride_length=patch_stride_length, domain_image=lung_mask, domain_image_is_mask=True) return(probability_image) def lung_airway_segmentation(ct, lung_mask=None, prediction_batch_size=16, patch_stride_length=32, antsxnet_cache_directory=None, verbose=False): """ Perform pulmonary airway segmentation from CT images. Training data taken from the EXACT09 challenge. Arguments --------- ct : ANTsImage input ct image lung_mask : ANTsImage input binary lung mask which defines the patch extraction (label 1 = left lung, label 2 = right lung, label 3 = main airway). If not supplied, one is estimated. prediction_batch_size : int Control memory usage for prediction. More consequential for GPU-usage. patch_stride_length : 3-D tuple or int Dictates the stride length for accumulating predicting patches. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be reused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- Segmentation probability image Example ------- >>> ct = ants.image_read("ct.nii.gz") """ from ..architectures import create_unet_model_3d from ..utilities import extract_image_patches from ..utilities import reconstruct_image_from_patches from ..utilities import get_pretrained_network from ..utilities import lung_extraction patch_size = (160, 160, 160) if np.any(ct.shape < np.array(patch_size)): raise ValueError("Images must be > 160 voxels per dimension.") ################################ # # Preprocess images # ################################ if lung_mask is None: lung_ex = lung_extraction(ct, modality="ct", verbose=verbose) lung_mask = ants.iMath_MD(lung_ex['segmentation_image'], 2, 3) lung_mask = ants.threshold_image(lung_mask, 1, 3, 1, 0) ct_preprocessed = ants.image_clone(ct) ct_preprocessed = (ct_preprocessed + 800) / (500 + 800) ct_preprocessed[ct_preprocessed > 1.0] = 1.0 ct_preprocessed[ct_preprocessed < 0.0] = 0.0 ################################ # # Build model and load weights # ################################ if verbose: print("Load model and weights.") if isinstance(patch_stride_length, int): patch_stride_length = (patch_stride_length,) * 3 number_of_classification_labels = 2 channel_size = 1 model = create_unet_model_3d((*patch_size, channel_size), number_of_outputs=number_of_classification_labels, mode="classification", number_of_filters=(32, 64, 128, 256, 512), convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), dropout_rate=0.0, weight_decay=0) weights_file_name = get_pretrained_network("pulmonaryAirwayWeights", antsxnet_cache_directory=antsxnet_cache_directory) model.load_weights(weights_file_name) ################################ # # Extract patches # ################################ if verbose: print("Extract patches.") ct_masked = ct_preprocessed * lung_mask ct_patches = extract_image_patches(ct_masked, patch_size=patch_size, max_number_of_patches="all", stride_length=patch_stride_length, mask_image=lung_mask, random_seed=None, return_as_array=True) total_number_of_patches = ct_patches.shape[0] ################################ # # Do prediction and then restack into the image # ################################ number_of_batches = total_number_of_patches // prediction_batch_size residual_number_of_patches = total_number_of_patches - number_of_batches * prediction_batch_size if residual_number_of_patches > 0: number_of_batches = number_of_batches + 1 if verbose: print(" Total number of patches: ", str(total_number_of_patches)) print(" Prediction batch size: ", str(prediction_batch_size)) print(" Number of batches: ", str(number_of_batches)) prediction = np.zeros((total_number_of_patches, *patch_size, 2)) for b in range(number_of_batches): batchX = None if b < number_of_batches - 1 or residual_number_of_patches == 0: batchX = np.zeros((prediction_batch_size, *patch_size, channel_size)) else: batchX = np.zeros((residual_number_of_patches, *patch_size, channel_size)) indices = range(b * prediction_batch_size, b * prediction_batch_size + batchX.shape[0]) batchX[:,:,:,:,0] = ct_patches[indices,:,:,:] if verbose: print("Predicting batch ", str(b + 1), " of ", str(number_of_batches)) prediction[indices,:,:,:,:] = model.predict(batchX, verbose=verbose) if verbose: print("Predict patches and reconstruct.") probability_image = reconstruct_image_from_patches(np.squeeze(prediction[:,:,:,:,1]), stride_length=patch_stride_length, domain_image=lung_mask, domain_image_is_mask=True) return(probability_image)