Add resnet18
This commit is contained in:
78
unsupervised/models.py
Normal file
78
unsupervised/models.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import tensorflow.keras as keras
|
||||
import tensorflow.keras.layers as layers
|
||||
|
||||
|
||||
def wrap_mobilenet_nnconv5_for_utrain(model):
|
||||
"""
|
||||
Wraps a fast depth model for use in unsupervised training.
|
||||
|
||||
This just exposes the lower disparity layers as outputs, so they can be used to train at different scales/image
|
||||
resolutions.
|
||||
:param model:
|
||||
:return:
|
||||
"""
|
||||
input = model.input
|
||||
disp_1 = model.get_layer('conv_pw_%d_relu' % 15).output
|
||||
disp_2 = model.get_layer('conv_pw_%d_relu' % 16).output
|
||||
disp_3 = model.get_layer('conv_pw_%d_relu' % 17).output
|
||||
return keras.Model(input, outputs=[disp_1, disp_2, disp_3, model.output])
|
||||
|
||||
|
||||
def res_layer(inputs, out_channels, down_sample=None, stride=1, normalisation=layers.BatchNormalization,
|
||||
activation=layers.ReLU):
|
||||
x = layers.Conv2D(out_channels, 3, padding='same', strides=stride)(inputs)
|
||||
x = normalisation()(x)
|
||||
x = activation()(x)
|
||||
x = layers.Conv2D(out_channels, 3, padding='same', strides=1)(x)
|
||||
x = normalisation()(x)
|
||||
# Residual skip connection. Downsample inputs if necessary
|
||||
if down_sample is not None:
|
||||
inputs = down_sample
|
||||
x = layers.Add()([x, inputs])
|
||||
x = activation()(x)
|
||||
return x
|
||||
|
||||
|
||||
def res_block(inputs, out_channels, num_blocks=1, stride=1, normalisation=layers.BatchNormalization,
|
||||
activation=layers.ReLU):
|
||||
down_sample = None
|
||||
if stride != 1 or inputs.shape[-1] != out_channels:
|
||||
down_sample = layers.Conv2D(out_channels, 1, stride, padding='same')(inputs)
|
||||
down_sample = normalisation()(down_sample)
|
||||
|
||||
x = res_layer(inputs, out_channels, down_sample, stride, normalisation, activation)
|
||||
for i in range(1, num_blocks):
|
||||
x = res_layer(x, out_channels, None, 1, normalisation, activation)
|
||||
return x
|
||||
|
||||
|
||||
def resnet_18(shape=(224, 224, 6)):
|
||||
"""
|
||||
Build the ResNet 18 network (encoder for the pose network)
|
||||
:param shape: Input shape. Note this should support 2 images for the pose net, so 6 channels in that case
|
||||
:return: Resnet encoder (ResNet18)
|
||||
"""
|
||||
inputs = layers.Input(shape)
|
||||
x = keras.layers.Conv2D(64, 7, 2, padding='same')(inputs)
|
||||
x = layers.BatchNormalization()(x)
|
||||
x = layers.ReLU()(x)
|
||||
x = layers.MaxPooling2D(3, 2, 'same')(x)
|
||||
x = res_block(x, 64, 2)
|
||||
x = res_block(x, 128, 2, 2)
|
||||
x = res_block(x, 256, 2, 2)
|
||||
x = res_block(x, 512, 2, 2)
|
||||
return keras.Model(inputs=inputs, outputs=x)
|
||||
|
||||
|
||||
def pose_net(shape=(224, 224, 6), encoder=resnet_18):
|
||||
resnet = encoder(shape=shape)
|
||||
|
||||
# Decoder
|
||||
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# import fast_depth_functional as fd
|
||||
# wrap_mobilenet_nnconv5_for_utrain(fd.mobilenet_nnconv5()).summary()
|
||||
pose_net()
|
||||
Reference in New Issue
Block a user