diff --git a/unsupervised/models.py b/unsupervised/models.py index 4ef2331..fe5f603 100644 --- a/unsupervised/models.py +++ b/unsupervised/models.py @@ -19,7 +19,7 @@ def wrap_mobilenet_nnconv5_for_utrain(model): def res_layer(inputs, out_channels, down_sample=None, stride=1, normalisation=layers.BatchNormalization, - activation=layers.ReLU): + activation=layers.ReLU, name=None): x = layers.Conv2D(out_channels, 3, padding='same', strides=stride)(inputs) x = normalisation()(x) x = activation()(x) @@ -29,12 +29,12 @@ def res_layer(inputs, out_channels, down_sample=None, stride=1, normalisation=la if down_sample is not None: inputs = down_sample x = layers.Add()([x, inputs]) - x = activation()(x) + x = activation(name=name)(x) return x def res_block(inputs, out_channels, num_blocks=1, stride=1, normalisation=layers.BatchNormalization, - activation=layers.ReLU): + activation=layers.ReLU, name=None): down_sample = None if stride != 1 or inputs.shape[-1] != out_channels: down_sample = layers.Conv2D(out_channels, 1, stride, padding='same')(inputs) @@ -42,7 +42,7 @@ def res_block(inputs, out_channels, num_blocks=1, stride=1, normalisation=layers x = res_layer(inputs, out_channels, down_sample, stride, normalisation, activation) for i in range(1, num_blocks): - x = res_layer(x, out_channels, None, 1, normalisation, activation) + x = res_layer(x, out_channels, None, 1, normalisation, activation, name) return x @@ -55,18 +55,35 @@ def resnet_18(shape=(224, 224, 6)): inputs = layers.Input(shape) x = keras.layers.Conv2D(64, 7, 2, padding='same')(inputs) x = layers.BatchNormalization()(x) - x = layers.ReLU()(x) - x = layers.MaxPooling2D(3, 2, 'same')(x) - x = res_block(x, 64, 2) - x = res_block(x, 128, 2, 2) - x = res_block(x, 256, 2, 2) - x = res_block(x, 512, 2, 2) - return keras.Model(inputs=inputs, outputs=x) + skip_1 = layers.ReLU(name="res_1")(x) + x = layers.MaxPooling2D(3, 2, 'same')(skip_1) + skip_2 = res_block(x, 64, 2, name="res_2") + skip_3 = res_block(skip_2, 128, 2, 2, name="res_3") + skip_4 = res_block(skip_3, 256, 2, 2, name="res_4") + skip_5 = res_block(skip_4, 512, 2, 2, name="res_5") + return keras.Model(inputs=inputs, outputs=[skip_1, skip_2, skip_3, skip_4, skip_5]) +# TODO Monodepth and sfm learner both solve the posenet on all source images. So for the case of monodepth, it would need +# 9 as the input (3 images - target and 2 source images) and would produce 2 6DOF poses def pose_net(shape=(224, 224, 6), encoder=resnet_18): resnet = encoder(shape=shape) + for layer in resnet.layers: + layer.trainable = True + + # Concatenate every skip connection + cat_skips = [layers.ReLU()(layers.Conv2D(256, 1)(encode_output[-1])) for encode_output in resnet.outputs] + cat_skips = layers.Concatenate(1)(cat_skips) + + x = layers.Conv2D(256, 3, padding='same')(cat_skips) + x = layers.ReLU()(x) + + x = layers.Conv2D(256, 3, padding='same') + x = layers.ReLU()(x) + + x = layers.Conv2D(256, 12, 1)(x) + # Decoder pass