Refactor fast-depth
Addresses the following: - Rename nnconv5 block to nnconv5 - Add skip connections directly to nnconv5 block - Allow custom metrics, loss and optimizer (keep defaults that reflect original paper) to train - Correctly use nyu evaluation dataset only when no dataset is provided
This commit is contained in:
@@ -29,14 +29,18 @@ def fix_windows_gpu():
|
||||
print(e)
|
||||
|
||||
|
||||
def FDDepthwiseBlock(inputs,
|
||||
out_channels,
|
||||
block_id=1):
|
||||
def nnconv5(inputs,
|
||||
out_channels,
|
||||
block_id=1,
|
||||
skip_connection=None):
|
||||
x = keras.layers.DepthwiseConv2D(5, padding='same')(inputs)
|
||||
x = keras.layers.BatchNormalization()(x)
|
||||
x = keras.layers.ReLU(6.)(x)
|
||||
x = keras.layers.Conv2D(out_channels, 1, padding='same')(x)
|
||||
x = keras.layers.BatchNormalization()(x)
|
||||
x = keras.layers.UpSampling2D()(x)
|
||||
if skip_connection is not None:
|
||||
x = keras.layers.Add()([x, skip_connection])
|
||||
return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x)
|
||||
|
||||
|
||||
@@ -54,23 +58,14 @@ def mobilenet_nnconv5(weights=None, shape=(224, 224, 3)):
|
||||
layer.trainable = True
|
||||
|
||||
# Fast depth decoder
|
||||
x = FDDepthwiseBlock(mobilenet.output, 512, block_id=14)
|
||||
# Nearest neighbour interpolation, used by fast depth paper
|
||||
x = keras.layers.UpSampling2D()(x)
|
||||
x = FDDepthwiseBlock(x, 256, block_id=15)
|
||||
x = keras.layers.UpSampling2D()(x)
|
||||
x = keras.layers.Add()(
|
||||
[x, mobilenet.get_layer(name="conv_pw_5_relu").output])
|
||||
x = FDDepthwiseBlock(x, 128, block_id=16)
|
||||
x = keras.layers.UpSampling2D()(x)
|
||||
x = keras.layers.Add()(
|
||||
[x, mobilenet.get_layer(name="conv_pw_3_relu").output])
|
||||
x = FDDepthwiseBlock(x, 64, block_id=17)
|
||||
x = keras.layers.UpSampling2D()(x)
|
||||
x = keras.layers.Add()(
|
||||
[x, mobilenet.get_layer(name="conv_pw_1_relu").output])
|
||||
x = FDDepthwiseBlock(x, 32, block_id=18)
|
||||
x = keras.layers.UpSampling2D()(x)
|
||||
x = nnconv5(mobilenet.output, 512, block_id=14)
|
||||
x = nnconv5(x, 256, block_id=15, skip_connection=mobilenet.get_layer(
|
||||
name="conv_pw_5_relu").output)
|
||||
x = nnconv5(x, 128, block_id=16, skip_connection=mobilenet.get_layer(
|
||||
name="conv_pw_3_relu").output)
|
||||
x = nnconv5(x, 64, block_id=17, skip_connection=mobilenet.get_layer(
|
||||
name="conv_pw_1_relu").output)
|
||||
x = nnconv5(x, 32, block_id=18)
|
||||
|
||||
x = keras.layers.Conv2D(1, 1, padding='same')(x)
|
||||
x = keras.layers.BatchNormalization()(x)
|
||||
@@ -93,19 +88,21 @@ def delta3_metric(y_true, y_pred):
|
||||
return tf.nn.moments(tf.cast(maxRatio < tf.convert_to_tensor(1.25 ** 3), tf.float32), axes=None)[0]
|
||||
|
||||
|
||||
def compile(model):
|
||||
def compile(model, optimiser=keras.optimizers.SGD(), loss=keras.losses.MeanSquaredError(), custom_metrics=None):
|
||||
"""
|
||||
Compile FastDepth model with relevant metrics
|
||||
:param model: Model to compile
|
||||
:param optimiser: Custom optimiser to use
|
||||
:param loss: Loss function to use
|
||||
:param include_metrics: Whether to include metrics (RMSE, MSE, a1,2,3)
|
||||
"""
|
||||
# TODO: Learning rate (exponential decay)
|
||||
model.compile(optimizer=keras.optimizers.SGD(momentum=0.9),
|
||||
loss=keras.losses.MeanSquaredError(),
|
||||
model.compile(optimizer=optimiser,
|
||||
loss=loss,
|
||||
metrics=[keras.metrics.RootMeanSquaredError(),
|
||||
keras.metrics.MeanSquaredError(),
|
||||
delta1_metric,
|
||||
delta2_metric,
|
||||
delta3_metric])
|
||||
delta3_metric] if custom_metrics is None else custom_metrics)
|
||||
|
||||
|
||||
def train(existing_model=None, pretrained_weights='imagenet', epochs=4, save_file=None, dataset=None):
|
||||
@@ -120,7 +117,7 @@ def train(existing_model=None, pretrained_weights='imagenet', epochs=4, save_fil
|
||||
"""
|
||||
if not existing_model:
|
||||
existing_model = mobilenet_nnconv5(pretrained_weights)
|
||||
compile(existing_model)
|
||||
compile(existing_model)
|
||||
if not dataset:
|
||||
dataset = load_nyu()
|
||||
existing_model.fit(dataset, epochs=epochs)
|
||||
@@ -137,7 +134,7 @@ def evaluate(compiled_model, dataset=None):
|
||||
where label width/height matches image width/height.
|
||||
Defaults to Tensorflow nyu_v2 evaluation split dataset (https://www.tensorflow.org/datasets/catalog/nyu_depth_v2)
|
||||
"""
|
||||
if not dataset:
|
||||
if dataset is None:
|
||||
dataset = load_nyu_evaluate()
|
||||
compiled_model.evaluate(dataset, verbose=1)
|
||||
|
||||
@@ -200,3 +197,8 @@ def load_nyu_evaluate():
|
||||
builder = tfds.builder('nyu_depth_v2')
|
||||
builder.download_and_prepare(download_dir='../nyu')
|
||||
return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = mobilenet_nnconv5()
|
||||
model.summary()
|
||||
|
||||
Reference in New Issue
Block a user