import tensorflow as tf import tensorflow.keras as keras import tensorflow_datasets as tfds import fast_depth_functional as fd def dense_upsample_block(input, out_channels, skip_connection): """ Upsample block as described by dense depth in https://arxiv.org/pdf/1812.11941.pdf """ x = keras.layers.UpSampling2D(interpolation='bilinear')(input) x = keras.layers.Concatenate()([x, skip_connection]) x = keras.layers.Conv2D(filters=out_channels, kernel_size=3, strides=1, padding='same')(x) x = keras.layers.Conv2D(filters=out_channels, kernel_size=3, strides=1, padding='same')(x) return keras.layers.LeakyReLU(alpha=0.2)(x) def dense_depth(size, weights=None, shape=(224, 224, 3)): input = keras.layers.Input(shape=shape) densenet = dense_net(input, size, weights, shape) densenet_output_channels = densenet.layers[-1].output.shape[-1] # Reduce the feature set (pointwise) decoder = keras.layers.Conv2D(filters=densenet_output_channels, kernel_size=1, padding='same')(densenet.output) # The actual decoder decoder = dense_upsample_block( decoder, densenet_output_channels // 2, densenet.get_layer('pool3_pool').output) decoder = dense_upsample_block( decoder, densenet_output_channels // 4, densenet.get_layer('pool2_pool').output) decoder = dense_upsample_block( decoder, densenet_output_channels // 8, densenet.get_layer('pool1').output) decoder = dense_upsample_block( decoder, densenet_output_channels // 16, densenet.get_layer('conv1/relu').output) conv3 = keras.layers.Conv2D( filters=1, kernel_size=3, strides=1, padding='same', name='conv3')(decoder) return keras.Model(inputs=input, outputs=conv3, name='dense_depth') def dense_net(input, size, weights=None, shape=(224, 224, 3)): if size == 121: densenet = keras.applications.DenseNet121(input_tensor=input, input_shape=shape, weights=weights, include_top=False) elif size == 169: densenet = keras.applications.DenseNet169(input_tensor=input, input_shape=shape, weights=weights, include_top=False) else: densenet = keras.applications.DenseNet201(input_tensor=input, input_shape=shape, weights=weights, include_top=False) for layer in densenet.layers: layer.trainable = True return densenet def dense_nnconv5(size, weights=None, shape=(224, 224, 3), half_features=True): input = keras.layers.Input(shape=shape) densenet = dense_net(input, size, weights, shape) densenet_output_shape = densenet.layers[-1].output.shape if half_features: decode_filters = int(int(densenet_output_shape[-1]) / 2) else: decode_filters = int(densenet_output_shape[-1]) # Reduce the feature set (pointwise) x = keras.layers.Conv2D(filters=decode_filters, kernel_size=1, padding='same', input_shape=densenet_output_shape, name='conv2')(densenet.output) # TODO: More intermediate layers here? # Fast Depth Decoder x = fd.nnconv5(x, densenet.get_layer('pool3_pool').output_shape[3], 1, skip_connection=densenet.get_layer('pool3_pool').output) x = fd.nnconv5(x, densenet.get_layer('pool2_pool').output_shape[3], 2, skip_connection=densenet.get_layer('pool2_pool').output) x = fd.nnconv5(x, densenet.get_layer('pool1').output_shape[3], 3, skip_connection=densenet.get_layer('pool1').output) x = fd.nnconv5(x, densenet.get_layer('conv1/relu').output_shape[3], 4, skip_connection=densenet.get_layer('conv1/relu').output) # Final Pointwise for depth extraction x = keras.layers.Conv2D(1, 1, padding='same')(x) x = keras.layers.BatchNormalization()(x) x = keras.layers.ReLU(6.)(x) return keras.Model(inputs=input, outputs=x, name="fast_dense_depth") def crop_and_resize(x): shape = tf.shape(x['depth']) def layer(): return keras.Sequential([ keras.layers.experimental.preprocessing.CenterCrop( shape[1], shape[2]), keras.layers.experimental.preprocessing.Resizing( 224, 224, interpolation='nearest') ]) def half_layer(): return keras.Sequential([ keras.layers.experimental.preprocessing.CenterCrop( shape[1], shape[2]), keras.layers.experimental.preprocessing.Resizing( 112, 112, interpolation='nearest') ]) # Reshape label to 4d, can't use array unwrap as it's unsupported by tensorflow return layer()(x['image']), half_layer()(tf.reshape(x['depth'], [shape[0], shape[1], shape[2], 1])) def load_nyu(): """ Load the nyu_v2 dataset train split. Will be downloaded to ../nyu :return: nyu_v2 dataset builder """ builder = tfds.builder('nyu_depth_v2') builder.download_and_prepare(download_dir='../nyu') return builder \ .as_dataset(split='train', shuffle_files=True) \ .shuffle(buffer_size=1024) \ .batch(8) \ .map(lambda x: crop_and_resize(x)) def load_nyu_evaluate(): """ Load the nyu_v2 dataset validation split. Will be downloaded to ../nyu :return: nyu_v2 dataset builder """ builder = tfds.builder('nyu_depth_v2') builder.download_and_prepare(download_dir='../nyu') return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x)) if __name__ == '__main__': model = dense_depth(169, 'imagenet') model.summary() # model = dense_nnconv5(169, 'imagenet') # model.summary()