Start adding pose decoder
This commit is contained in:
@@ -19,7 +19,7 @@ def wrap_mobilenet_nnconv5_for_utrain(model):
|
|||||||
|
|
||||||
|
|
||||||
def res_layer(inputs, out_channels, down_sample=None, stride=1, normalisation=layers.BatchNormalization,
|
def res_layer(inputs, out_channels, down_sample=None, stride=1, normalisation=layers.BatchNormalization,
|
||||||
activation=layers.ReLU):
|
activation=layers.ReLU, name=None):
|
||||||
x = layers.Conv2D(out_channels, 3, padding='same', strides=stride)(inputs)
|
x = layers.Conv2D(out_channels, 3, padding='same', strides=stride)(inputs)
|
||||||
x = normalisation()(x)
|
x = normalisation()(x)
|
||||||
x = activation()(x)
|
x = activation()(x)
|
||||||
@@ -29,12 +29,12 @@ def res_layer(inputs, out_channels, down_sample=None, stride=1, normalisation=la
|
|||||||
if down_sample is not None:
|
if down_sample is not None:
|
||||||
inputs = down_sample
|
inputs = down_sample
|
||||||
x = layers.Add()([x, inputs])
|
x = layers.Add()([x, inputs])
|
||||||
x = activation()(x)
|
x = activation(name=name)(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def res_block(inputs, out_channels, num_blocks=1, stride=1, normalisation=layers.BatchNormalization,
|
def res_block(inputs, out_channels, num_blocks=1, stride=1, normalisation=layers.BatchNormalization,
|
||||||
activation=layers.ReLU):
|
activation=layers.ReLU, name=None):
|
||||||
down_sample = None
|
down_sample = None
|
||||||
if stride != 1 or inputs.shape[-1] != out_channels:
|
if stride != 1 or inputs.shape[-1] != out_channels:
|
||||||
down_sample = layers.Conv2D(out_channels, 1, stride, padding='same')(inputs)
|
down_sample = layers.Conv2D(out_channels, 1, stride, padding='same')(inputs)
|
||||||
@@ -42,7 +42,7 @@ def res_block(inputs, out_channels, num_blocks=1, stride=1, normalisation=layers
|
|||||||
|
|
||||||
x = res_layer(inputs, out_channels, down_sample, stride, normalisation, activation)
|
x = res_layer(inputs, out_channels, down_sample, stride, normalisation, activation)
|
||||||
for i in range(1, num_blocks):
|
for i in range(1, num_blocks):
|
||||||
x = res_layer(x, out_channels, None, 1, normalisation, activation)
|
x = res_layer(x, out_channels, None, 1, normalisation, activation, name)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -55,18 +55,35 @@ def resnet_18(shape=(224, 224, 6)):
|
|||||||
inputs = layers.Input(shape)
|
inputs = layers.Input(shape)
|
||||||
x = keras.layers.Conv2D(64, 7, 2, padding='same')(inputs)
|
x = keras.layers.Conv2D(64, 7, 2, padding='same')(inputs)
|
||||||
x = layers.BatchNormalization()(x)
|
x = layers.BatchNormalization()(x)
|
||||||
x = layers.ReLU()(x)
|
skip_1 = layers.ReLU(name="res_1")(x)
|
||||||
x = layers.MaxPooling2D(3, 2, 'same')(x)
|
x = layers.MaxPooling2D(3, 2, 'same')(skip_1)
|
||||||
x = res_block(x, 64, 2)
|
skip_2 = res_block(x, 64, 2, name="res_2")
|
||||||
x = res_block(x, 128, 2, 2)
|
skip_3 = res_block(skip_2, 128, 2, 2, name="res_3")
|
||||||
x = res_block(x, 256, 2, 2)
|
skip_4 = res_block(skip_3, 256, 2, 2, name="res_4")
|
||||||
x = res_block(x, 512, 2, 2)
|
skip_5 = res_block(skip_4, 512, 2, 2, name="res_5")
|
||||||
return keras.Model(inputs=inputs, outputs=x)
|
return keras.Model(inputs=inputs, outputs=[skip_1, skip_2, skip_3, skip_4, skip_5])
|
||||||
|
|
||||||
|
|
||||||
|
# TODO Monodepth and sfm learner both solve the posenet on all source images. So for the case of monodepth, it would need
|
||||||
|
# 9 as the input (3 images - target and 2 source images) and would produce 2 6DOF poses
|
||||||
def pose_net(shape=(224, 224, 6), encoder=resnet_18):
|
def pose_net(shape=(224, 224, 6), encoder=resnet_18):
|
||||||
resnet = encoder(shape=shape)
|
resnet = encoder(shape=shape)
|
||||||
|
|
||||||
|
for layer in resnet.layers:
|
||||||
|
layer.trainable = True
|
||||||
|
|
||||||
|
# Concatenate every skip connection
|
||||||
|
cat_skips = [layers.ReLU()(layers.Conv2D(256, 1)(encode_output[-1])) for encode_output in resnet.outputs]
|
||||||
|
cat_skips = layers.Concatenate(1)(cat_skips)
|
||||||
|
|
||||||
|
x = layers.Conv2D(256, 3, padding='same')(cat_skips)
|
||||||
|
x = layers.ReLU()(x)
|
||||||
|
|
||||||
|
x = layers.Conv2D(256, 3, padding='same')
|
||||||
|
x = layers.ReLU()(x)
|
||||||
|
|
||||||
|
x = layers.Conv2D(256, 12, 1)(x)
|
||||||
|
|
||||||
# Decoder
|
# Decoder
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user