Implement details of dense depth paper
This commit is contained in:
@@ -5,42 +5,37 @@ import tensorflow_datasets as tfds
|
||||
import fast_depth_functional as fd
|
||||
|
||||
|
||||
def dense_upproject(input, out_channels, skip_connection):
|
||||
def dense_upsample_block(input, out_channels, skip_connection):
|
||||
"""
|
||||
Upsample block as described by dense depth in https://arxiv.org/pdf/1812.11941.pdf
|
||||
"""
|
||||
x = keras.layers.UpSampling2D(interpolation='bilinear')(input)
|
||||
x = keras.layers.Concatenate()([x, skip_connection])
|
||||
x = keras.layers.Conv2D(filters=out_channels,
|
||||
kernel_size=3, strides=1, padding='same')(x)
|
||||
x = keras.layers.LeakyReLU(alpha=0.2)(x)
|
||||
x = keras.layers.Conv2D(filters=out_channels,
|
||||
kernel_size=3, strides=1, padding='same')(x)
|
||||
return keras.layers.LeakyReLU(alpha=0.2)(x)
|
||||
|
||||
|
||||
def dense_depth(size, weights=None, shape=(224, 224, 3), half_features=True):
|
||||
def dense_depth(size, weights=None, shape=(224, 224, 3)):
|
||||
input = keras.layers.Input(shape=shape)
|
||||
densenet = dense_net(input, size, weights, shape)
|
||||
densenet_output_shape = densenet.layers[-1].output.shape
|
||||
|
||||
if half_features:
|
||||
decode_filters = densenet_output_shape[-1] // 2
|
||||
else:
|
||||
decode_filters = int(densenet_output_shape[-1])
|
||||
densenet_output_channels = densenet.layers[-1].output.shape[-1]
|
||||
|
||||
# Reduce the feature set (pointwise)
|
||||
decoder = keras.layers.Conv2D(filters=decode_filters, kernel_size=1, padding='same',
|
||||
input_shape=densenet_output_shape, name='conv2')(densenet.output)
|
||||
decoder = keras.layers.Conv2D(filters=densenet_output_channels, kernel_size=1, padding='same')(densenet.output)
|
||||
|
||||
# The actual decoder
|
||||
decoder = dense_upproject(
|
||||
decoder, decode_filters // 2, densenet.get_layer('pool3_pool').output)
|
||||
decoder = dense_upproject(
|
||||
decoder, decode_filters // 4, densenet.get_layer('pool2_pool').output)
|
||||
decoder = dense_upproject(
|
||||
decoder, decode_filters // 8, densenet.get_layer('pool1').output)
|
||||
decoder = dense_upproject(
|
||||
decoder, decode_filters // 16, densenet.get_layer('conv1/relu').output)
|
||||
# Enable to upproject to full image size
|
||||
# decoder = dense_upproject(decoder, int(decode_filters / 32), input)
|
||||
decoder = dense_upsample_block(
|
||||
decoder, densenet_output_channels // 2, densenet.get_layer('pool3_pool').output)
|
||||
decoder = dense_upsample_block(
|
||||
decoder, densenet_output_channels // 4, densenet.get_layer('pool2_pool').output)
|
||||
decoder = dense_upsample_block(
|
||||
decoder, densenet_output_channels // 8, densenet.get_layer('pool1').output)
|
||||
decoder = dense_upsample_block(
|
||||
decoder, densenet_output_channels // 16, densenet.get_layer('conv1/relu').output)
|
||||
|
||||
conv3 = keras.layers.Conv2D(
|
||||
filters=1, kernel_size=3, strides=1, padding='same', name='conv3')(decoder)
|
||||
@@ -89,8 +84,6 @@ def dense_nnconv5(size, weights=None, shape=(224, 224, 3), half_features=True):
|
||||
skip_connection=densenet.get_layer('pool1').output)
|
||||
x = fd.nnconv5(x, densenet.get_layer('conv1/relu').output_shape[3], 4,
|
||||
skip_connection=densenet.get_layer('conv1/relu').output)
|
||||
# Enable to get full dense decode (back to original size)
|
||||
# x = fd.nnconv5(x, int(densenet.get_layer('conv1/relu').output_shape[3] / 2), 5)
|
||||
|
||||
# Final Pointwise for depth extraction
|
||||
x = keras.layers.Conv2D(1, 1, padding='same')(x)
|
||||
|
||||
Reference in New Issue
Block a user