Functional models are way easier to work with, and I don't need any advanced features that would require model subclassing
58 lines
2.2 KiB
Python
58 lines
2.2 KiB
Python
import tensorflow.keras as keras
|
|
|
|
'''
|
|
Functional version of fastdepth model. Note that this doesn't work at the moment
|
|
'''
|
|
|
|
|
|
def FDDepthwiseBlock(inputs,
|
|
out_channels,
|
|
block_id=1):
|
|
x = keras.layers.DepthwiseConv2D(5, padding='same')(inputs)
|
|
x = keras.layers.BatchNormalization()(x)
|
|
x = keras.layers.ReLU(6.)(x)
|
|
x = keras.layers.Conv2D(out_channels, 1, padding='same')(x)
|
|
x = keras.layers.BatchNormalization()(x)
|
|
return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x)
|
|
|
|
|
|
def make_fastdepth_functional():
|
|
# This doesn't work, at least right now...
|
|
input = keras.layers.Input(shape=(224, 224, 3))
|
|
x = input
|
|
mobilenet = keras.applications.MobileNet(include_top=False, weights=None)
|
|
for layer in mobilenet.layers:
|
|
x = layer(x)
|
|
if layer.name == 'conv_pw_5_relu':
|
|
conv5 = x
|
|
elif layer.name == 'conv_pw_3_relu':
|
|
conv3 = x
|
|
elif layer.name == 'conv_pw_1_relu':
|
|
conv1 = x
|
|
|
|
x = FDDepthwiseBlock(x, 512, block_id=14)
|
|
# TODO: Bilinear interpolation
|
|
# x = keras.layers.experimental.preprocessing.Resizing(14, 14)
|
|
# Nearest neighbour interpolation, used by fast depth paper
|
|
x = keras.layers.experimental.preprocessing.Resizing(14, 14, interpolation='nearest')(x)
|
|
x = FDDepthwiseBlock(x, 256, block_id=15)
|
|
x = keras.layers.experimental.preprocessing.Resizing(28, 28, interpolation='nearest')(x)
|
|
x = keras.layers.Add()([x, conv5])
|
|
x = FDDepthwiseBlock(x, 128, block_id=16)
|
|
x = keras.layers.experimental.preprocessing.Resizing(56, 56, interpolation='nearest')(x)
|
|
x = keras.layers.Add()([x, conv3])
|
|
x = FDDepthwiseBlock(x, 64, block_id=17)
|
|
x = keras.layers.experimental.preprocessing.Resizing(112, 112, interpolation='nearest')(x)
|
|
x = keras.layers.Add()([x, conv1])
|
|
x = FDDepthwiseBlock(x, 32, block_id=18)
|
|
x = keras.layers.experimental.preprocessing.Resizing(224, 224, interpolation='nearest')(x)
|
|
|
|
x = keras.layers.Conv2D(1, 1, padding='same')(x)
|
|
x = keras.layers.BatchNormalization()(x)
|
|
x = keras.layers.ReLU(6.)(x)
|
|
return keras.Model(inputs=input, outputs=x, name="fast_depth")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
make_fastdepth_functional().summary()
|