import tensorflow.keras as keras import tensorflow.keras.layers as layers def wrap_mobilenet_nnconv5_for_utrain(model): """ Wraps a fast depth model for use in unsupervised training. This just exposes the lower disparity layers as outputs, so they can be used to train at different scales/image resolutions. :param model: :return: """ input = model.input disp_1 = model.get_layer('conv_pw_%d_relu' % 15).output disp_2 = model.get_layer('conv_pw_%d_relu' % 16).output disp_3 = model.get_layer('conv_pw_%d_relu' % 17).output return keras.Model(input, outputs=[disp_1, disp_2, disp_3, model.output]) def res_layer(inputs, out_channels, down_sample=None, stride=1, normalisation=layers.BatchNormalization, activation=layers.ReLU): x = layers.Conv2D(out_channels, 3, padding='same', strides=stride)(inputs) x = normalisation()(x) x = activation()(x) x = layers.Conv2D(out_channels, 3, padding='same', strides=1)(x) x = normalisation()(x) # Residual skip connection. Downsample inputs if necessary if down_sample is not None: inputs = down_sample x = layers.Add()([x, inputs]) x = activation()(x) return x def res_block(inputs, out_channels, num_blocks=1, stride=1, normalisation=layers.BatchNormalization, activation=layers.ReLU): down_sample = None if stride != 1 or inputs.shape[-1] != out_channels: down_sample = layers.Conv2D(out_channels, 1, stride, padding='same')(inputs) down_sample = normalisation()(down_sample) x = res_layer(inputs, out_channels, down_sample, stride, normalisation, activation) for i in range(1, num_blocks): x = res_layer(x, out_channels, None, 1, normalisation, activation) return x def resnet_18(shape=(224, 224, 6)): """ Build the ResNet 18 network (encoder for the pose network) :param shape: Input shape. Note this should support 2 images for the pose net, so 6 channels in that case :return: Resnet encoder (ResNet18) """ inputs = layers.Input(shape) x = keras.layers.Conv2D(64, 7, 2, padding='same')(inputs) x = layers.BatchNormalization()(x) x = layers.ReLU()(x) x = layers.MaxPooling2D(3, 2, 'same')(x) x = res_block(x, 64, 2) x = res_block(x, 128, 2, 2) x = res_block(x, 256, 2, 2) x = res_block(x, 512, 2, 2) return keras.Model(inputs=inputs, outputs=x) def pose_net(shape=(224, 224, 6), encoder=resnet_18): resnet = encoder(shape=shape) # Decoder pass if __name__ == '__main__': # import fast_depth_functional as fd # wrap_mobilenet_nnconv5_for_utrain(fd.mobilenet_nnconv5()).summary() pose_net()