Merge branch 'dense-depth' into 'main'

Dense depth

See merge request vato007/fast-depth-tf!1
This commit is contained in:
Michael Pivato
2021-03-29 07:31:42 +00:00
3 changed files with 204 additions and 27 deletions

153
dense_depth_functional.py Normal file
View File

@@ -0,0 +1,153 @@
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow_datasets as tfds
import fast_depth_functional as fd
def dense_upproject(input, out_channels, skip_connection):
x = keras.layers.UpSampling2D(interpolation='bilinear')(input)
x = keras.layers.Concatenate()([x, skip_connection])
x = keras.layers.Conv2D(filters=out_channels,
kernel_size=3, strides=1, padding='same')(x)
x = keras.layers.LeakyReLU(alpha=0.2)(x)
x = keras.layers.Conv2D(filters=out_channels,
kernel_size=3, strides=1, padding='same')(x)
return keras.layers.LeakyReLU(alpha=0.2)(x)
def dense_depth(size, weights=None, shape=(224, 224, 3), half_features=True):
input = keras.layers.Input(shape=shape)
densenet = dense_net(input, size, weights, shape)
densenet_output_shape = densenet.layers[-1].output.shape
if half_features:
decode_filters = densenet_output_shape[-1] // 2
else:
decode_filters = int(densenet_output_shape[-1])
# Reduce the feature set (pointwise)
decoder = keras.layers.Conv2D(filters=decode_filters, kernel_size=1, padding='same',
input_shape=densenet_output_shape, name='conv2')(densenet.output)
# The actual decoder
decoder = dense_upproject(
decoder, decode_filters // 2, densenet.get_layer('pool3_pool').output)
decoder = dense_upproject(
decoder, decode_filters // 4, densenet.get_layer('pool2_pool').output)
decoder = dense_upproject(
decoder, decode_filters // 8, densenet.get_layer('pool1').output)
decoder = dense_upproject(
decoder, decode_filters // 16, densenet.get_layer('conv1/relu').output)
# Enable to upproject to full image size
# decoder = dense_upproject(decoder, int(decode_filters / 32), input)
conv3 = keras.layers.Conv2D(
filters=1, kernel_size=3, strides=1, padding='same', name='conv3')(decoder)
return keras.Model(inputs=input, outputs=conv3, name='dense_depth')
def dense_net(input, size, weights=None, shape=(224, 224, 3)):
if size == 121:
densenet = keras.applications.DenseNet121(input_tensor=input, input_shape=shape, weights=weights,
include_top=False)
elif size == 169:
densenet = keras.applications.DenseNet169(input_tensor=input, input_shape=shape, weights=weights,
include_top=False)
else:
densenet = keras.applications.DenseNet201(input_tensor=input, input_shape=shape, weights=weights,
include_top=False)
for layer in densenet.layers:
layer.trainable = True
return densenet
def dense_nnconv5(size, weights=None, shape=(224, 224, 3), half_features=True):
input = keras.layers.Input(shape=shape)
densenet = dense_net(input, size, weights, shape)
densenet_output_shape = densenet.layers[-1].output.shape
if half_features:
decode_filters = int(int(densenet_output_shape[-1]) / 2)
else:
decode_filters = int(densenet_output_shape[-1])
# Reduce the feature set (pointwise)
x = keras.layers.Conv2D(filters=decode_filters, kernel_size=1, padding='same',
input_shape=densenet_output_shape, name='conv2')(densenet.output)
# TODO: More intermediate layers here?
# Fast Depth Decoder
x = fd.nnconv5(x, densenet.get_layer('pool3_pool').output_shape[3], 1,
skip_connection=densenet.get_layer('pool3_pool').output)
x = fd.nnconv5(x, densenet.get_layer('pool2_pool').output_shape[3], 2,
skip_connection=densenet.get_layer('pool2_pool').output)
x = fd.nnconv5(x, densenet.get_layer('pool1').output_shape[3], 3,
skip_connection=densenet.get_layer('pool1').output)
x = fd.nnconv5(x, densenet.get_layer('conv1/relu').output_shape[3], 4,
skip_connection=densenet.get_layer('conv1/relu').output)
# Enable to get full dense decode (back to original size)
# x = fd.nnconv5(x, int(densenet.get_layer('conv1/relu').output_shape[3] / 2), 5)
# Final Pointwise for depth extraction
x = keras.layers.Conv2D(1, 1, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU(6.)(x)
return keras.Model(inputs=input, outputs=x, name="fast_dense_depth")
def crop_and_resize(x):
shape = tf.shape(x['depth'])
def layer():
return keras.Sequential([
keras.layers.experimental.preprocessing.CenterCrop(
shape[1], shape[2]),
keras.layers.experimental.preprocessing.Resizing(
224, 224, interpolation='nearest')
])
def half_layer():
return keras.Sequential([
keras.layers.experimental.preprocessing.CenterCrop(
shape[1], shape[2]),
keras.layers.experimental.preprocessing.Resizing(
112, 112, interpolation='nearest')
])
# Reshape label to 4d, can't use array unwrap as it's unsupported by tensorflow
return layer()(x['image']), half_layer()(tf.reshape(x['depth'], [shape[0], shape[1], shape[2], 1]))
def load_nyu():
"""
Load the nyu_v2 dataset train split. Will be downloaded to ../nyu
:return: nyu_v2 dataset builder
"""
builder = tfds.builder('nyu_depth_v2')
builder.download_and_prepare(download_dir='../nyu')
return builder \
.as_dataset(split='train', shuffle_files=True) \
.shuffle(buffer_size=1024) \
.batch(8) \
.map(lambda x: crop_and_resize(x))
def load_nyu_evaluate():
"""
Load the nyu_v2 dataset validation split. Will be downloaded to ../nyu
:return: nyu_v2 dataset builder
"""
builder = tfds.builder('nyu_depth_v2')
builder.download_and_prepare(download_dir='../nyu')
return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x))
if __name__ == '__main__':
model = dense_depth(169, 'imagenet')
model.summary()
# model = dense_nnconv5(169, 'imagenet')
# model.summary()

View File

@@ -29,14 +29,18 @@ def fix_windows_gpu():
print(e)
def FDDepthwiseBlock(inputs,
out_channels,
block_id=1):
def nnconv5(inputs,
out_channels,
block_id=1,
skip_connection=None):
x = keras.layers.DepthwiseConv2D(5, padding='same')(inputs)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU(6.)(x)
x = keras.layers.Conv2D(out_channels, 1, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.UpSampling2D()(x)
if skip_connection is not None:
x = keras.layers.Add()([x, skip_connection])
return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x)
@@ -54,23 +58,14 @@ def mobilenet_nnconv5(weights=None, shape=(224, 224, 3)):
layer.trainable = True
# Fast depth decoder
x = FDDepthwiseBlock(mobilenet.output, 512, block_id=14)
# Nearest neighbour interpolation, used by fast depth paper
x = keras.layers.UpSampling2D()(x)
x = FDDepthwiseBlock(x, 256, block_id=15)
x = keras.layers.UpSampling2D()(x)
x = keras.layers.Add()(
[x, mobilenet.get_layer(name="conv_pw_5_relu").output])
x = FDDepthwiseBlock(x, 128, block_id=16)
x = keras.layers.UpSampling2D()(x)
x = keras.layers.Add()(
[x, mobilenet.get_layer(name="conv_pw_3_relu").output])
x = FDDepthwiseBlock(x, 64, block_id=17)
x = keras.layers.UpSampling2D()(x)
x = keras.layers.Add()(
[x, mobilenet.get_layer(name="conv_pw_1_relu").output])
x = FDDepthwiseBlock(x, 32, block_id=18)
x = keras.layers.UpSampling2D()(x)
x = nnconv5(mobilenet.output, 512, block_id=14)
x = nnconv5(x, 256, block_id=15, skip_connection=mobilenet.get_layer(
name="conv_pw_5_relu").output)
x = nnconv5(x, 128, block_id=16, skip_connection=mobilenet.get_layer(
name="conv_pw_3_relu").output)
x = nnconv5(x, 64, block_id=17, skip_connection=mobilenet.get_layer(
name="conv_pw_1_relu").output)
x = nnconv5(x, 32, block_id=18)
x = keras.layers.Conv2D(1, 1, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
@@ -93,19 +88,21 @@ def delta3_metric(y_true, y_pred):
return tf.nn.moments(tf.cast(maxRatio < tf.convert_to_tensor(1.25 ** 3), tf.float32), axes=None)[0]
def compile(model):
def compile(model, optimiser=keras.optimizers.SGD(), loss=keras.losses.MeanSquaredError(), custom_metrics=None):
"""
Compile FastDepth model with relevant metrics
:param model: Model to compile
:param optimiser: Custom optimiser to use
:param loss: Loss function to use
:param include_metrics: Whether to include metrics (RMSE, MSE, a1,2,3)
"""
# TODO: Learning rate (exponential decay)
model.compile(optimizer=keras.optimizers.SGD(momentum=0.9),
loss=keras.losses.MeanSquaredError(),
model.compile(optimizer=optimiser,
loss=loss,
metrics=[keras.metrics.RootMeanSquaredError(),
keras.metrics.MeanSquaredError(),
delta1_metric,
delta2_metric,
delta3_metric])
delta3_metric] if custom_metrics is None else custom_metrics)
def train(existing_model=None, pretrained_weights='imagenet', epochs=4, save_file=None, dataset=None):
@@ -120,7 +117,7 @@ def train(existing_model=None, pretrained_weights='imagenet', epochs=4, save_fil
"""
if not existing_model:
existing_model = mobilenet_nnconv5(pretrained_weights)
compile(existing_model)
compile(existing_model)
if not dataset:
dataset = load_nyu()
existing_model.fit(dataset, epochs=epochs)
@@ -137,7 +134,7 @@ def evaluate(compiled_model, dataset=None):
where label width/height matches image width/height.
Defaults to Tensorflow nyu_v2 evaluation split dataset (https://www.tensorflow.org/datasets/catalog/nyu_depth_v2)
"""
if not dataset:
if dataset is None:
dataset = load_nyu_evaluate()
compiled_model.evaluate(dataset, verbose=1)
@@ -200,3 +197,8 @@ def load_nyu_evaluate():
builder = tfds.builder('nyu_depth_v2')
builder.download_and_prepare(download_dir='../nyu')
return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x))
if __name__ == '__main__':
model = mobilenet_nnconv5()
model.summary()

22
losses.py Normal file
View File

@@ -0,0 +1,22 @@
import tensorflow as tf
import tensorflow.keras.backend as K
def dense_depth_loss_function(y_true, y_pred, theta=0.1, maxDepthVal=1000.0 / 10.0):
# Point-wise depth
l_depth = K.mean(K.abs(y_pred - y_true), axis=-1)
# Edges
dy_true, dx_true = tf.image.image_gradients(y_true)
dy_pred, dx_pred = tf.image.image_gradients(y_pred)
l_edges = K.mean(K.abs(dy_pred - dy_true) + K.abs(dx_pred - dx_true), axis=-1)
# Structural similarity (SSIM) index
l_ssim = K.clip((1 - tf.image.ssim(y_true, y_pred, maxDepthVal)) * 0.5, 0, 1)
# Weights
w1 = 1.0
w2 = 1.0
w3 = theta
return (w1 * l_ssim) + (w2 * K.mean(l_edges)) + (w3 * K.mean(l_depth))