Refactor fast-depth

Addresses the following:
 - Rename nnconv5 block to nnconv5
 - Add skip connections directly to nnconv5 block
 - Allow custom metrics, loss and optimizer (keep defaults that reflect original paper) to train
 - Correctly use nyu evaluation dataset only when no dataset is provided
This commit is contained in:
Piv
2021-03-29 17:55:07 +10:30
parent 3325ea0c0c
commit 870429c3ef

View File

@@ -29,14 +29,18 @@ def fix_windows_gpu():
print(e) print(e)
def FDDepthwiseBlock(inputs, def nnconv5(inputs,
out_channels, out_channels,
block_id=1): block_id=1,
skip_connection=None):
x = keras.layers.DepthwiseConv2D(5, padding='same')(inputs) x = keras.layers.DepthwiseConv2D(5, padding='same')(inputs)
x = keras.layers.BatchNormalization()(x) x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU(6.)(x) x = keras.layers.ReLU(6.)(x)
x = keras.layers.Conv2D(out_channels, 1, padding='same')(x) x = keras.layers.Conv2D(out_channels, 1, padding='same')(x)
x = keras.layers.BatchNormalization()(x) x = keras.layers.BatchNormalization()(x)
x = keras.layers.UpSampling2D()(x)
if skip_connection is not None:
x = keras.layers.Add()([x, skip_connection])
return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x) return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x)
@@ -54,23 +58,14 @@ def mobilenet_nnconv5(weights=None, shape=(224, 224, 3)):
layer.trainable = True layer.trainable = True
# Fast depth decoder # Fast depth decoder
x = FDDepthwiseBlock(mobilenet.output, 512, block_id=14) x = nnconv5(mobilenet.output, 512, block_id=14)
# Nearest neighbour interpolation, used by fast depth paper x = nnconv5(x, 256, block_id=15, skip_connection=mobilenet.get_layer(
x = keras.layers.UpSampling2D()(x) name="conv_pw_5_relu").output)
x = FDDepthwiseBlock(x, 256, block_id=15) x = nnconv5(x, 128, block_id=16, skip_connection=mobilenet.get_layer(
x = keras.layers.UpSampling2D()(x) name="conv_pw_3_relu").output)
x = keras.layers.Add()( x = nnconv5(x, 64, block_id=17, skip_connection=mobilenet.get_layer(
[x, mobilenet.get_layer(name="conv_pw_5_relu").output]) name="conv_pw_1_relu").output)
x = FDDepthwiseBlock(x, 128, block_id=16) x = nnconv5(x, 32, block_id=18)
x = keras.layers.UpSampling2D()(x)
x = keras.layers.Add()(
[x, mobilenet.get_layer(name="conv_pw_3_relu").output])
x = FDDepthwiseBlock(x, 64, block_id=17)
x = keras.layers.UpSampling2D()(x)
x = keras.layers.Add()(
[x, mobilenet.get_layer(name="conv_pw_1_relu").output])
x = FDDepthwiseBlock(x, 32, block_id=18)
x = keras.layers.UpSampling2D()(x)
x = keras.layers.Conv2D(1, 1, padding='same')(x) x = keras.layers.Conv2D(1, 1, padding='same')(x)
x = keras.layers.BatchNormalization()(x) x = keras.layers.BatchNormalization()(x)
@@ -93,19 +88,21 @@ def delta3_metric(y_true, y_pred):
return tf.nn.moments(tf.cast(maxRatio < tf.convert_to_tensor(1.25 ** 3), tf.float32), axes=None)[0] return tf.nn.moments(tf.cast(maxRatio < tf.convert_to_tensor(1.25 ** 3), tf.float32), axes=None)[0]
def compile(model): def compile(model, optimiser=keras.optimizers.SGD(), loss=keras.losses.MeanSquaredError(), custom_metrics=None):
""" """
Compile FastDepth model with relevant metrics Compile FastDepth model with relevant metrics
:param model: Model to compile :param model: Model to compile
:param optimiser: Custom optimiser to use
:param loss: Loss function to use
:param include_metrics: Whether to include metrics (RMSE, MSE, a1,2,3)
""" """
# TODO: Learning rate (exponential decay) model.compile(optimizer=optimiser,
model.compile(optimizer=keras.optimizers.SGD(momentum=0.9), loss=loss,
loss=keras.losses.MeanSquaredError(),
metrics=[keras.metrics.RootMeanSquaredError(), metrics=[keras.metrics.RootMeanSquaredError(),
keras.metrics.MeanSquaredError(), keras.metrics.MeanSquaredError(),
delta1_metric, delta1_metric,
delta2_metric, delta2_metric,
delta3_metric]) delta3_metric] if custom_metrics is None else custom_metrics)
def train(existing_model=None, pretrained_weights='imagenet', epochs=4, save_file=None, dataset=None): def train(existing_model=None, pretrained_weights='imagenet', epochs=4, save_file=None, dataset=None):
@@ -120,7 +117,7 @@ def train(existing_model=None, pretrained_weights='imagenet', epochs=4, save_fil
""" """
if not existing_model: if not existing_model:
existing_model = mobilenet_nnconv5(pretrained_weights) existing_model = mobilenet_nnconv5(pretrained_weights)
compile(existing_model) compile(existing_model)
if not dataset: if not dataset:
dataset = load_nyu() dataset = load_nyu()
existing_model.fit(dataset, epochs=epochs) existing_model.fit(dataset, epochs=epochs)
@@ -137,7 +134,7 @@ def evaluate(compiled_model, dataset=None):
where label width/height matches image width/height. where label width/height matches image width/height.
Defaults to Tensorflow nyu_v2 evaluation split dataset (https://www.tensorflow.org/datasets/catalog/nyu_depth_v2) Defaults to Tensorflow nyu_v2 evaluation split dataset (https://www.tensorflow.org/datasets/catalog/nyu_depth_v2)
""" """
if not dataset: if dataset is None:
dataset = load_nyu_evaluate() dataset = load_nyu_evaluate()
compiled_model.evaluate(dataset, verbose=1) compiled_model.evaluate(dataset, verbose=1)
@@ -200,3 +197,8 @@ def load_nyu_evaluate():
builder = tfds.builder('nyu_depth_v2') builder = tfds.builder('nyu_depth_v2')
builder.download_and_prepare(download_dir='../nyu') builder.download_and_prepare(download_dir='../nyu')
return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x)) return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x))
if __name__ == '__main__':
model = mobilenet_nnconv5()
model.summary()