Source code for antspynet.utilities.super_resolution_utilities

import ants
import numpy as np
import tensorflow.keras.backend as K

from . import extract_image_patches
from . import regression_match_image

from tensorflow.keras.models import load_model
import tensorflow as tf

import time
from os import path


[docs]def mse(x, y=None): """ Mean square error of a single image or between two images. Arguments --------- x : input image ants input image y : input image ants input image Returns ------- Value. Example ------- >>> r16 = ants.image_read(ants.get_data("r16")) >>> r64 = ants.image_read(ants.get_data("r64")) >>> value = mse(r16, r64) """ if y is None: x2 = x ** 2 return x2.mean() else: diff2 = (x - y) ** 2 return diff2.mean()
[docs]def mae(x, y=None): """ Mean absolute error of a single image or between two images. Arguments --------- x : input image ants input image y : input image ants input image Returns ------- Value Example ------- >>> r16 = ants.image_read(ants.get_data("r16")) >>> r64 = ants.image_read(ants.get_data("r64")) >>> value = mae(r16, r64) """ if y is None: xabs = x.abs() return xabs.mean() else: diffabs = (x-y).abs().mean() return diffabs.mean()
[docs]def psnr(x, y): """ Peak signal-to-noise ratio between two images. Arguments --------- x : input image ants input image y : input image ants input image Returns ------- Value Example ------- >>> r16 = ants.image_read(ants.get_data("r16")) >>> r64 = ants.image_read(ants.get_data("r64")) >>> value = psnr(r16, r64) """ value = 20 * np.log10(x.max()) - 10 * np.log10(mse(x, y)) return value
[docs]def ssim(x, y, K=(0.01, 0.03)): """ Structural similarity index (SSI) between two images. Implementation of the SSI quantity for two images proposed in Z. Wang, A.C. Bovik, H.R. Sheikh, E.P. Simoncelli. "Image quality assessment: from error visibility to structural similarity". IEEE TIP. 13 (4): 600–612. Arguments --------- x : input image ants input image y : input image ants input image K : tuple of length 2 tuple which contain SSI parameters meant to stabilize the formula in case of weak denominators. Returns ------- Value Example ------- >>> r16 = ants.image_read(ants.get_data("r16")) >>> r64 = ants.image_read(ants.get_data("r64")) >>> value = psnr(r16, r64) """ global_max = np.max( ( x.max(), y.max()) ) global_min = np.abs(min( ( x.min(), y.min())) ) L = global_max - global_min C1 = (K[0] * L) ** 2 C2 = (K[1] * L) ** 2 C3 = C2 / 2 mu_x = x.mean() mu_y = y.mean() mu_x_sq = mu_x * mu_x mu_y_sq = mu_y * mu_y mu_xy = mu_x * mu_y sigma_x_sq = (x * x).mean() - mu_x_sq sigma_y_sq = (y * y).mean() - mu_y_sq sigma_xy = (x * y).mean() - mu_xy numerator = (2 * mu_xy + C1) * (2 * sigma_xy + C2) denominator = (mu_x_sq + mu_y_sq + C1) * (sigma_x_sq + sigma_y_sq + C2) SSI = numerator / denominator return SSI
[docs]def gmsd(x, y): """ Gradient magnitude similarity deviation A fast and simple metric that correlates to perceptual quality. Arguments --------- x : input image ants input image y : input image ants input image Returns ------- Value Example ------- >>> r16 = ants.image_read(ants.get_data("r16")) >>> r64 = ants.image_read(ants.get_data("r64")) >>> value = gmsd(r16, r64) """ gx = ants.iMath(x, "Grad") gy = ants.iMath(y, "Grad") # see eqn 4 - 6 in https://arxiv.org/pdf/1308.3052.pdf constant = 0.0026 gmsd_numerator = gx * gy * 2.0 + constant gmsd_denominator = gx ** 2 + gy ** 2 + constant gmsd = gmsd_numerator / gmsd_denominator product_dimension = 1 for i in range(len(x.shape)): product_dimension *= x.shape[i] prefactor = 1.0 / product_dimension return np.sqrt(prefactor * ((gmsd - gmsd.mean()) ** 2).sum())
[docs]def apply_super_resolution_model_to_image( image, model, target_range=(-127.5, 127.5), batch_size=32, regression_order=None, verbose=False, ): """ Apply a pretrained deep back projection model for super resolution. Helper function for applying a pretrained deep back projection model. Apply a patch-wise trained network to perform super-resolution. Can be applied to variable sized inputs. Warning: This function may be better used on CPU unless the GPU can accommodate the full image size. Warning 2: The global intensity range (min to max) of the output will match the input where the range is taken over all channels. Arguments --------- image : ANTs image input image. model : keras object or string pretrained keras model or filename. target_range : 2-element tuple a tuple or array defining the (min, max) of the input image (e.g., -127.5, 127.5). Output images will be scaled back to original intensity. This range should match the mapping used in the training of the network. batch_size : integer Batch size used for the prediction call. regression_order : integer If specified, Apply the function regression_match_image with poly_order=regression_order. verbose : boolean If True, show status messages. Returns ------- Super-resolution image upscaled to resolution specified by the network. Example ------- >>> import ants >>> image = ants.image_read(ants.get_ants_data('r16')) >>> image_sr = apply_super_resolution_model_to_image(image, get_pretrained_network("dbpn4x")) """ tflite_flag = False channel_axis = 0 if K.image_data_format() == "channels_last": channel_axis = -1 if target_range[0] > target_range[1]: target_range = target_range[::-1] start_time = time.time() if isinstance(model, str): if path.isfile(model): if verbose: print("Load model.") if path.splitext(model)[1] == '.tflite': interpreter = tf.lite.Interpreter(model) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() shape_length = len(interpreter.get_input_details()[0]['shape']) tflite_flag = True else: model = load_model(model) shape_length = len(model.input_shape) if verbose: elapsed_time = time.time() - start_time print(" (elapsed time: ", elapsed_time, ")") else: raise ValueError("Model not found.") else: shape_length = len(model.input_shape) if shape_length < 4 | shape_length > 5: raise ValueError("Unexpected input shape.") else: if shape_length == 5 & image.dimension != 3: raise ValueError("Expecting 3D input for this model.") elif shape_length == 4 & image.dimension != 2: raise ValueError("Expecting 2D input for this model.") if channel_axis == -1: channel_axis < shape_length if tflite_flag: channel_size = interpreter.get_input_details()[0]['shape'][channel_axis] else: channel_size = model.input_shape[channel_axis] if channel_size != image.components: raise ValueError( "Channel size of model", str(channel_size), "does not match ncomponents=", str(image.components), "of the input image.", ) image_patches = extract_image_patches( image, patch_size=image.shape, max_number_of_patches=1, stride_length=image.shape, return_as_array=True, ) if image.components == 1: image_patches = np.expand_dims(image_patches, axis=-1) image_patches = image_patches - image_patches.min() image_patches = ( image_patches / image_patches.max() * (target_range[1] - target_range[0]) + target_range[0] ) if verbose: print("Prediction") start_time = time.time() if tflite_flag: image_patches = image_patches.astype('float32') interpreter.set_tensor(input_details[0]['index'], image_patches) interpreter.invoke() out = interpreter.tensor(output_details[0]['index']) prediction = out() else: prediction = model.predict(image_patches, batch_size=batch_size) if verbose: elapsed_time = time.time() - start_time print(" (elapsed time: ", elapsed_time, ")") if verbose: print("Reconstruct intensities") intensity_range = image.range() prediction = prediction - prediction.min() prediction = ( prediction / prediction.max() * (intensity_range[1] - intensity_range[0]) + intensity_range[0] ) def slice_array_channel(input_array, slice, channel_axis=-1): if channel_axis == 0: if shape_length == 4: return input_array[slice, :, :, :] else: return input_array[slice, :, :, :, :] else: if shape_length == 4: return input_array[:, :, :, slice] else: return input_array[:, :, :, :, slice] expansion_factor = np.asarray(prediction.shape) / np.asarray(image_patches.shape) if channel_axis == 0: FIXME expansion_factor = expansion_factor[1 : (len(expansion_factor) - 1)] if verbose: print("ExpansionFactor:", str(expansion_factor)) if image.components == 1: image_array = slice_array_channel(prediction, 0, channel_axis) prediction_image = ants.make_image( (np.asarray(image.shape) * np.asarray(expansion_factor)).astype(int), image_array, ) if regression_order is not None: reference_image = ants.resample_image_to_target(image, prediction_image) prediction_image = regression_match_image( prediction_image, reference_image, poly_order=regression_order ) else: image_component_list = list() for k in range(image.components): image_array = slice_array_channel(prediction, k, channel_axis) image_component_list.append( ants.make_image( (np.asarray(image.shape) * np.asarray(expansion_factor)).astype( int ), image_array, ) ) prediction_image = ants.merge_channels(image_component_list) prediction_image = ants.copy_image_info(image, prediction_image) ants.set_spacing( prediction_image, tuple(np.asarray(image.spacing) / np.asarray(expansion_factor)), ) return prediction_image