diff --git a/fast_depth_functional.py b/fast_depth_functional.py index eee428b..4da5289 100644 --- a/fast_depth_functional.py +++ b/fast_depth_functional.py @@ -5,59 +5,53 @@ Functional version of fastdepth model. Note that this doesn't work at the moment ''' -def _depthwise_conv_block(inputs, - pointwise_conv_filters, - depth_multiplier=1, - strides=(1, 1), - block_id=1): - channel_axis = 1 if keras.backend.image_data_format() == 'channels_first' else -1 - pointwise_conv_filters = int(pointwise_conv_filters) - - if strides == (1, 1): - x = inputs - else: - x = keras.layers.ZeroPadding2D(((0, 1), (0, 1)), name='conv_pad_%d' % block_id)( - inputs) - x = keras.layers.DepthwiseConv2D((3, 3), - padding='same' if strides == (1, 1) else 'valid', - depth_multiplier=depth_multiplier, - strides=strides, - use_bias=False, - name='conv_dw_%d' % block_id)( - x) - x = keras.layers.BatchNormalization( - axis=channel_axis, name='conv_dw_%d_bn' % block_id)( - x) - x = keras.layers.ReLU(6., name='conv_dw_%d_relu' % block_id)(x) - - x = keras.layers.Conv2D( - pointwise_conv_filters, (1, 1), - padding='same', - use_bias=False, - strides=(1, 1), - name='conv_pw_%d' % block_id)( - x) - x = keras.layers.BatchNormalization( - axis=channel_axis, name='conv_pw_%d_bn' % block_id)( - x) +def FDDepthwiseBlock(inputs, + out_channels, + block_id=1): + x = keras.layers.DepthwiseConv2D(5, padding='same')(inputs) + x = keras.layers.BatchNormalization()(x) + x = keras.layers.ReLU(6.)(x) + x = keras.layers.Conv2D(out_channels, 1, padding='same')(x) + x = keras.layers.BatchNormalization()(x) return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x) def make_fastdepth_functional(): # This doesn't work, at least right now... - mobilenet = keras.applications.MobileNet(include_top=False) input = keras.layers.Input(shape=(224, 224, 3)) - x = mobilenet(input) - x = _depthwise_conv_block(x, 512, block_id=14) - x = _depthwise_conv_block(x, 256, block_id=15) - x = keras.layers.Add()([x, mobilenet.get_layer('conv_pw_5_relu').output]) - x = _depthwise_conv_block(x, 128, block_id=16) - x = keras.layers.Add()([x, mobilenet.get_layer('conv_pw_3_relu').output]) - x = _depthwise_conv_block(x, 64, block_id=17) - x = keras.layers.Add()([x, mobilenet.get_layer('conv_pw_1_relu').output]) - x = _depthwise_conv_block(x, 32, block_id=18) + x = input + mobilenet = keras.applications.MobileNet(include_top=False, weights=None) + for layer in mobilenet.layers: + x = layer(x) + if layer.name == 'conv_pw_5_relu': + conv5 = x + elif layer.name == 'conv_pw_3_relu': + conv3 = x + elif layer.name == 'conv_pw_1_relu': + conv1 = x - x = keras.layers.Conv2D(1, 1)(x) + x = FDDepthwiseBlock(x, 512, block_id=14) + # TODO: Bilinear interpolation + # x = keras.layers.experimental.preprocessing.Resizing(14, 14) + # Nearest neighbour interpolation, used by fast depth paper + x = keras.layers.experimental.preprocessing.Resizing(14, 14, interpolation='nearest')(x) + x = FDDepthwiseBlock(x, 256, block_id=15) + x = keras.layers.experimental.preprocessing.Resizing(28, 28, interpolation='nearest')(x) + x = keras.layers.Add()([x, conv5]) + x = FDDepthwiseBlock(x, 128, block_id=16) + x = keras.layers.experimental.preprocessing.Resizing(56, 56, interpolation='nearest')(x) + x = keras.layers.Add()([x, conv3]) + x = FDDepthwiseBlock(x, 64, block_id=17) + x = keras.layers.experimental.preprocessing.Resizing(112, 112, interpolation='nearest')(x) + x = keras.layers.Add()([x, conv1]) + x = FDDepthwiseBlock(x, 32, block_id=18) + x = keras.layers.experimental.preprocessing.Resizing(224, 224, interpolation='nearest')(x) + + x = keras.layers.Conv2D(1, 1, padding='same')(x) x = keras.layers.BatchNormalization()(x) - x = keras.layers.ReLU()(x) - return keras.Model(input, x, name="fast_depth") + x = keras.layers.ReLU(6.)(x) + return keras.Model(inputs=input, outputs=x, name="fast_depth") + + +if __name__ == '__main__': + make_fastdepth_functional().summary() diff --git a/main.py b/main.py index 051da5b..62b1885 100644 --- a/main.py +++ b/main.py @@ -1,70 +1,6 @@ -import tensorflow as tf -from tensorflow import keras - import tensorflow_datasets as tfds -class DecodeConv(keras.layers.Layer): - def __init__(self, out_filters, **kwargs): - super().__init__(**kwargs) - # Should be depthwise followed by batchnorm and relu. - self.depthwise = keras.layers.DepthwiseConv2D(5) - self.batch_norm = keras.layers.BatchNormalization() - self.relu = keras.layers.ReLU(6.) - self.pointwise = keras.layers.Conv2D(out_filters, 1) - self.pointwise_bn = keras.layers.BatchNormalization() - self.pointwise_rl = keras.layers.ReLU(6.) - - def call(self, inputs, **kwargs): - inputs = self.depthwise(inputs, **kwargs) - inputs = self.batch_norm(inputs, **kwargs) - inputs = self.relu(inputs, **kwargs) - inputs = self.pointwise(inputs, **kwargs) - inputs = self.pointwise_bn(inputs, **kwargs) - return self.pointwise_rl(inputs, **kwargs) - - -class FastDepth(keras.Model): - def get_config(self): - # TODO: What to put here? - pass - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.mobile_net = keras.applications.MobileNet(include_top=False) - - # TODO: Try keras.layers.SeparableConv2D as well, should do the same thing if relu is used as activation - # It probably doesn't, since - self.decode_conv1 = DecodeConv(512) - self.decode_conv2 = DecodeConv(256) - self.decode_conv3 = DecodeConv(128) - self.decode_conv4 = DecodeConv(64) - self.decode_conv5 = DecodeConv(32) - - self.final_pointwise = keras.layers.Conv2D(1, 1) - self.final_pointwise_bn = keras.layers.BatchNormalization() - self.final_pointwise_relu = keras.layers.ReLU() - - def call(self, inputs, is_training=False, **kwargs): - # Go through mobilenet, then each decode layer, including skip connections using: - # keras.layers.Add() - inputs = self.mobile_net(inputs, is_training=is_training, **kwargs) - - # FastDepth Additive Decoder - inputs = self.decode_conv1(inputs, is_training=is_training, **kwargs) - inputs = self.decode_conv2(inputs, is_training=is_training, **kwargs) - inputs = inputs + self.mobile_net.get_layer('conv_pw_5_relu').output - inputs = self.decode_conv3(inputs, is_training=is_training, **kwargs) - inputs = inputs + self.mobile_net.get_layer('conv_pw_3_relu').output - inputs = self.decode_conv4(inputs, is_training=is_training, **kwargs) - inputs = inputs + self.mobile_net.get_layer('conv_pw_1_relu').output - inputs = self.decode_conv5(inputs, is_training=is_training, **kwargs) - - inputs = self.final_pointwise(inputs, is_training=is_training, **kwargs) - inputs = self.final_pointwise_bn(inputs, is_training=is_training, **kwargs) - return self.final_pointwise_relu(inputs, is_training=is_training, **kwargs) - - def load_nyu(): builder = tfds.builder('nyu_depth_v2') builder.download_and_prepare(download_dir='../nyu')