import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Add, Activation, BatchNormalization, Concatenate, Dropout, ReLU, LeakyReLU,
Conv3D, Cropping3D, Conv3DTranspose, Input, Lambda, MaxPooling3D,
ReLU, SpatialDropout3D, UpSampling3D, ZeroPadding3D,
Cropping2D, Conv2D, Conv2DTranspose, MaxPooling2D, UpSampling2D, ZeroPadding2D)
from ..utilities import InstanceNormalization
from ..utilities import ResampleTensorLayer2D, ResampleTensorLayer3D
[docs]def create_nobrainer_unet_model_3d(input_image_size):
"""
Implementation of the "NoBrainer" U-net architecture
Creates a keras model of the U-net deep learning architecture for image
segmentation available at:
https://github.com/neuronets/nobrainer/
Arguments
---------
input_image_size : tuple of length 4
Used for specifying the input tensor shape. The shape (or dimension) of
that tensor is the image dimensions followed by the number of channels
(e.g., red, green, and blue).
Returns
-------
Keras model
A 3-D keras model defining the U-net network.
Example
-------
>>> model = create_nobrainer_unet_model_3d((None, None, None, 1))
>>> model.summary()
"""
number_of_outputs = 1
number_of_filters_at_base_layer = 16
convolution_kernel_size = (3, 3, 3)
deconvolution_kernel_size = (2, 2, 2)
inputs = Input(shape=input_image_size)
# Encoding path
outputs = Conv3D(filters=number_of_filters_at_base_layer,
kernel_size=convolution_kernel_size,
padding='same')(inputs)
outputs = ReLU()(outputs)
outputs = Conv3D(filters=number_of_filters_at_base_layer*2,
kernel_size=convolution_kernel_size,
padding='same')(outputs)
outputs = ReLU()(outputs)
skip1 = outputs
outputs = MaxPooling3D(pool_size=(2, 2, 2))(outputs)
outputs = Conv3D(filters=number_of_filters_at_base_layer*2,
kernel_size=convolution_kernel_size,
padding='same')(outputs)
outputs = ReLU()(outputs)
outputs = Conv3D(filters=number_of_filters_at_base_layer*4,
kernel_size=convolution_kernel_size,
padding='same')(outputs)
outputs = ReLU()(outputs)
skip2 = outputs
outputs = MaxPooling3D(pool_size=(2, 2, 2))(outputs)
outputs = Conv3D(filters=number_of_filters_at_base_layer*4,
kernel_size=convolution_kernel_size,
padding='same')(outputs)
outputs = ReLU()(outputs)
outputs = Conv3D(filters=number_of_filters_at_base_layer*8,
kernel_size=convolution_kernel_size,
padding='same')(outputs)
outputs = ReLU()(outputs)
skip3 = outputs
outputs = MaxPooling3D(pool_size=(2, 2, 2))(outputs)
outputs = Conv3D(filters=number_of_filters_at_base_layer*8,
kernel_size=convolution_kernel_size,
padding='same')(outputs)
outputs = ReLU()(outputs)
outputs = Conv3D(filters=number_of_filters_at_base_layer*16,
kernel_size=convolution_kernel_size,
padding='same')(outputs)
outputs = ReLU()(outputs)
# Decoding path
outputs = Conv3DTranspose(filters=number_of_filters_at_base_layer*16,
kernel_size=deconvolution_kernel_size,
strides=2,
padding='same')(outputs)
outputs = Concatenate()([skip3, outputs])
outputs = Conv3D(filters=number_of_filters_at_base_layer*8,
kernel_size=convolution_kernel_size,
padding='same')(outputs)
outputs = ReLU()(outputs)
outputs = Conv3D(filters=number_of_filters_at_base_layer*8,
kernel_size=convolution_kernel_size,
padding='same')(outputs)
outputs = ReLU()(outputs)
outputs = Conv3DTranspose(filters=number_of_filters_at_base_layer*8,
kernel_size=deconvolution_kernel_size,
strides=2,
padding='same')(outputs)
outputs = Concatenate()([skip2, outputs])
outputs = Conv3D(filters=number_of_filters_at_base_layer*4,
kernel_size=convolution_kernel_size,
padding='same')(outputs)
outputs = ReLU()(outputs)
outputs = Conv3D(filters=number_of_filters_at_base_layer*4,
kernel_size=convolution_kernel_size,
padding='same')(outputs)
outputs = ReLU()(outputs)
outputs = Conv3DTranspose(filters=number_of_filters_at_base_layer*4,
kernel_size=deconvolution_kernel_size,
strides=2,
padding='same')(outputs)
outputs = Concatenate()([skip1, outputs])
outputs = Conv3D(filters=number_of_filters_at_base_layer*2,
kernel_size=convolution_kernel_size,
padding='same')(outputs)
outputs = ReLU()(outputs)
outputs = Conv3D(filters=number_of_filters_at_base_layer*2,
kernel_size=convolution_kernel_size,
padding='same')(outputs)
outputs = ReLU()(outputs)
convActivation = ''
if number_of_outputs == 1:
convActivation = 'sigmoid'
else:
convActivation = 'softmax'
outputs = Conv3D(filters=number_of_outputs,
kernel_size=1,
activation = convActivation)(outputs)
unet_model = Model(inputs=inputs, outputs=outputs)
return unet_model
[docs]def create_hippmapp3r_unet_model_3d(input_image_size,
do_first_network=True,
data_format="channels_last"):
"""
Implementation of the "HippMapp3r" U-net architecture
Creates a keras model implementation of the u-net architecture
described here:
https://onlinelibrary.wiley.com/doi/pdf/10.1002/hbm.24811
with the implementation available here:
https://github.com/mgoubran/HippMapp3r
Arguments
---------
input_image_size : tuple of length 4
Used for specifying the input tensor shape. The shape (or dimension) of
that tensor is the image dimensions followed by the number of channels
(e.g., red, green, and blue).
do_first_network : boolean
Boolean dictating if the model built should be the first (initial) network
or second (refinement) network.
data_format : string
One of "channels_first" or "channels_last". We do this for this specific
architecture as the original weights were saved in "channels_first" format.
Returns
-------
Keras model
A 3-D keras model defining the U-net network.
Example
-------
>>> shape_initial_stage = (160, 160, 128)
>>> model_initial_stage = antspynet.create_hippmapp3r_unet_model_3d((*shape_initial_stage, 1), True)
>>> model_initial_stage.load_weights(antspynet.get_pretrained_network("hippMapp3rInitial"))
>>> shape_refine_stage = (112, 112, 64)
>>> model_refine_stage = antspynet.create_hippmapp3r_unet_model_3d((*shape_refine_stage, 1), False)
>>> model_refine_stage.load_weights(antspynet.get_pretrained_network("hippMapp3rRefine"))
"""
channels_axis = None
if data_format == "channels_last":
channels_axis = 4
elif data_format == "channels_first":
channels_axis = 1
else:
raise ValueError("Unexpected string for data_format.")
def convB_3d_layer(input, number_of_filters, kernel_size=3, strides=1):
block = Conv3D(filters=number_of_filters,
kernel_size=kernel_size,
strides=strides,
padding='same',
data_format=data_format)(input)
block = InstanceNormalization(axis=channels_axis)(block)
block = LeakyReLU()(block)
return(block)
def residual_block_3d(input, number_of_filters):
block = convB_3d_layer(input, number_of_filters)
block = SpatialDropout3D(rate=0.3,
data_format=data_format)(block)
block = convB_3d_layer(block, number_of_filters)
return(block)
def upsample_block_3d(input, number_of_filters):
block = UpSampling3D(data_format=data_format)(input)
block = convB_3d_layer(block, number_of_filters)
return(block)
def feature_block_3d(input, number_of_filters):
block = convB_3d_layer(input, number_of_filters)
block = convB_3d_layer(block, number_of_filters, kernel_size=1)
return(block)
number_of_filters_at_base_layer = 16
number_of_layers = 6
if do_first_network == False:
number_of_layers = 5
inputs = Input(shape=input_image_size)
# Encoding path
add = None
encoding_convolution_layers = []
for i in range(number_of_layers):
number_of_filters = number_of_filters_at_base_layer * 2 ** i
conv = None
if i == 0:
conv = convB_3d_layer(inputs, number_of_filters)
else:
conv = convB_3d_layer(add, number_of_filters, strides=2)
residual_block = residual_block_3d(conv, number_of_filters)
add = Add()([conv, residual_block])
encoding_convolution_layers.append(add)
# Decoding path
outputs = encoding_convolution_layers[number_of_layers-1]
# 256
number_of_filters = (number_of_filters_at_base_layer *
2 ** (number_of_layers - 2))
outputs = upsample_block_3d(outputs, number_of_filters)
if do_first_network == True:
# 256, 128
outputs = Concatenate(axis=channels_axis)([encoding_convolution_layers[4], outputs])
outputs = feature_block_3d(outputs, number_of_filters)
number_of_filters = int(number_of_filters / 2)
outputs = upsample_block_3d(outputs, number_of_filters)
# 128, 64
outputs = Concatenate(axis=channels_axis)([encoding_convolution_layers[3], outputs])
outputs = feature_block_3d(outputs, number_of_filters)
number_of_filters = int(number_of_filters / 2)
outputs = upsample_block_3d(outputs, number_of_filters)
# 64, 32
outputs = Concatenate(axis=channels_axis)([encoding_convolution_layers[2], outputs])
feature64 = feature_block_3d(outputs, number_of_filters)
number_of_filters = int(number_of_filters / 2)
outputs = upsample_block_3d(feature64, number_of_filters)
back64 = None
if do_first_network == True:
back64 = convB_3d_layer(feature64, 1, 1)
else:
back64 = Conv3D(filters=1,
kernel_size=1,
data_format=data_format)(feature64)
back64 = UpSampling3D(data_format=data_format)(back64)
# 32, 16
outputs = Concatenate(axis=channels_axis)([encoding_convolution_layers[1], outputs])
feature32 = feature_block_3d(outputs, number_of_filters)
number_of_filters = int(number_of_filters / 2)
outputs = upsample_block_3d(feature32, number_of_filters)
back32 = None
if do_first_network == True:
back32 = convB_3d_layer(feature32, 1, 1)
else:
back32 = Conv3D(filters=1,
kernel_size=1,
data_format=data_format)(feature32)
back32 = Add()([back64, back32])
back32 = UpSampling3D(data_format=data_format)(back32)
# Final
outputs = Concatenate(axis=channels_axis)([encoding_convolution_layers[0], outputs])
outputs = convB_3d_layer(outputs, number_of_filters, 3)
outputs = convB_3d_layer(outputs, number_of_filters, 1)
if do_first_network == True:
outputs = convB_3d_layer(outputs, 1, 1)
else:
outputs = Conv3D(filters=1,
kernel_size=1,
data_format=data_format)(outputs)
outputs = Add()([back32, outputs])
outputs = Activation('sigmoid')(outputs)
unet_model = Model(inputs=inputs, outputs=outputs)
return(unet_model)
def create_shiva_unet_model_3d(number_of_modalities=1,
number_of_outputs=1):
"""
Implementation of the "shiva" U-net architecture used for PVS and WMH
segmentation.
Publications:
* PVS: https://pubmed.ncbi.nlm.nih.gov/34262443/
* WMH: https://pubmed.ncbi.nlm.nih.gov/38050769/
with respective GitHub repositories:
* PVS: https://github.com/pboutinaud/SHIVA_PVS
* WMH: https://github.com/pboutinaud/SHIVA_WMH
Arguments
---------
number_of_modalities : integer
Specifies number of channels in the architecture.
number_of_outputs : integer
Specifies the number of outputs per voxel. Determines
final activation function (1 = sigmoid, >1 = softmax).
Returns
-------
Keras model
A 3-D keras model defining the U-net network.
Example
-------
>>> model = antspynet.create_shiva_unet_model_3d()
"""
def get_pad_shape(target_layer, reference_layer):
delta = K.int_shape(target_layer)[1] - K.int_shape(reference_layer)[1]
if delta % 2 != 0:
pad_shape_0 = (int(delta/2), int(delta/2) + 1)
else:
pad_shape_0 = (int(delta/2), int(delta/2))
delta = K.int_shape(target_layer)[2] - K.int_shape(reference_layer)[2]
if delta % 2 != 0:
pad_shape_1 = (int(delta/2), int(delta/2) + 1)
else:
pad_shape_1 = (int(delta/2), int(delta/2))
delta = K.int_shape(target_layer)[3] - K.int_shape(reference_layer)[3]
if delta % 2 != 0:
pad_shape_2 = (int(delta/2), int(delta/2) + 1)
else:
pad_shape_2 = (int(delta/2), int(delta/2))
if pad_shape_0 == (0, 0) and pad_shape_1 == (0, 0) and pad_shape_2 == (0, 0):
return None
else:
return((pad_shape_0, pad_shape_1, pad_shape_2))
input_image_size=(160, 214, 176, number_of_modalities)
number_of_filters = (10, 18, 32, 58, 104, 187, 337)
inputs = Input(shape=input_image_size)
# encoding layers
encoding_layers = list()
outputs = inputs
for i in range(len(number_of_filters)):
outputs = Conv3D(filters=number_of_filters[i],
kernel_size=3,
padding='same',
use_bias=False)(outputs)
outputs = BatchNormalization()(outputs)
outputs = Activation('swish')(outputs)
outputs = Conv3D(filters=number_of_filters[i],
kernel_size=3,
padding='same',
use_bias=False)(outputs)
outputs = BatchNormalization()(outputs)
outputs = Activation('swish')(outputs)
encoding_layers.append(outputs)
outputs = MaxPooling3D(pool_size=2)(outputs)
dropout_rate = 0.05
if i > 0:
dropout_rate = 0.5
outputs = Dropout(rate=dropout_rate)(outputs)
# decoding layers
for i in range(len(encoding_layers)-1, -1, -1):
upsample_layer = UpSampling3D(size=2)(outputs)
pad_shape = get_pad_shape(encoding_layers[i], upsample_layer)
if i > 0 and pad_shape is not None:
zero_layer = ZeroPadding3D(padding=pad_shape)(upsample_layer)
outputs = Concatenate(axis=-1)([zero_layer, encoding_layers[i]])
else:
outputs = Concatenate(axis=-1)([upsample_layer, encoding_layers[i]])
outputs = Conv3D(filters=K.int_shape(outputs)[-1],
kernel_size=3,
padding='same',
use_bias=False)(outputs)
outputs = BatchNormalization()(outputs)
outputs = Activation('swish')(outputs)
outputs = Conv3D(filters=number_of_filters[i],
kernel_size=3,
padding='same',
use_bias=False)(outputs)
outputs = BatchNormalization()(outputs)
outputs = Activation('swish')(outputs)
outputs = Dropout(rate=0.5)(outputs)
# final
outputs = Conv3D(filters=10,
kernel_size=3,
padding='same',
use_bias=False)(outputs)
outputs = BatchNormalization()(outputs)
outputs = Activation('swish')(outputs)
outputs = Conv3D(filters=10,
kernel_size=3,
padding='same',
use_bias=False)(outputs)
outputs = BatchNormalization()(outputs)
outputs = Activation('swish')(outputs)
activation = "sigmoid"
if number_of_outputs > 1:
activation = "softmax"
outputs = Conv3D(filters=number_of_outputs,
kernel_size=1,
activation=activation,
padding='same')(outputs)
unet_model = Model(inputs=inputs, outputs=outputs)
return(unet_model)
def create_hypermapp3r_unet_model_3d(input_image_size,
data_format="channels_last"):
"""
Implementation of the "HyperMapp3r" U-net architecture
https://pubmed.ncbi.nlm.nih.gov/35088930/
Arguments
---------
input_image_size : tuple of length 4
Used for specifying the input tensor shape. The shape (or dimension) of
that tensor is the image dimensions followed by the number of channels
(e.g., red, green, and blue).
data_format : string
One of "channels_first" or "channels_last". We do this for this specific
architecture as the original weights were saved in "channels_first" format.
Returns
-------
Keras model
A 3-D keras model defining the U-net network.
Example
-------
>>> model = antspynet.create_hypermapp3r_unet_model_3d((2, 224, 224, 224), True)
>>> model.load_weights(antspynet.get_pretrained_network("hypermapper_224iso_multi"))
"""
channels_axis = None
if data_format == "channels_last":
channels_axis = 4
elif data_format == "channels_first":
channels_axis = 1
else:
raise ValueError("Unexpected string for data_format.")
def convB_3d_layer(input, number_of_filters, kernel_size=3, strides=1):
block = Conv3D(filters=number_of_filters,
kernel_size=kernel_size,
strides=strides,
padding='same',
data_format=data_format)(input)
block = InstanceNormalization(axis=channels_axis)(block)
block = LeakyReLU()(block)
return(block)
def residual_block_3d(input, number_of_filters):
block = convB_3d_layer(input, number_of_filters)
block = SpatialDropout3D(rate=0.3,
data_format=data_format)(block)
block = convB_3d_layer(block, number_of_filters)
return(block)
def upsample_block_3d(input, number_of_filters):
block = UpSampling3D(data_format=data_format)(input)
block = convB_3d_layer(block, number_of_filters)
return(block)
def feature_block_3d(input, number_of_filters):
block = convB_3d_layer(input, number_of_filters)
block = convB_3d_layer(block, number_of_filters, kernel_size=1)
return(block)
number_of_filters_at_base_layer = 8
number_of_layers = 4
inputs = Input(shape=input_image_size)
# Encoding path
add = None
encoding_convolution_layers = []
for i in range(number_of_layers):
number_of_filters = number_of_filters_at_base_layer * 2 ** i
conv = None
if i == 0:
conv = convB_3d_layer(inputs, number_of_filters)
else:
conv = convB_3d_layer(add, number_of_filters, strides=2)
residual_block = residual_block_3d(conv, number_of_filters)
add = Add()([conv, residual_block])
encoding_convolution_layers.append(add)
# Decoding path
outputs = encoding_convolution_layers[number_of_layers-1]
# 64
number_of_filters = (number_of_filters_at_base_layer *
2 ** (number_of_layers - 2))
outputs = upsample_block_3d(outputs, number_of_filters)
# 64, 32
outputs = Concatenate(axis=channels_axis)([encoding_convolution_layers[2], outputs])
feature64 = feature_block_3d(outputs, number_of_filters)
number_of_filters = int(number_of_filters / 2)
outputs = upsample_block_3d(feature64, number_of_filters)
back64 = Conv3D(filters=1,
kernel_size=1,
data_format=data_format)(feature64)
back64 = UpSampling3D(data_format=data_format)(back64)
# 32, 16
outputs = Concatenate(axis=channels_axis)([encoding_convolution_layers[1], outputs])
feature32 = feature_block_3d(outputs, number_of_filters)
number_of_filters = int(number_of_filters / 2)
outputs = upsample_block_3d(feature32, number_of_filters)
back32 = Conv3D(filters=1,
kernel_size=1,
data_format=data_format)(feature32)
back32 = Add()([back64, back32])
back32 = UpSampling3D(data_format=data_format)(back32)
# Final
outputs = Concatenate(axis=channels_axis)([encoding_convolution_layers[0], outputs])
outputs = convB_3d_layer(outputs, number_of_filters, 3)
outputs = convB_3d_layer(outputs, number_of_filters, 1)
outputs = Conv3D(filters=1,
kernel_size=1,
data_format=data_format)(outputs)
outputs = Add()([back32, outputs])
outputs = Activation('sigmoid')(outputs)
unet_model = Model(inputs=inputs, outputs=outputs)
return(unet_model)
def create_sysu_media_unet_model_3d(input_image_size,
number_of_filters=None,
anatomy="wmh"):
"""
3-D variation of the sysu_media U-net architecture
Creates a keras model implementation of the u-net architecture
in the 2017 MICCAI WMH challenge by the sysu_medial team described
here:
https://pubmed.ncbi.nlm.nih.gov/30125711/
with the original implementation available here:
https://github.com/hongweilibran/wmh_ibbmTum
Arguments
---------
input_image_size : tuple
Tuple of length 4, (width, height, depth, channels).
anatomy : string
"wmh" or "claustrum"
Returns
-------
Keras model
A 3-D keras model defining the U-net network.
Example
-------
>>> image_size = (200, 200, 200)
>>> model = antspynet.create_sysu_media_unet_model_3d((*image_size, 1))
"""
def get_crop_shape(target_layer, reference_layer):
delta = K.int_shape(target_layer)[1] - K.int_shape(reference_layer)[1]
if delta % 2 != 0:
cropShape0 = (int(delta/2), int(delta/2) + 1)
else:
cropShape0 = (int(delta/2), int(delta/2))
delta = K.int_shape(target_layer)[2] - K.int_shape(reference_layer)[2]
if delta % 2 != 0:
cropShape1 = (int(delta/2), int(delta/2) + 1)
else:
cropShape1 = (int(delta/2), int(delta/2))
delta = K.int_shape(target_layer)[3] - K.int_shape(reference_layer)[3]
if delta % 2 != 0:
cropShape2 = (int(delta/2), int(delta/2) + 1)
else:
cropShape2 = (int(delta/2), int(delta/2))
return((cropShape0, cropShape1, cropShape2))
inputs = Input(shape=input_image_size)
if number_of_filters is None:
if anatomy == "wmh":
number_of_filters = (64, 96, 128, 256, 512)
elif anatomy == "claustrum":
number_of_filters = (32, 64, 96, 128, 256)
# encoding layers
encoding_layers = list()
outputs = inputs
for i in range(len(number_of_filters)):
kernel1 = 3
kernel2 = 3
if i == 0 and anatomy == "wmh":
kernel1 = 5
kernel2 = 5
elif i == 3:
kernel1 = 3
kernel2 = 4
outputs = Conv3D(filters=number_of_filters[i],
kernel_size=kernel1,
padding='same')(outputs)
outputs = Activation('relu')(outputs)
outputs = Conv3D(filters=number_of_filters[i],
kernel_size=kernel2,
padding='same')(outputs)
outputs = Activation('relu')(outputs)
encoding_layers.append(outputs)
if i < 4:
outputs = MaxPooling3D(pool_size=(2, 2, 2))(outputs)
# decoding layers
for i in range(len(encoding_layers)-2, -1, -1):
upsample_layer = UpSampling3D(size=(2, 2, 2))(outputs)
crop_shape = get_crop_shape(encoding_layers[i], upsample_layer)
cropped_layer = Cropping3D(cropping=crop_shape)(encoding_layers[i])
outputs = Concatenate(axis=-1)([upsample_layer, cropped_layer])
outputs = Conv3D(filters=number_of_filters[i],
kernel_size=3,
padding='same')(outputs)
outputs = Activation('relu')(outputs)
outputs = Conv3D(filters=number_of_filters[i],
kernel_size=3,
padding='same')(outputs)
outputs = Activation('relu')(outputs)
# final
crop_shape = get_crop_shape(inputs, outputs)
outputs = ZeroPadding3D(padding=crop_shape)(outputs)
outputs = Conv3D(filters=1,
kernel_size=1,
activation='sigmoid',
padding='same')(outputs)
unet_model = Model(inputs=inputs, outputs=outputs)
return(unet_model)
[docs]def create_hypothalamus_unet_model_3d(input_image_size):
"""
Implementation of the U-net architecture for hypothalamus segmentation
described in
https://pubmed.ncbi.nlm.nih.gov/32853816/
and ported from the original implementation:
https://github.com/BBillot/hypothalamus_seg
The network has is characterized by the following parameters:
* 3 resolution levels: 24 ---> 48 ---> 96 filters
* convolution: kernel size: (3, 3, 3), activation: 'elu',
* pool size: (2, 2, 2)
Returns
-------
Keras model
A 3-D keras model defining the U-net network.
Example
-------
>>> model = create_hypothalamus_unet_model_3d((160, 160, 160, 1))
"""
convolution_kernel_size = (3, 3, 3)
pool_size = (2, 2, 2)
number_of_outputs = 11
number_of_layers = 3
number_of_filters_at_base_layer = 24
number_of_filters = list()
for i in range(number_of_layers):
number_of_filters.append(number_of_filters_at_base_layer * 2**i)
inputs = Input(shape=(*input_image_size, 1))
# Encoding path
encoding_convolution_layers = []
pool = None
for i in range(number_of_layers):
if i == 0:
conv = Conv3D(filters=number_of_filters[i],
kernel_size=convolution_kernel_size,
padding='same',
activation='elu')(inputs)
else:
conv = Conv3D(filters=number_of_filters[i],
kernel_size=convolution_kernel_size,
padding='same',
activation='elu')(pool)
conv = Conv3D(filters=number_of_filters[i],
kernel_size=convolution_kernel_size,
padding='same',
activation='elu')(conv)
encoding_convolution_layers.append(conv)
conv = BatchNormalization(axis=-1)(conv)
if i < number_of_layers - 1:
pool = MaxPooling3D(pool_size=pool_size)(conv)
else:
outputs = conv
# Decoding path
for i in range(1, number_of_layers):
deconv = UpSampling3D(size=pool_size)(outputs)
outputs = Concatenate(axis=4)([encoding_convolution_layers[number_of_layers-i-1], deconv])
outputs = Conv3D(filters=number_of_filters[number_of_layers-i-1],
kernel_size=convolution_kernel_size,
padding='same',
activation='elu')(outputs)
outputs = Conv3D(filters=number_of_filters[number_of_layers-i-1],
kernel_size=convolution_kernel_size,
padding='same',
activation='elu')(outputs)
outputs = BatchNormalization(axis=-1)(outputs)
outputs = Conv3D(filters=number_of_outputs,
kernel_size=(1, 1, 1),
activation='softmax')(outputs)
unet_model = Model(inputs=inputs, outputs=outputs)
return(unet_model)
def create_partial_convolution_unet_model_2d(input_image_size,
number_of_priors=0,
number_of_filters=(64, 128, 256, 512, 512, 512, 512, 512),
kernel_size=(7, 5, 5, 3, 3, 3, 3, 3),
use_partial_conv=True):
"""
2-D implementation of the U-net architecture for inpainting using partial
convolution.
https://arxiv.org/abs/1804.07723
Arguments
---------
input_image_size : tuple of length 3
Tuple of ints of length 3 specifying 2-D image size and channel size.
number_of_priors : int
Specify tissue priors for use during the decoding branch.
number_of_filters: tuple
Specifies the filter schedule. Defaults to the number of filters used in
the paper.
kernel_size: single scalar or tuple of same length as the number of filters.
Specifies the kernel size schedule for the encoding path. Defaults to the
kernel sizes used in the paper.
use_partial_conv: boolean
Testing. Switch between vanilla convolution layers and partial convolution layers.
Returns
-------
Keras model
A 2-D keras model defining the U-net network.
Example
-------
>>> model = create_partial_convolution_unet_model_2d((256, 256, 1)))
"""
from ..utilities import PartialConv2D
if isinstance(kernel_size, int):
kernel_size = [kernel_size] * len(number_of_filters)
elif len(kernel_size) == 1:
kernel_size = [kernel_size[0]] * len(number_of_filters)
elif len(kernel_size) != len(number_of_filters):
raise ValueError("kernel_size must be a scalar or of equal length as the number_of_filters.")
input_image = Input(input_image_size)
input_mask = Input(input_image_size)
if number_of_priors > 0:
input_priors = Input((input_image_size[0], input_image_size[1], number_of_priors))
inputs = [input_image, input_mask, input_priors]
else:
inputs = [input_image, input_mask]
# Encoding path
number_of_layers = len(number_of_filters)
encoding_convolution_layers = []
pool = None
mask = None
for i in range(number_of_layers):
if i == 0:
if use_partial_conv:
conv, mask = PartialConv2D(filters=number_of_filters[i],
kernel_size=kernel_size[i],
padding="same")([inputs[0], inputs[1]])
else:
conv = Conv2D(filters=number_of_filters[i],
kernel_size=kernel_size[i],
padding='same')(inputs[0])
else:
if use_partial_conv:
mask = ResampleTensorLayer2D(shape=(pool.shape[1], pool.shape[2]),
interpolation_type='nearest_neighbor')(mask)
conv, mask = PartialConv2D(filters=number_of_filters[i],
kernel_size=kernel_size[i],
padding="same")([pool, mask])
else:
conv = Conv2D(filters=number_of_filters[i],
kernel_size=kernel_size[i],
padding='same')(pool)
conv = ReLU()(conv)
if use_partial_conv:
conv, mask = PartialConv2D(filters=number_of_filters[i],
kernel_size=kernel_size[i],
padding="same")([conv, mask])
else:
conv = Conv2D(filters=number_of_filters[i],
kernel_size=kernel_size[i],
padding='same')(conv)
conv = ReLU()(conv)
encoding_convolution_layers.append(conv)
if i < number_of_layers - 1:
pool = MaxPooling2D(pool_size=(2,2),
strides=(2,2))(encoding_convolution_layers[i])
# Decoding path
outputs = encoding_convolution_layers[number_of_layers - 1]
for i in range(1, number_of_layers):
deconv = Conv2DTranspose(filters=number_of_filters[number_of_layers-i-1],
kernel_size=2,
padding='same')(outputs)
deconv = UpSampling2D(size=(2,2))(deconv)
if use_partial_conv:
mask = UpSampling2D(size=(2,2),
interpolation="nearest")(mask)
outputs = Concatenate(axis=3)([deconv, encoding_convolution_layers[number_of_layers-i-1]])
if use_partial_conv:
mask = Lambda(lambda x: tf.repeat(tf.gather(x[0], [0], axis=-1), tf.shape(x[1])[-1], axis=-1))([mask, outputs])
if number_of_priors > 0:
resampled_priors = ResampleTensorLayer2D(shape=(outputs.shape[1], outputs.shape[2]),
interpolation_type='linear')(input_priors)
outputs = Concatenate(axis=3)([outputs, resampled_priors])
if use_partial_conv:
resampled_priors_mask = Lambda(lambda x: tf.ones_like(x))(resampled_priors)
mask = Concatenate(axis=3)([mask, resampled_priors_mask])
if use_partial_conv:
outputs, mask = PartialConv2D(filters=number_of_filters[number_of_layers-i-1],
kernel_size=3,
padding="same")([outputs, mask])
else:
outputs = Conv2D(filters=number_of_filters[number_of_layers-i-1],
kernel_size=3,
padding='same')(outputs)
outputs = ReLU()(outputs)
if use_partial_conv:
outputs, mask = PartialConv2D(filters=number_of_filters[number_of_layers-i-1],
kernel_size=3,
padding="same")([outputs, mask])
else:
outputs = Conv2D(filters=number_of_filters[number_of_layers-i-1],
kernel_size=3,
padding='same')(outputs)
outputs = ReLU()(outputs)
outputs = Conv2D(filters=1,
kernel_size=(1, 1),
activation = 'linear')(outputs)
unet_model = Model(inputs=inputs, outputs=outputs)
return unet_model
def create_partial_convolution_unet_model_3d(input_image_size,
number_of_priors=0,
number_of_filters=(64, 128, 256, 512, 512, 512, 512, 512),
kernel_size=(7, 5, 5, 3, 3, 3, 3, 3),
use_partial_conv=True):
"""
3-D implementation of the U-net architecture for inpainting using partial
convolution.
https://arxiv.org/abs/1804.07723
Arguments
---------
input_image_size : tuple of length 3
Tuple of ints of length 3 specifying 2-D image size and channel size.
number_of_priors : int
Specify tissue priors for use during the decoding branch.
number_of_filters: tuple
Specifies the filter schedule. Defaults to the number of filters used in
the paper.
kernel_size: single scalar or tuple of same length as the number of filters.
Specifies the kernel size schedule for the encoding path. Defaults to the
kernel sizes used in the paper.
use_partial_conv: boolean
Testing. Switch between vanilla convolution layers and partial convolution layers.
Returns
-------
Keras model
A 3-D keras model defining the U-net network.
Example
-------
>>> model = create_partial_convolution_unet_model_3d((256, 256, 256, 1)))
"""
from ..utilities import PartialConv3D
if isinstance(kernel_size, int):
kernel_size = [kernel_size] * len(number_of_filters)
elif len(kernel_size) == 1:
kernel_size = [kernel_size[0]] * len(number_of_filters)
elif len(kernel_size) != len(number_of_filters):
raise ValueError("kernel_size must be a scalar or of equal length as the number_of_filters.")
input_image = Input(input_image_size)
input_mask = Input(input_image_size)
if number_of_priors > 0:
input_priors = Input((input_image_size[0], input_image_size[1], input_image_size[2], number_of_priors))
inputs = [input_image, input_mask, input_priors]
else:
inputs = [input_image, input_mask]
# Encoding path
number_of_layers = len(number_of_filters)
encoding_convolution_layers = []
pool = None
mask = None
for i in range(number_of_layers):
if i == 0:
if use_partial_conv:
conv, mask = PartialConv3D(filters=number_of_filters[i],
kernel_size=kernel_size[i],
padding="same")([inputs[0], inputs[1]])
else:
conv = Conv3D(filters=number_of_filters[i],
kernel_size=kernel_size[i],
padding='same')(inputs[0])
else:
if use_partial_conv:
mask = ResampleTensorLayer3D(shape=(pool.shape[1], pool.shape[2], pool.shape[3]),
interpolation_type='nearest_neighbor')(mask)
conv, mask = PartialConv3D(filters=number_of_filters[i],
kernel_size=kernel_size[i],
padding="same")([pool, mask])
else:
conv = Conv3D(filters=number_of_filters[i],
kernel_size=kernel_size[i],
padding='same')(pool)
conv = ReLU()(conv)
if use_partial_conv:
conv, mask = PartialConv3D(filters=number_of_filters[i],
kernel_size=kernel_size[i],
padding="same")([conv, mask])
else:
conv = Conv3D(filters=number_of_filters[i],
kernel_size=kernel_size[i],
padding='same')(conv)
conv = ReLU()(conv)
encoding_convolution_layers.append(conv)
if i < number_of_layers - 1:
pool = MaxPooling3D(pool_size=(2,2,2),
strides=(2,2,2))(encoding_convolution_layers[i])
# Decoding path
outputs = encoding_convolution_layers[number_of_layers - 1]
for i in range(1, number_of_layers):
deconv = Conv3DTranspose(filters=number_of_filters[number_of_layers-i-1],
kernel_size=2,
padding='same')(outputs)
deconv = UpSampling3D(size=(2,2,2))(deconv)
if use_partial_conv:
mask = UpSampling3D(size=(2,2,2))(mask)
outputs = Concatenate(axis=4)([deconv, encoding_convolution_layers[number_of_layers-i-1]])
if use_partial_conv:
mask = Lambda(lambda x: tf.repeat(tf.gather(x[0], [0], axis=-1), tf.shape(x[1])[-1], axis=-1))([mask, outputs])
if number_of_priors > 0:
resampled_priors = ResampleTensorLayer3D(shape=(outputs.shape[1], outputs.shape[2], outputs.shape[3]),
interpolation_type='linear')(input_priors)
outputs = Concatenate(axis=4)([outputs, resampled_priors])
if use_partial_conv:
resampled_priors_mask = Lambda(lambda x: tf.ones_like(x))(resampled_priors)
mask = Concatenate(axis=4)([mask, resampled_priors_mask])
if use_partial_conv:
outputs, mask = PartialConv3D(filters=number_of_filters[number_of_layers-i-1],
kernel_size=3,
padding="same")([outputs, mask])
else:
outputs = Conv3D(filters=number_of_filters[number_of_layers-i-1],
kernel_size=3,
padding='same')(outputs)
outputs = ReLU()(outputs)
if use_partial_conv:
outputs, mask = PartialConv3D(filters=number_of_filters[number_of_layers-i-1],
kernel_size=3,
padding="same")([outputs, mask])
else:
outputs = Conv3D(filters=number_of_filters[number_of_layers-i-1],
kernel_size=3,
padding='same')(outputs)
outputs = ReLU()(outputs)
outputs = Conv3D(filters=1,
kernel_size=(1, 1, 1),
activation = 'linear')(outputs)
unet_model = Model(inputs=inputs, outputs=outputs)
return unet_model