Found from DenseDepth, each layer can be set to trainable in the encoder, then the outputs of the model and the required layers for skip connections can be used directly. Ends up being much cleaner
197 lines
8.2 KiB
Python
197 lines
8.2 KiB
Python
import tensorflow as tf
|
|
import tensorflow.keras as keras
|
|
import tensorflow_datasets as tfds
|
|
|
|
|
|
# Ripped from: https://forums.developer.nvidia.com/t/could-not-create-cudnn-handle-cudnn-status-alloc-failed/108261/4?u=mpivato4
|
|
# Seems to be an issue on windows so explicitly set gpu growth
|
|
def fix_windows_gpu():
|
|
gpus = tf.config.experimental.list_physical_devices('GPU')
|
|
if gpus:
|
|
try:
|
|
# Currently, memory growth needs to be the same across GPUs
|
|
for gpu in gpus:
|
|
tf.config.experimental.set_memory_growth(gpu, True)
|
|
logical_gpus = tf.config.experimental.list_logical_devices('GPU')
|
|
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
|
|
except RuntimeError as e:
|
|
# Memory growth must be set before GPUs have been initialized
|
|
print(e)
|
|
|
|
|
|
'''
|
|
Functional version of fastdepth model
|
|
'''
|
|
|
|
|
|
def FDDepthwiseBlock(inputs,
|
|
out_channels,
|
|
block_id=1):
|
|
x = keras.layers.DepthwiseConv2D(5, padding='same')(inputs)
|
|
x = keras.layers.BatchNormalization()(x)
|
|
x = keras.layers.ReLU(6.)(x)
|
|
x = keras.layers.Conv2D(out_channels, 1, padding='same')(x)
|
|
x = keras.layers.BatchNormalization()(x)
|
|
return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x)
|
|
|
|
|
|
def FDDepthwiseBlockNoBN(inputs, out_channels, block_id=1):
|
|
x = keras.layers.DepthwiseConv2D(5, padding='same')(inputs)
|
|
x = keras.layers.ReLU(6.)(x)
|
|
x = keras.layers.Conv2D(out_channels, 1, padding='same')(x)
|
|
return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x)
|
|
|
|
|
|
def make_mobilenet_nnconv5(weights=None, shape=(224, 224, 3)):
|
|
input = keras.layers.Input(shape=shape)
|
|
mobilenet = keras.applications.MobileNet(input_tensor=input, include_top=False, weights=weights)
|
|
for layer in mobilenet.layers:
|
|
layer.trainable = True
|
|
|
|
# Fast depth decoder
|
|
x = FDDepthwiseBlock(mobilenet.output, 512, block_id=14)
|
|
# TODO: Bilinear interpolation
|
|
# x = keras.layers.experimental.preprocessing.Resizing(14, 14, interpolation='bilinear')
|
|
# Nearest neighbour interpolation, used by fast depth paper
|
|
x = keras.layers.experimental.preprocessing.Resizing(14, 14, interpolation='nearest')(x)
|
|
x = FDDepthwiseBlock(x, 256, block_id=15)
|
|
x = keras.layers.experimental.preprocessing.Resizing(28, 28, interpolation='nearest')(x)
|
|
x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_5_relu").output])
|
|
x = FDDepthwiseBlock(x, 128, block_id=16)
|
|
x = keras.layers.experimental.preprocessing.Resizing(56, 56, interpolation='nearest')(x)
|
|
x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_3_relu").output])
|
|
x = FDDepthwiseBlock(x, 64, block_id=17)
|
|
x = keras.layers.experimental.preprocessing.Resizing(112, 112, interpolation='nearest')(x)
|
|
x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_1_relu").output])
|
|
x = FDDepthwiseBlock(x, 32, block_id=18)
|
|
x = keras.layers.experimental.preprocessing.Resizing(224, 224, interpolation='nearest')(x)
|
|
|
|
x = keras.layers.Conv2D(1, 1, padding='same')(x)
|
|
x = keras.layers.BatchNormalization()(x)
|
|
x = keras.layers.ReLU(6.)(x)
|
|
return keras.Model(inputs=input, outputs=x, name="fast_depth")
|
|
|
|
|
|
def make_mobilenet_nnconv5_no_bn(weights=None, shape=(224, 224, 3)):
|
|
input = keras.layers.Input(shape=shape)
|
|
mobilenet = keras.applications.MobileNet(input_tensor=input, include_top=False, weights=weights)
|
|
for layer in mobilenet.layers:
|
|
layer.trainable = True
|
|
|
|
# Fast depth decoder
|
|
x = FDDepthwiseBlockNoBN(mobilenet.output, 512, block_id=14)
|
|
# Nearest neighbour interpolation, used by fast depth paper
|
|
x = keras.layers.experimental.preprocessing.Resizing(14, 14, interpolation='bilinear')(x)
|
|
x = FDDepthwiseBlockNoBN(x, 256, block_id=15)
|
|
x = keras.layers.experimental.preprocessing.Resizing(28, 28, interpolation='bilinear')(x)
|
|
x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_5_relu").output])
|
|
x = FDDepthwiseBlockNoBN(x, 128, block_id=16)
|
|
x = keras.layers.experimental.preprocessing.Resizing(56, 56, interpolation='bilinear')(x)
|
|
x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_3_relu").output])
|
|
x = FDDepthwiseBlockNoBN(x, 64, block_id=17)
|
|
x = keras.layers.experimental.preprocessing.Resizing(112, 112, interpolation='bilinear')(x)
|
|
x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_1_relu").output])
|
|
x = FDDepthwiseBlockNoBN(x, 32, block_id=18)
|
|
x = keras.layers.experimental.preprocessing.Resizing(224, 224, interpolation='bilinear')(x)
|
|
|
|
x = keras.layers.Conv2D(1, 1, padding='same')(x)
|
|
x = keras.layers.ReLU(6.)(x)
|
|
return keras.Model(inputs=input, outputs=x, name="fast_depth")
|
|
|
|
|
|
# TODO: Fix these, float doesn't work same as pytorch
|
|
def delta1_metric(y_true, y_pred):
|
|
maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred)
|
|
return tf.nn.moments(tf.cast(maxRatio < tf.convert_to_tensor(1.25), tf.float32), axes=None)[0]
|
|
|
|
|
|
def delta2_metric(y_true, y_pred):
|
|
maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred)
|
|
return tf.nn.moments(tf.cast(maxRatio < tf.convert_to_tensor(1.25 ** 2), tf.float32), axes=None)[0]
|
|
|
|
|
|
def delta3_metric(y_true, y_pred):
|
|
maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred)
|
|
return tf.nn.moments(tf.cast(maxRatio < tf.convert_to_tensor(1.25 ** 3), tf.float32), axes=None)[0]
|
|
|
|
|
|
def compile(model):
|
|
# TODO: Learning rate (exponential decay)
|
|
model.compile(optimizer=keras.optimizers.SGD(momentum=0.9),
|
|
loss=keras.losses.MeanSquaredError(),
|
|
metrics=[keras.metrics.RootMeanSquaredError(),
|
|
keras.metrics.MeanSquaredError(),
|
|
delta1_metric,
|
|
delta2_metric,
|
|
delta3_metric])
|
|
|
|
|
|
def train(existing_model=None, pretrained_weights='imagenet', epochs=4, save_file=None, dataset=None):
|
|
if not existing_model:
|
|
existing_model = make_mobilenet_nnconv5(pretrained_weights)
|
|
compile(existing_model)
|
|
if not dataset:
|
|
dataset = load_nyu()
|
|
existing_model.fit(dataset, epochs=epochs)
|
|
if save_file:
|
|
existing_model.save(save_file)
|
|
return existing_model
|
|
|
|
|
|
def evaluate(compiled_model, dataset=None):
|
|
"""
|
|
Evaluate the model using rmse, delta1/2/3 metrics
|
|
:param compiled_model: Compiled, trained model to evaluate
|
|
:param dataset: Dataset for evaluation. Should be of format {'image': image, 'depth': label},
|
|
where label width/height matches image width/height.
|
|
Defaults to Tensorflow nyu_v2 evaluation split dataset (https://www.tensorflow.org/datasets/catalog/nyu_depth_v2)
|
|
"""
|
|
if not dataset:
|
|
dataset = load_nyu_evaluate()
|
|
compiled_model.evaluate(dataset, verbose=1)
|
|
|
|
|
|
def forward(model, image):
|
|
"""
|
|
Propagate a single or batch of images through the model. Image(s) should be in format NHWC
|
|
:param model:
|
|
:param image:
|
|
:return:
|
|
"""
|
|
return model(crop_and_resize(image))
|
|
|
|
|
|
def load_model(file):
|
|
return keras.models.load_model(file, custom_objects={'delta1_metric': delta1_metric,
|
|
'delta2_metric': delta2_metric,
|
|
'delta3_metric': delta3_metric})
|
|
|
|
|
|
def crop_and_resize(x):
|
|
shape = tf.shape(x['depth'])
|
|
|
|
def layer():
|
|
return keras.Sequential([
|
|
keras.layers.experimental.preprocessing.CenterCrop(shape[1], shape[2]),
|
|
keras.layers.experimental.preprocessing.Resizing(224, 224, interpolation='nearest')
|
|
])
|
|
|
|
# Reshape label to 4d, can't use array unwrap as it's unsupported by tensorflow
|
|
return layer()(x['image']), layer()(tf.reshape(x['depth'], [shape[0], shape[1], shape[2], 1]))
|
|
|
|
|
|
def load_nyu():
|
|
builder = tfds.builder('nyu_depth_v2')
|
|
builder.download_and_prepare(download_dir='../nyu')
|
|
return builder \
|
|
.as_dataset(split='train', shuffle_files=True) \
|
|
.shuffle(buffer_size=1024) \
|
|
.batch(8) \
|
|
.map(lambda x: crop_and_resize(x))
|
|
|
|
|
|
def load_nyu_evaluate():
|
|
builder = tfds.builder('nyu_depth_v2')
|
|
builder.download_and_prepare(download_dir='../nyu')
|
|
return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x))
|