import tensorflow.keras as keras ''' Functional version of fastdepth model. Note that this doesn't work at the moment ''' def _depthwise_conv_block(inputs, pointwise_conv_filters, depth_multiplier=1, strides=(1, 1), block_id=1): channel_axis = 1 if keras.backend.image_data_format() == 'channels_first' else -1 pointwise_conv_filters = int(pointwise_conv_filters) if strides == (1, 1): x = inputs else: x = keras.layers.ZeroPadding2D(((0, 1), (0, 1)), name='conv_pad_%d' % block_id)( inputs) x = keras.layers.DepthwiseConv2D((3, 3), padding='same' if strides == (1, 1) else 'valid', depth_multiplier=depth_multiplier, strides=strides, use_bias=False, name='conv_dw_%d' % block_id)( x) x = keras.layers.BatchNormalization( axis=channel_axis, name='conv_dw_%d_bn' % block_id)( x) x = keras.layers.ReLU(6., name='conv_dw_%d_relu' % block_id)(x) x = keras.layers.Conv2D( pointwise_conv_filters, (1, 1), padding='same', use_bias=False, strides=(1, 1), name='conv_pw_%d' % block_id)( x) x = keras.layers.BatchNormalization( axis=channel_axis, name='conv_pw_%d_bn' % block_id)( x) return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x) def make_fastdepth_functional(): # This doesn't work, at least right now... mobilenet = keras.applications.MobileNet(include_top=False) input = keras.layers.Input(shape=(224, 224, 3)) x = mobilenet(input) x = _depthwise_conv_block(x, 512, block_id=14) x = _depthwise_conv_block(x, 256, block_id=15) x = keras.layers.Add()([x, mobilenet.get_layer('conv_pw_5_relu').output]) x = _depthwise_conv_block(x, 128, block_id=16) x = keras.layers.Add()([x, mobilenet.get_layer('conv_pw_3_relu').output]) x = _depthwise_conv_block(x, 64, block_id=17) x = keras.layers.Add()([x, mobilenet.get_layer('conv_pw_1_relu').output]) x = _depthwise_conv_block(x, 32, block_id=18) x = keras.layers.Conv2D(1, 1)(x) x = keras.layers.BatchNormalization()(x) x = keras.layers.ReLU()(x) return keras.Model(input, x, name="fast_depth")