Files
fast-depth-tf/unsupervised/models.py
2021-08-03 20:25:19 +09:30

96 lines
3.5 KiB
Python

import tensorflow.keras as keras
import tensorflow.keras.layers as layers
def wrap_mobilenet_nnconv5_for_utrain(model):
"""
Wraps a fast depth model for use in unsupervised training.
This just exposes the lower disparity layers as outputs, so they can be used to train at different scales/image
resolutions.
:param model:
:return:
"""
input = model.input
disp_1 = model.get_layer('conv_pw_%d_relu' % 15).output
disp_2 = model.get_layer('conv_pw_%d_relu' % 16).output
disp_3 = model.get_layer('conv_pw_%d_relu' % 17).output
return keras.Model(input, outputs=[disp_1, disp_2, disp_3, model.output])
def res_layer(inputs, out_channels, down_sample=None, stride=1, normalisation=layers.BatchNormalization,
activation=layers.ReLU, name=None):
x = layers.Conv2D(out_channels, 3, padding='same', strides=stride)(inputs)
x = normalisation()(x)
x = activation()(x)
x = layers.Conv2D(out_channels, 3, padding='same', strides=1)(x)
x = normalisation()(x)
# Residual skip connection. Downsample inputs if necessary
if down_sample is not None:
inputs = down_sample
x = layers.Add()([x, inputs])
x = activation(name=name)(x)
return x
def res_block(inputs, out_channels, num_blocks=1, stride=1, normalisation=layers.BatchNormalization,
activation=layers.ReLU, name=None):
down_sample = None
if stride != 1 or inputs.shape[-1] != out_channels:
down_sample = layers.Conv2D(out_channels, 1, stride, padding='same')(inputs)
down_sample = normalisation()(down_sample)
x = res_layer(inputs, out_channels, down_sample, stride, normalisation, activation)
for i in range(1, num_blocks):
x = res_layer(x, out_channels, None, 1, normalisation, activation, name)
return x
def resnet_18(shape=(224, 224, 6)):
"""
Build the ResNet 18 network (encoder for the pose network)
:param shape: Input shape. Note this should support 2 images for the pose net, so 6 channels in that case
:return: Resnet encoder (ResNet18)
"""
inputs = layers.Input(shape)
x = keras.layers.Conv2D(64, 7, 2, padding='same')(inputs)
x = layers.BatchNormalization()(x)
skip_1 = layers.ReLU(name="res_1")(x)
x = layers.MaxPooling2D(3, 2, 'same')(skip_1)
skip_2 = res_block(x, 64, 2, name="res_2")
skip_3 = res_block(skip_2, 128, 2, 2, name="res_3")
skip_4 = res_block(skip_3, 256, 2, 2, name="res_4")
skip_5 = res_block(skip_4, 512, 2, 2, name="res_5")
return keras.Model(inputs=inputs, outputs=[skip_1, skip_2, skip_3, skip_4, skip_5])
# TODO Monodepth and sfm learner both solve the posenet on all source images. So for the case of monodepth, it would need
# 9 as the input (3 images - target and 2 source images) and would produce 2 6DOF poses
def pose_net(shape=(224, 224, 6), encoder=resnet_18):
resnet = encoder(shape=shape)
for layer in resnet.layers:
layer.trainable = True
# Concatenate every skip connection
cat_skips = [layers.ReLU()(layers.Conv2D(256, 1)(encode_output[-1])) for encode_output in resnet.outputs]
cat_skips = layers.Concatenate(1)(cat_skips)
x = layers.Conv2D(256, 3, padding='same')(cat_skips)
x = layers.ReLU()(x)
x = layers.Conv2D(256, 3, padding='same')
x = layers.ReLU()(x)
x = layers.Conv2D(256, 12, 1)(x)
# Decoder
pass
if __name__ == '__main__':
# import fast_depth_functional as fd
# wrap_mobilenet_nnconv5_for_utrain(fd.mobilenet_nnconv5()).summary()
pose_net()