diff --git a/fast_depth_functional.py b/fast_depth_functional.py index 70dd07a..c085687 100644 --- a/fast_depth_functional.py +++ b/fast_depth_functional.py @@ -35,34 +35,34 @@ def FDDepthwiseBlock(inputs, return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x) +def FDDepthwiseBlockNoBN(inputs, out_channels, block_id=1): + x = keras.layers.DepthwiseConv2D(5, padding='same')(inputs) + x = keras.layers.ReLU(6.)(x) + x = keras.layers.Conv2D(out_channels, 1, padding='same')(x) + return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x) + + def make_mobilenet_nnconv5(weights=None, shape=(224, 224, 3)): input = keras.layers.Input(shape=shape) - x = input - mobilenet = keras.applications.MobileNet(input_tensor=x, include_top=False, weights=weights) + mobilenet = keras.applications.MobileNet(input_tensor=input, include_top=False, weights=weights) for layer in mobilenet.layers: - x = layer(x) - if layer.name == 'conv_pw_5_relu': - conv5 = x - elif layer.name == 'conv_pw_3_relu': - conv3 = x - elif layer.name == 'conv_pw_1_relu': - conv1 = x + layer.trainable = True # Fast depth decoder - x = FDDepthwiseBlock(x, 512, block_id=14) + x = FDDepthwiseBlock(mobilenet.output, 512, block_id=14) # TODO: Bilinear interpolation - # x = keras.layers.experimental.preprocessing.Resizing(14, 14) + # x = keras.layers.experimental.preprocessing.Resizing(14, 14, interpolation='bilinear') # Nearest neighbour interpolation, used by fast depth paper x = keras.layers.experimental.preprocessing.Resizing(14, 14, interpolation='nearest')(x) x = FDDepthwiseBlock(x, 256, block_id=15) x = keras.layers.experimental.preprocessing.Resizing(28, 28, interpolation='nearest')(x) - x = keras.layers.Add()([x, conv5]) + x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_5_relu").output]) x = FDDepthwiseBlock(x, 128, block_id=16) x = keras.layers.experimental.preprocessing.Resizing(56, 56, interpolation='nearest')(x) - x = keras.layers.Add()([x, conv3]) + x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_3_relu").output]) x = FDDepthwiseBlock(x, 64, block_id=17) x = keras.layers.experimental.preprocessing.Resizing(112, 112, interpolation='nearest')(x) - x = keras.layers.Add()([x, conv1]) + x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_1_relu").output]) x = FDDepthwiseBlock(x, 32, block_id=18) x = keras.layers.experimental.preprocessing.Resizing(224, 224, interpolation='nearest')(x) @@ -72,6 +72,33 @@ def make_mobilenet_nnconv5(weights=None, shape=(224, 224, 3)): return keras.Model(inputs=input, outputs=x, name="fast_depth") +def make_mobilenet_nnconv5_no_bn(weights=None, shape=(224, 224, 3)): + input = keras.layers.Input(shape=shape) + mobilenet = keras.applications.MobileNet(input_tensor=input, include_top=False, weights=weights) + for layer in mobilenet.layers: + layer.trainable = True + + # Fast depth decoder + x = FDDepthwiseBlockNoBN(mobilenet.output, 512, block_id=14) + # Nearest neighbour interpolation, used by fast depth paper + x = keras.layers.experimental.preprocessing.Resizing(14, 14, interpolation='bilinear')(x) + x = FDDepthwiseBlockNoBN(x, 256, block_id=15) + x = keras.layers.experimental.preprocessing.Resizing(28, 28, interpolation='bilinear')(x) + x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_5_relu").output]) + x = FDDepthwiseBlockNoBN(x, 128, block_id=16) + x = keras.layers.experimental.preprocessing.Resizing(56, 56, interpolation='bilinear')(x) + x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_3_relu").output]) + x = FDDepthwiseBlockNoBN(x, 64, block_id=17) + x = keras.layers.experimental.preprocessing.Resizing(112, 112, interpolation='bilinear')(x) + x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_1_relu").output]) + x = FDDepthwiseBlockNoBN(x, 32, block_id=18) + x = keras.layers.experimental.preprocessing.Resizing(224, 224, interpolation='bilinear')(x) + + x = keras.layers.Conv2D(1, 1, padding='same')(x) + x = keras.layers.ReLU(6.)(x) + return keras.Model(inputs=input, outputs=x, name="fast_depth") + + # TODO: Fix these, float doesn't work same as pytorch def delta1_metric(y_true, y_pred): maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred)