107 lines
4.5 KiB
Python
107 lines
4.5 KiB
Python
import tensorflow.keras as keras
|
|
|
|
import fast_depth_functional as fd
|
|
|
|
|
|
def dense_upsample_block(input, out_channels, skip_connection):
|
|
"""
|
|
Upsample block as described by dense depth in https://arxiv.org/pdf/1812.11941.pdf
|
|
"""
|
|
x = keras.layers.UpSampling2D(interpolation='bilinear')(input)
|
|
x = keras.layers.Concatenate()([x, skip_connection])
|
|
x = keras.layers.Conv2D(filters=out_channels,
|
|
kernel_size=3, strides=1, padding='same')(x)
|
|
x = keras.layers.LeakyReLU(alpha=0.2)(x)
|
|
x = keras.layers.Conv2D(filters=out_channels,
|
|
kernel_size=3, strides=1, padding='same')(x)
|
|
return keras.layers.LeakyReLU(alpha=0.2)(x)
|
|
|
|
|
|
def dense_depth(size, weights=None, shape=(224, 224, 3)):
|
|
"""
|
|
Make the dense depth network graph using keras functional api.
|
|
|
|
Note that you should use the dense depth loss function, and use Adam as the optimiser with a learning rate of 0.0001
|
|
(default learning rate of Adam is 0.001).
|
|
:param size:
|
|
:param weights:
|
|
:param shape:
|
|
:return:
|
|
"""
|
|
input = keras.layers.Input(shape=shape)
|
|
densenet = dense_net(input, size, weights, shape)
|
|
|
|
densenet_output_channels = densenet.layers[-1].output.shape[-1]
|
|
|
|
# Reduce the feature set (pointwise)
|
|
decoder = keras.layers.Conv2D(
|
|
filters=densenet_output_channels, kernel_size=1, padding='same')(densenet.output)
|
|
|
|
# The actual decoder
|
|
decoder = dense_upsample_block(
|
|
decoder, densenet_output_channels // 2, densenet.get_layer('pool3_pool').output)
|
|
decoder = dense_upsample_block(
|
|
decoder, densenet_output_channels // 4, densenet.get_layer('pool2_pool').output)
|
|
decoder = dense_upsample_block(
|
|
decoder, densenet_output_channels // 8, densenet.get_layer('pool1').output)
|
|
decoder = dense_upsample_block(
|
|
decoder, densenet_output_channels // 16, densenet.get_layer('conv1/relu').output)
|
|
|
|
decoder = dense_upsample_block(decoder, int(densenet_output_channels / 32), input)
|
|
|
|
conv3 = keras.layers.Conv2D(
|
|
filters=1, kernel_size=3, strides=1, padding='same', name='conv3')(decoder)
|
|
return keras.Model(inputs=input, outputs=conv3, name='dense_depth')
|
|
|
|
|
|
def dense_net(input, size, weights=None, shape=(224, 224, 3)):
|
|
if size == 121:
|
|
densenet = keras.applications.DenseNet121(input_tensor=input, input_shape=shape, weights=weights,
|
|
include_top=False)
|
|
elif size == 169:
|
|
densenet = keras.applications.DenseNet169(input_tensor=input, input_shape=shape, weights=weights,
|
|
include_top=False)
|
|
else:
|
|
densenet = keras.applications.DenseNet201(input_tensor=input, input_shape=shape, weights=weights,
|
|
include_top=False)
|
|
|
|
for layer in densenet.layers:
|
|
layer.trainable = True
|
|
|
|
return densenet
|
|
|
|
|
|
def dense_nnconv5(size, weights=None, shape=(224, 224, 3), half_features=True):
|
|
input = keras.layers.Input(shape=shape)
|
|
densenet = dense_net(input, size, weights, shape)
|
|
densenet_output_shape = densenet.layers[-1].output.shape
|
|
|
|
# Reduce the feature set (pointwise)
|
|
decoder = keras.layers.Conv2D(filters=int(densenet_output_shape[-1]), kernel_size=1, padding='same',
|
|
input_shape=densenet_output_shape, name='conv2')(densenet.output)
|
|
|
|
# TODO: More intermediate layers here?
|
|
|
|
# Fast Depth Decoder
|
|
decoder = fd.nnconv5(decoder, densenet.get_layer('pool3_pool').output_shape[3], 1,
|
|
skip_connection=densenet.get_layer('pool3_pool').output)
|
|
decoder = fd.nnconv5(decoder, densenet.get_layer('pool2_pool').output_shape[3], 2,
|
|
skip_connection=densenet.get_layer('pool2_pool').output)
|
|
decoder = fd.nnconv5(decoder, densenet.get_layer('pool1').output_shape[3], 3,
|
|
skip_connection=densenet.get_layer('pool1').output)
|
|
decoder = fd.nnconv5(decoder, densenet.get_layer('conv1/relu').output_shape[3], 4,
|
|
skip_connection=densenet.get_layer('conv1/relu').output)
|
|
|
|
# Final Pointwise for depth extraction
|
|
decoder = keras.layers.Conv2D(1, 1, padding='same')(decoder)
|
|
decoder = keras.layers.BatchNormalization()(decoder)
|
|
decoder = keras.layers.ReLU(6.)(decoder)
|
|
return keras.Model(inputs=input, outputs=decoder, name="fast_dense_depth")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
model = dense_depth(169, 'imagenet')
|
|
model.summary()
|
|
# model = dense_nnconv5(169, 'imagenet')
|
|
# model.summary()
|