Finish off pose net
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import tensorflow as tf
|
||||||
import tensorflow.keras as keras
|
import tensorflow.keras as keras
|
||||||
import tensorflow.keras.layers as layers
|
import tensorflow.keras.layers as layers
|
||||||
|
|
||||||
@@ -55,17 +56,16 @@ def resnet_18(shape=(224, 224, 6)):
|
|||||||
inputs = layers.Input(shape)
|
inputs = layers.Input(shape)
|
||||||
x = keras.layers.Conv2D(64, 7, 2, padding='same')(inputs)
|
x = keras.layers.Conv2D(64, 7, 2, padding='same')(inputs)
|
||||||
x = layers.BatchNormalization()(x)
|
x = layers.BatchNormalization()(x)
|
||||||
skip_1 = layers.ReLU(name="res_1")(x)
|
x = layers.ReLU(name="res_1")(x)
|
||||||
x = layers.MaxPooling2D(3, 2, 'same')(skip_1)
|
x = layers.MaxPooling2D(3, 2, 'same')(x)
|
||||||
skip_2 = res_block(x, 64, 2, name="res_2")
|
x = res_block(x, 64, 2, name="res_2")
|
||||||
skip_3 = res_block(skip_2, 128, 2, 2, name="res_3")
|
x = res_block(x, 128, 2, 2, name="res_3")
|
||||||
skip_4 = res_block(skip_3, 256, 2, 2, name="res_4")
|
x = res_block(x, 256, 2, 2, name="res_4")
|
||||||
skip_5 = res_block(skip_4, 512, 2, 2, name="res_5")
|
x = res_block(x, 512, 2, 2, name="res_5")
|
||||||
return keras.Model(inputs=inputs, outputs=[skip_1, skip_2, skip_3, skip_4, skip_5])
|
# Note: Skips aren't used by pose, only depth
|
||||||
|
return keras.Model(inputs=inputs, outputs=x)
|
||||||
|
|
||||||
|
|
||||||
# TODO Monodepth and sfm learner both solve the posenet on all source images. So for the case of monodepth, it would need
|
|
||||||
# 9 as the input (3 images - target and 2 source images) and would produce 2 6DOF poses
|
|
||||||
def pose_net(shape=(224, 224, 6), encoder=resnet_18):
|
def pose_net(shape=(224, 224, 6), encoder=resnet_18):
|
||||||
resnet = encoder(shape=shape)
|
resnet = encoder(shape=shape)
|
||||||
|
|
||||||
@@ -73,20 +73,23 @@ def pose_net(shape=(224, 224, 6), encoder=resnet_18):
|
|||||||
layer.trainable = True
|
layer.trainable = True
|
||||||
|
|
||||||
# Concatenate every skip connection
|
# Concatenate every skip connection
|
||||||
cat_skips = [layers.ReLU()(layers.Conv2D(256, 1)(encode_output[-1])) for encode_output in resnet.outputs]
|
# Note: Monodepth only uses output of resnet
|
||||||
cat_skips = layers.Concatenate(1)(cat_skips)
|
x = layers.Conv2D(256, 1)(resnet.output)
|
||||||
|
|
||||||
x = layers.Conv2D(256, 3, padding='same')(cat_skips)
|
|
||||||
x = layers.ReLU()(x)
|
x = layers.ReLU()(x)
|
||||||
|
|
||||||
x = layers.Conv2D(256, 3, padding='same')
|
x = layers.Conv2D(256, 3, padding='same')(x)
|
||||||
x = layers.ReLU()(x)
|
x = layers.ReLU()(x)
|
||||||
|
|
||||||
x = layers.Conv2D(256, 12, 1)(x)
|
x = layers.Conv2D(256, 3, padding='same')(x)
|
||||||
|
x = layers.ReLU()(x)
|
||||||
|
|
||||||
# Decoder
|
# The magic pose step
|
||||||
|
x = layers.Conv2D(6, 1, 1)(x)
|
||||||
|
|
||||||
pass
|
x = tf.reduce_mean(x, [1, 2])
|
||||||
|
# Previous works scale by 0.01 to facilitate training
|
||||||
|
x = 0.01 * layers.Reshape([6])(x)
|
||||||
|
return keras.Model(resnet.input, x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user