Files
fast-depth-tf/dense_depth_functional.py
2021-04-14 12:38:51 +09:30

147 lines
5.7 KiB
Python

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow_datasets as tfds
import fast_depth_functional as fd
def dense_upsample_block(input, out_channels, skip_connection):
"""
Upsample block as described by dense depth in https://arxiv.org/pdf/1812.11941.pdf
"""
x = keras.layers.UpSampling2D(interpolation='bilinear')(input)
x = keras.layers.Concatenate()([x, skip_connection])
x = keras.layers.Conv2D(filters=out_channels,
kernel_size=3, strides=1, padding='same')(x)
x = keras.layers.Conv2D(filters=out_channels,
kernel_size=3, strides=1, padding='same')(x)
return keras.layers.LeakyReLU(alpha=0.2)(x)
def dense_depth(size, weights=None, shape=(224, 224, 3)):
input = keras.layers.Input(shape=shape)
densenet = dense_net(input, size, weights, shape)
densenet_output_channels = densenet.layers[-1].output.shape[-1]
# Reduce the feature set (pointwise)
decoder = keras.layers.Conv2D(filters=densenet_output_channels, kernel_size=1, padding='same')(densenet.output)
# The actual decoder
decoder = dense_upsample_block(
decoder, densenet_output_channels // 2, densenet.get_layer('pool3_pool').output)
decoder = dense_upsample_block(
decoder, densenet_output_channels // 4, densenet.get_layer('pool2_pool').output)
decoder = dense_upsample_block(
decoder, densenet_output_channels // 8, densenet.get_layer('pool1').output)
decoder = dense_upsample_block(
decoder, densenet_output_channels // 16, densenet.get_layer('conv1/relu').output)
conv3 = keras.layers.Conv2D(
filters=1, kernel_size=3, strides=1, padding='same', name='conv3')(decoder)
return keras.Model(inputs=input, outputs=conv3, name='dense_depth')
def dense_net(input, size, weights=None, shape=(224, 224, 3)):
if size == 121:
densenet = keras.applications.DenseNet121(input_tensor=input, input_shape=shape, weights=weights,
include_top=False)
elif size == 169:
densenet = keras.applications.DenseNet169(input_tensor=input, input_shape=shape, weights=weights,
include_top=False)
else:
densenet = keras.applications.DenseNet201(input_tensor=input, input_shape=shape, weights=weights,
include_top=False)
for layer in densenet.layers:
layer.trainable = True
return densenet
def dense_nnconv5(size, weights=None, shape=(224, 224, 3), half_features=True):
input = keras.layers.Input(shape=shape)
densenet = dense_net(input, size, weights, shape)
densenet_output_shape = densenet.layers[-1].output.shape
if half_features:
decode_filters = int(int(densenet_output_shape[-1]) / 2)
else:
decode_filters = int(densenet_output_shape[-1])
# Reduce the feature set (pointwise)
x = keras.layers.Conv2D(filters=decode_filters, kernel_size=1, padding='same',
input_shape=densenet_output_shape, name='conv2')(densenet.output)
# TODO: More intermediate layers here?
# Fast Depth Decoder
x = fd.nnconv5(x, densenet.get_layer('pool3_pool').output_shape[3], 1,
skip_connection=densenet.get_layer('pool3_pool').output)
x = fd.nnconv5(x, densenet.get_layer('pool2_pool').output_shape[3], 2,
skip_connection=densenet.get_layer('pool2_pool').output)
x = fd.nnconv5(x, densenet.get_layer('pool1').output_shape[3], 3,
skip_connection=densenet.get_layer('pool1').output)
x = fd.nnconv5(x, densenet.get_layer('conv1/relu').output_shape[3], 4,
skip_connection=densenet.get_layer('conv1/relu').output)
# Final Pointwise for depth extraction
x = keras.layers.Conv2D(1, 1, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU(6.)(x)
return keras.Model(inputs=input, outputs=x, name="fast_dense_depth")
def crop_and_resize(x):
shape = tf.shape(x['depth'])
def layer():
return keras.Sequential([
keras.layers.experimental.preprocessing.CenterCrop(
shape[1], shape[2]),
keras.layers.experimental.preprocessing.Resizing(
224, 224, interpolation='nearest')
])
def half_layer():
return keras.Sequential([
keras.layers.experimental.preprocessing.CenterCrop(
shape[1], shape[2]),
keras.layers.experimental.preprocessing.Resizing(
112, 112, interpolation='nearest')
])
# Reshape label to 4d, can't use array unwrap as it's unsupported by tensorflow
return layer()(x['image']), half_layer()(tf.reshape(x['depth'], [shape[0], shape[1], shape[2], 1]))
def load_nyu():
"""
Load the nyu_v2 dataset train split. Will be downloaded to ../nyu
:return: nyu_v2 dataset builder
"""
builder = tfds.builder('nyu_depth_v2')
builder.download_and_prepare(download_dir='../nyu')
return builder \
.as_dataset(split='train', shuffle_files=True) \
.shuffle(buffer_size=1024) \
.batch(8) \
.map(lambda x: crop_and_resize(x))
def load_nyu_evaluate():
"""
Load the nyu_v2 dataset validation split. Will be downloaded to ../nyu
:return: nyu_v2 dataset builder
"""
builder = tfds.builder('nyu_depth_v2')
builder.download_and_prepare(download_dir='../nyu')
return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x))
if __name__ == '__main__':
model = dense_depth(169, 'imagenet')
model.summary()
# model = dense_nnconv5(169, 'imagenet')
# model.summary()