Files
fast-depth-tf/fast_depth_functional.py
2021-03-17 21:15:06 +10:30

98 lines
3.5 KiB
Python

import tensorflow as tf
import tensorflow.keras as keras
'''
Functional version of fastdepth model
'''
def FDDepthwiseBlock(inputs,
out_channels,
block_id=1):
x = keras.layers.DepthwiseConv2D(5, padding='same')(inputs)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU(6.)(x)
x = keras.layers.Conv2D(out_channels, 1, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x)
def make_fastdepth_functional(weights=None):
# This doesn't work, at least right now...
input = keras.layers.Input(shape=(224, 224, 3))
x = input
mobilenet = keras.applications.MobileNet(include_top=False, weights=weights)
for layer in mobilenet.layers:
x = layer(x)
if layer.name == 'conv_pw_5_relu':
conv5 = x
elif layer.name == 'conv_pw_3_relu':
conv3 = x
elif layer.name == 'conv_pw_1_relu':
conv1 = x
# Fast depth decoder
x = FDDepthwiseBlock(x, 512, block_id=14)
# TODO: Bilinear interpolation
# x = keras.layers.experimental.preprocessing.Resizing(14, 14)
# Nearest neighbour interpolation, used by fast depth paper
x = keras.layers.experimental.preprocessing.Resizing(14, 14, interpolation='nearest')(x)
x = FDDepthwiseBlock(x, 256, block_id=15)
x = keras.layers.experimental.preprocessing.Resizing(28, 28, interpolation='nearest')(x)
x = keras.layers.Add()([x, conv5])
x = FDDepthwiseBlock(x, 128, block_id=16)
x = keras.layers.experimental.preprocessing.Resizing(56, 56, interpolation='nearest')(x)
x = keras.layers.Add()([x, conv3])
x = FDDepthwiseBlock(x, 64, block_id=17)
x = keras.layers.experimental.preprocessing.Resizing(112, 112, interpolation='nearest')(x)
x = keras.layers.Add()([x, conv1])
x = FDDepthwiseBlock(x, 32, block_id=18)
x = keras.layers.experimental.preprocessing.Resizing(224, 224, interpolation='nearest')(x)
x = keras.layers.Conv2D(1, 1, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU(6.)(x)
return keras.Model(inputs=input, outputs=x, name="fast_depth")
def delta1_metric(y_true, y_pred):
maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred)
return float((maxRatio < 1.25).float().mean())
def delta2_metric(y_true, y_pred):
maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred)
return float((maxRatio < 1.2 ** 25).float().mean())
def delta3_metric(y_true, y_pred):
maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred)
return float((maxRatio < 1.25 ** 3).float().mean())
def fastdepth_for_training():
# Pretrained mobilenet on imagenet dataset
model = make_fastdepth_functional('imagenet')
return model.compile(optimizer=keras.optimizers.SGD(momentum=0.9),
loss=keras.losses.MSE(),
metrics=[keras.metrics.RootMeanSquaredError(),
keras.metrics.MeanSquaredError(),
delta1_metric,
delta2_metric,
delta3_metric])
def train_compiled_model(compiled_model, dataset):
"""
:param compiled_model: Compiled model to train on
:param dataset: Dataset to train on (must be compatible with model
:return:
"""
# TODO: Use tf nyu_v2 dataset to train.
pass
if __name__ == '__main__':
make_fastdepth_functional().summary()