Add dense-depth and experimental dense-net-nnconv5 models
Since dense-depth will use half labels by default, the nyu train/eval datasets can be loaded from here at half resolutions for labels
This commit is contained in:
153
dense_depth_functional.py
Normal file
153
dense_depth_functional.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras as keras
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
import fast_depth_functional as fd
|
||||
|
||||
|
||||
def dense_upproject(input, out_channels, skip_connection):
|
||||
x = keras.layers.UpSampling2D(interpolation='bilinear')(input)
|
||||
x = keras.layers.Concatenate()([x, skip_connection])
|
||||
x = keras.layers.Conv2D(filters=out_channels,
|
||||
kernel_size=3, strides=1, padding='same')(x)
|
||||
x = keras.layers.LeakyReLU(alpha=0.2)(x)
|
||||
x = keras.layers.Conv2D(filters=out_channels,
|
||||
kernel_size=3, strides=1, padding='same')(x)
|
||||
return keras.layers.LeakyReLU(alpha=0.2)(x)
|
||||
|
||||
|
||||
def dense_depth(size, weights=None, shape=(224, 224, 3), half_features=True):
|
||||
input = keras.layers.Input(shape=shape)
|
||||
densenet = dense_net(input, size, weights, shape)
|
||||
densenet_output_shape = densenet.layers[-1].output.shape
|
||||
|
||||
if half_features:
|
||||
decode_filters = densenet_output_shape[-1] // 2
|
||||
else:
|
||||
decode_filters = int(densenet_output_shape[-1])
|
||||
|
||||
# Reduce the feature set (pointwise)
|
||||
decoder = keras.layers.Conv2D(filters=decode_filters, kernel_size=1, padding='same',
|
||||
input_shape=densenet_output_shape, name='conv2')(densenet.output)
|
||||
|
||||
# The actual decoder
|
||||
decoder = dense_upproject(
|
||||
decoder, decode_filters // 2, densenet.get_layer('pool3_pool').output)
|
||||
decoder = dense_upproject(
|
||||
decoder, decode_filters // 4, densenet.get_layer('pool2_pool').output)
|
||||
decoder = dense_upproject(
|
||||
decoder, decode_filters // 8, densenet.get_layer('pool1').output)
|
||||
decoder = dense_upproject(
|
||||
decoder, decode_filters // 16, densenet.get_layer('conv1/relu').output)
|
||||
# Enable to upproject to full image size
|
||||
# decoder = dense_upproject(decoder, int(decode_filters / 32), input)
|
||||
|
||||
conv3 = keras.layers.Conv2D(
|
||||
filters=1, kernel_size=3, strides=1, padding='same', name='conv3')(decoder)
|
||||
return keras.Model(inputs=input, outputs=conv3, name='dense_depth')
|
||||
|
||||
|
||||
def dense_net(input, size, weights=None, shape=(224, 224, 3)):
|
||||
if size == 121:
|
||||
densenet = keras.applications.DenseNet121(input_tensor=input, input_shape=shape, weights=weights,
|
||||
include_top=False)
|
||||
elif size == 169:
|
||||
densenet = keras.applications.DenseNet169(input_tensor=input, input_shape=shape, weights=weights,
|
||||
include_top=False)
|
||||
else:
|
||||
densenet = keras.applications.DenseNet201(input_tensor=input, input_shape=shape, weights=weights,
|
||||
include_top=False)
|
||||
|
||||
for layer in densenet.layers:
|
||||
layer.trainable = True
|
||||
|
||||
return densenet
|
||||
|
||||
|
||||
def dense_nnconv5(size, weights=None, shape=(224, 224, 3), half_features=True):
|
||||
input = keras.layers.Input(shape=shape)
|
||||
densenet = dense_net(input, size, weights, shape)
|
||||
densenet_output_shape = densenet.layers[-1].output.shape
|
||||
|
||||
if half_features:
|
||||
decode_filters = int(int(densenet_output_shape[-1]) / 2)
|
||||
else:
|
||||
decode_filters = int(densenet_output_shape[-1])
|
||||
|
||||
# Reduce the feature set (pointwise)
|
||||
x = keras.layers.Conv2D(filters=decode_filters, kernel_size=1, padding='same',
|
||||
input_shape=densenet_output_shape, name='conv2')(densenet.output)
|
||||
|
||||
# TODO: More intermediate layers here?
|
||||
|
||||
# Fast Depth Decoder
|
||||
x = fd.nnconv5(x, densenet.get_layer('pool3_pool').output_shape[3], 1,
|
||||
skip_connection=densenet.get_layer('pool3_pool').output)
|
||||
x = fd.nnconv5(x, densenet.get_layer('pool2_pool').output_shape[3], 2,
|
||||
skip_connection=densenet.get_layer('pool2_pool').output)
|
||||
x = fd.nnconv5(x, densenet.get_layer('pool1').output_shape[3], 3,
|
||||
skip_connection=densenet.get_layer('pool1').output)
|
||||
x = fd.nnconv5(x, densenet.get_layer('conv1/relu').output_shape[3], 4,
|
||||
skip_connection=densenet.get_layer('conv1/relu').output)
|
||||
# Enable to get full dense decode (back to original size)
|
||||
# x = fd.nnconv5(x, int(densenet.get_layer('conv1/relu').output_shape[3] / 2), 5)
|
||||
|
||||
# Final Pointwise for depth extraction
|
||||
x = keras.layers.Conv2D(1, 1, padding='same')(x)
|
||||
x = keras.layers.BatchNormalization()(x)
|
||||
x = keras.layers.ReLU(6.)(x)
|
||||
return keras.Model(inputs=input, outputs=x, name="fast_dense_depth")
|
||||
|
||||
|
||||
def crop_and_resize(x):
|
||||
shape = tf.shape(x['depth'])
|
||||
|
||||
def layer():
|
||||
return keras.Sequential([
|
||||
keras.layers.experimental.preprocessing.CenterCrop(
|
||||
shape[1], shape[2]),
|
||||
keras.layers.experimental.preprocessing.Resizing(
|
||||
224, 224, interpolation='nearest')
|
||||
])
|
||||
|
||||
def half_layer():
|
||||
return keras.Sequential([
|
||||
keras.layers.experimental.preprocessing.CenterCrop(
|
||||
shape[1], shape[2]),
|
||||
keras.layers.experimental.preprocessing.Resizing(
|
||||
112, 112, interpolation='nearest')
|
||||
])
|
||||
|
||||
# Reshape label to 4d, can't use array unwrap as it's unsupported by tensorflow
|
||||
return layer()(x['image']), half_layer()(tf.reshape(x['depth'], [shape[0], shape[1], shape[2], 1]))
|
||||
|
||||
|
||||
def load_nyu():
|
||||
"""
|
||||
Load the nyu_v2 dataset train split. Will be downloaded to ../nyu
|
||||
:return: nyu_v2 dataset builder
|
||||
"""
|
||||
builder = tfds.builder('nyu_depth_v2')
|
||||
builder.download_and_prepare(download_dir='../nyu')
|
||||
return builder \
|
||||
.as_dataset(split='train', shuffle_files=True) \
|
||||
.shuffle(buffer_size=1024) \
|
||||
.batch(8) \
|
||||
.map(lambda x: crop_and_resize(x))
|
||||
|
||||
|
||||
def load_nyu_evaluate():
|
||||
"""
|
||||
Load the nyu_v2 dataset validation split. Will be downloaded to ../nyu
|
||||
:return: nyu_v2 dataset builder
|
||||
"""
|
||||
builder = tfds.builder('nyu_depth_v2')
|
||||
builder.download_and_prepare(download_dir='../nyu')
|
||||
return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = dense_depth(169, 'imagenet')
|
||||
model.summary()
|
||||
# model = dense_nnconv5(169, 'imagenet')
|
||||
# model.summary()
|
||||
22
losses.py
Normal file
22
losses.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras.backend as K
|
||||
|
||||
|
||||
def dense_depth_loss_function(y_true, y_pred, theta=0.1, maxDepthVal=1000.0 / 10.0):
|
||||
# Point-wise depth
|
||||
l_depth = K.mean(K.abs(y_pred - y_true), axis=-1)
|
||||
|
||||
# Edges
|
||||
dy_true, dx_true = tf.image.image_gradients(y_true)
|
||||
dy_pred, dx_pred = tf.image.image_gradients(y_pred)
|
||||
l_edges = K.mean(K.abs(dy_pred - dy_true) + K.abs(dx_pred - dx_true), axis=-1)
|
||||
|
||||
# Structural similarity (SSIM) index
|
||||
l_ssim = K.clip((1 - tf.image.ssim(y_true, y_pred, maxDepthVal)) * 0.5, 0, 1)
|
||||
|
||||
# Weights
|
||||
w1 = 1.0
|
||||
w2 = 1.0
|
||||
w3 = theta
|
||||
|
||||
return (w1 * l_ssim) + (w2 * K.mean(l_edges)) + (w3 * K.mean(l_depth))
|
||||
Reference in New Issue
Block a user