diff --git a/unsupervised/models.py b/unsupervised/models.py index fe5f603..7e56d09 100644 --- a/unsupervised/models.py +++ b/unsupervised/models.py @@ -1,3 +1,4 @@ +import tensorflow as tf import tensorflow.keras as keras import tensorflow.keras.layers as layers @@ -55,17 +56,16 @@ def resnet_18(shape=(224, 224, 6)): inputs = layers.Input(shape) x = keras.layers.Conv2D(64, 7, 2, padding='same')(inputs) x = layers.BatchNormalization()(x) - skip_1 = layers.ReLU(name="res_1")(x) - x = layers.MaxPooling2D(3, 2, 'same')(skip_1) - skip_2 = res_block(x, 64, 2, name="res_2") - skip_3 = res_block(skip_2, 128, 2, 2, name="res_3") - skip_4 = res_block(skip_3, 256, 2, 2, name="res_4") - skip_5 = res_block(skip_4, 512, 2, 2, name="res_5") - return keras.Model(inputs=inputs, outputs=[skip_1, skip_2, skip_3, skip_4, skip_5]) + x = layers.ReLU(name="res_1")(x) + x = layers.MaxPooling2D(3, 2, 'same')(x) + x = res_block(x, 64, 2, name="res_2") + x = res_block(x, 128, 2, 2, name="res_3") + x = res_block(x, 256, 2, 2, name="res_4") + x = res_block(x, 512, 2, 2, name="res_5") + # Note: Skips aren't used by pose, only depth + return keras.Model(inputs=inputs, outputs=x) -# TODO Monodepth and sfm learner both solve the posenet on all source images. So for the case of monodepth, it would need -# 9 as the input (3 images - target and 2 source images) and would produce 2 6DOF poses def pose_net(shape=(224, 224, 6), encoder=resnet_18): resnet = encoder(shape=shape) @@ -73,20 +73,23 @@ def pose_net(shape=(224, 224, 6), encoder=resnet_18): layer.trainable = True # Concatenate every skip connection - cat_skips = [layers.ReLU()(layers.Conv2D(256, 1)(encode_output[-1])) for encode_output in resnet.outputs] - cat_skips = layers.Concatenate(1)(cat_skips) - - x = layers.Conv2D(256, 3, padding='same')(cat_skips) + # Note: Monodepth only uses output of resnet + x = layers.Conv2D(256, 1)(resnet.output) x = layers.ReLU()(x) - x = layers.Conv2D(256, 3, padding='same') + x = layers.Conv2D(256, 3, padding='same')(x) x = layers.ReLU()(x) - x = layers.Conv2D(256, 12, 1)(x) + x = layers.Conv2D(256, 3, padding='same')(x) + x = layers.ReLU()(x) - # Decoder + # The magic pose step + x = layers.Conv2D(6, 1, 1)(x) - pass + x = tf.reduce_mean(x, [1, 2]) + # Previous works scale by 0.01 to facilitate training + x = 0.01 * layers.Reshape([6])(x) + return keras.Model(resnet.input, x) if __name__ == '__main__':