From 2372b906df66cb1c4a2192f3e2b407c0e4a3a029 Mon Sep 17 00:00:00 2001 From: Piv <18462828+Piv200@users.noreply.github.com> Date: Sun, 1 Aug 2021 10:44:33 +0930 Subject: [PATCH] Add resnet18 --- unsupervised/models.py | 78 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 unsupervised/models.py diff --git a/unsupervised/models.py b/unsupervised/models.py new file mode 100644 index 0000000..4ef2331 --- /dev/null +++ b/unsupervised/models.py @@ -0,0 +1,78 @@ +import tensorflow.keras as keras +import tensorflow.keras.layers as layers + + +def wrap_mobilenet_nnconv5_for_utrain(model): + """ + Wraps a fast depth model for use in unsupervised training. + + This just exposes the lower disparity layers as outputs, so they can be used to train at different scales/image + resolutions. + :param model: + :return: + """ + input = model.input + disp_1 = model.get_layer('conv_pw_%d_relu' % 15).output + disp_2 = model.get_layer('conv_pw_%d_relu' % 16).output + disp_3 = model.get_layer('conv_pw_%d_relu' % 17).output + return keras.Model(input, outputs=[disp_1, disp_2, disp_3, model.output]) + + +def res_layer(inputs, out_channels, down_sample=None, stride=1, normalisation=layers.BatchNormalization, + activation=layers.ReLU): + x = layers.Conv2D(out_channels, 3, padding='same', strides=stride)(inputs) + x = normalisation()(x) + x = activation()(x) + x = layers.Conv2D(out_channels, 3, padding='same', strides=1)(x) + x = normalisation()(x) + # Residual skip connection. Downsample inputs if necessary + if down_sample is not None: + inputs = down_sample + x = layers.Add()([x, inputs]) + x = activation()(x) + return x + + +def res_block(inputs, out_channels, num_blocks=1, stride=1, normalisation=layers.BatchNormalization, + activation=layers.ReLU): + down_sample = None + if stride != 1 or inputs.shape[-1] != out_channels: + down_sample = layers.Conv2D(out_channels, 1, stride, padding='same')(inputs) + down_sample = normalisation()(down_sample) + + x = res_layer(inputs, out_channels, down_sample, stride, normalisation, activation) + for i in range(1, num_blocks): + x = res_layer(x, out_channels, None, 1, normalisation, activation) + return x + + +def resnet_18(shape=(224, 224, 6)): + """ + Build the ResNet 18 network (encoder for the pose network) + :param shape: Input shape. Note this should support 2 images for the pose net, so 6 channels in that case + :return: Resnet encoder (ResNet18) + """ + inputs = layers.Input(shape) + x = keras.layers.Conv2D(64, 7, 2, padding='same')(inputs) + x = layers.BatchNormalization()(x) + x = layers.ReLU()(x) + x = layers.MaxPooling2D(3, 2, 'same')(x) + x = res_block(x, 64, 2) + x = res_block(x, 128, 2, 2) + x = res_block(x, 256, 2, 2) + x = res_block(x, 512, 2, 2) + return keras.Model(inputs=inputs, outputs=x) + + +def pose_net(shape=(224, 224, 6), encoder=resnet_18): + resnet = encoder(shape=shape) + + # Decoder + + pass + + +if __name__ == '__main__': + # import fast_depth_functional as fd + # wrap_mobilenet_nnconv5_for_utrain(fd.mobilenet_nnconv5()).summary() + pose_net()