From 870429c3efdda22b2e01e0245615a281ac6097b7 Mon Sep 17 00:00:00 2001 From: Piv <18462828+Piv200@users.noreply.github.com> Date: Mon, 29 Mar 2021 17:55:07 +1030 Subject: [PATCH] Refactor fast-depth Addresses the following: - Rename nnconv5 block to nnconv5 - Add skip connections directly to nnconv5 block - Allow custom metrics, loss and optimizer (keep defaults that reflect original paper) to train - Correctly use nyu evaluation dataset only when no dataset is provided --- fast_depth_functional.py | 56 +++++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/fast_depth_functional.py b/fast_depth_functional.py index bb9c87e..00f92fe 100644 --- a/fast_depth_functional.py +++ b/fast_depth_functional.py @@ -29,14 +29,18 @@ def fix_windows_gpu(): print(e) -def FDDepthwiseBlock(inputs, - out_channels, - block_id=1): +def nnconv5(inputs, + out_channels, + block_id=1, + skip_connection=None): x = keras.layers.DepthwiseConv2D(5, padding='same')(inputs) x = keras.layers.BatchNormalization()(x) x = keras.layers.ReLU(6.)(x) x = keras.layers.Conv2D(out_channels, 1, padding='same')(x) x = keras.layers.BatchNormalization()(x) + x = keras.layers.UpSampling2D()(x) + if skip_connection is not None: + x = keras.layers.Add()([x, skip_connection]) return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x) @@ -54,23 +58,14 @@ def mobilenet_nnconv5(weights=None, shape=(224, 224, 3)): layer.trainable = True # Fast depth decoder - x = FDDepthwiseBlock(mobilenet.output, 512, block_id=14) - # Nearest neighbour interpolation, used by fast depth paper - x = keras.layers.UpSampling2D()(x) - x = FDDepthwiseBlock(x, 256, block_id=15) - x = keras.layers.UpSampling2D()(x) - x = keras.layers.Add()( - [x, mobilenet.get_layer(name="conv_pw_5_relu").output]) - x = FDDepthwiseBlock(x, 128, block_id=16) - x = keras.layers.UpSampling2D()(x) - x = keras.layers.Add()( - [x, mobilenet.get_layer(name="conv_pw_3_relu").output]) - x = FDDepthwiseBlock(x, 64, block_id=17) - x = keras.layers.UpSampling2D()(x) - x = keras.layers.Add()( - [x, mobilenet.get_layer(name="conv_pw_1_relu").output]) - x = FDDepthwiseBlock(x, 32, block_id=18) - x = keras.layers.UpSampling2D()(x) + x = nnconv5(mobilenet.output, 512, block_id=14) + x = nnconv5(x, 256, block_id=15, skip_connection=mobilenet.get_layer( + name="conv_pw_5_relu").output) + x = nnconv5(x, 128, block_id=16, skip_connection=mobilenet.get_layer( + name="conv_pw_3_relu").output) + x = nnconv5(x, 64, block_id=17, skip_connection=mobilenet.get_layer( + name="conv_pw_1_relu").output) + x = nnconv5(x, 32, block_id=18) x = keras.layers.Conv2D(1, 1, padding='same')(x) x = keras.layers.BatchNormalization()(x) @@ -93,19 +88,21 @@ def delta3_metric(y_true, y_pred): return tf.nn.moments(tf.cast(maxRatio < tf.convert_to_tensor(1.25 ** 3), tf.float32), axes=None)[0] -def compile(model): +def compile(model, optimiser=keras.optimizers.SGD(), loss=keras.losses.MeanSquaredError(), custom_metrics=None): """ Compile FastDepth model with relevant metrics :param model: Model to compile + :param optimiser: Custom optimiser to use + :param loss: Loss function to use + :param include_metrics: Whether to include metrics (RMSE, MSE, a1,2,3) """ - # TODO: Learning rate (exponential decay) - model.compile(optimizer=keras.optimizers.SGD(momentum=0.9), - loss=keras.losses.MeanSquaredError(), + model.compile(optimizer=optimiser, + loss=loss, metrics=[keras.metrics.RootMeanSquaredError(), keras.metrics.MeanSquaredError(), delta1_metric, delta2_metric, - delta3_metric]) + delta3_metric] if custom_metrics is None else custom_metrics) def train(existing_model=None, pretrained_weights='imagenet', epochs=4, save_file=None, dataset=None): @@ -120,7 +117,7 @@ def train(existing_model=None, pretrained_weights='imagenet', epochs=4, save_fil """ if not existing_model: existing_model = mobilenet_nnconv5(pretrained_weights) - compile(existing_model) + compile(existing_model) if not dataset: dataset = load_nyu() existing_model.fit(dataset, epochs=epochs) @@ -137,7 +134,7 @@ def evaluate(compiled_model, dataset=None): where label width/height matches image width/height. Defaults to Tensorflow nyu_v2 evaluation split dataset (https://www.tensorflow.org/datasets/catalog/nyu_depth_v2) """ - if not dataset: + if dataset is None: dataset = load_nyu_evaluate() compiled_model.evaluate(dataset, verbose=1) @@ -200,3 +197,8 @@ def load_nyu_evaluate(): builder = tfds.builder('nyu_depth_v2') builder.download_and_prepare(download_dir='../nyu') return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x)) + + +if __name__ == '__main__': + model = mobilenet_nnconv5() + model.summary()