Addresses the following: - Rename nnconv5 block to nnconv5 - Add skip connections directly to nnconv5 block - Allow custom metrics, loss and optimizer (keep defaults that reflect original paper) to train - Correctly use nyu evaluation dataset only when no dataset is provided
205 lines
7.7 KiB
Python
205 lines
7.7 KiB
Python
import tensorflow as tf
|
|
import tensorflow.keras as keras
|
|
import tensorflow_datasets as tfds
|
|
|
|
"""
|
|
Unofficial tensorflow keras implementation of FastDepth (mobilenet_nnconv5).
|
|
PyTorch (official) Fast Depth Implementation: https://github.com/dwofk/fast-depth
|
|
"""
|
|
|
|
|
|
# Ripped from:
|
|
# https://forums.developer.nvidia.com/t/could-not-create-cudnn-handle-cudnn-status-alloc-failed/108261/4?u=mpivato4
|
|
# Seems to be an issue on windows so explicitly set gpu growth
|
|
def fix_windows_gpu():
|
|
"""
|
|
Fixes Windows GPU bug when attempting to allocate memory using cuDNN
|
|
"""
|
|
gpus = tf.config.experimental.list_physical_devices('GPU')
|
|
if gpus:
|
|
try:
|
|
# Currently, memory growth needs to be the same across GPUs
|
|
for gpu in gpus:
|
|
tf.config.experimental.set_memory_growth(gpu, True)
|
|
logical_gpus = tf.config.experimental.list_logical_devices('GPU')
|
|
print(len(gpus), "Physical GPUs,", len(
|
|
logical_gpus), "Logical GPUs")
|
|
except RuntimeError as e:
|
|
# Memory growth must be set before GPUs have been initialized
|
|
print(e)
|
|
|
|
|
|
def nnconv5(inputs,
|
|
out_channels,
|
|
block_id=1,
|
|
skip_connection=None):
|
|
x = keras.layers.DepthwiseConv2D(5, padding='same')(inputs)
|
|
x = keras.layers.BatchNormalization()(x)
|
|
x = keras.layers.ReLU(6.)(x)
|
|
x = keras.layers.Conv2D(out_channels, 1, padding='same')(x)
|
|
x = keras.layers.BatchNormalization()(x)
|
|
x = keras.layers.UpSampling2D()(x)
|
|
if skip_connection is not None:
|
|
x = keras.layers.Add()([x, skip_connection])
|
|
return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x)
|
|
|
|
|
|
def mobilenet_nnconv5(weights=None, shape=(224, 224, 3)):
|
|
"""
|
|
Replication of the FastDepth model in Tensorflow, using the keras Functional API
|
|
:param weights: Pretrained weights for MobileNet, defaults to None
|
|
:param shape: Input shape of the image, defaults to (224, 224, 3)
|
|
:return: FastDepth keras Model
|
|
"""
|
|
input = keras.layers.Input(shape=shape)
|
|
mobilenet = keras.applications.MobileNet(
|
|
input_shape=shape, input_tensor=input, include_top=False, weights=weights)
|
|
for layer in mobilenet.layers:
|
|
layer.trainable = True
|
|
|
|
# Fast depth decoder
|
|
x = nnconv5(mobilenet.output, 512, block_id=14)
|
|
x = nnconv5(x, 256, block_id=15, skip_connection=mobilenet.get_layer(
|
|
name="conv_pw_5_relu").output)
|
|
x = nnconv5(x, 128, block_id=16, skip_connection=mobilenet.get_layer(
|
|
name="conv_pw_3_relu").output)
|
|
x = nnconv5(x, 64, block_id=17, skip_connection=mobilenet.get_layer(
|
|
name="conv_pw_1_relu").output)
|
|
x = nnconv5(x, 32, block_id=18)
|
|
|
|
x = keras.layers.Conv2D(1, 1, padding='same')(x)
|
|
x = keras.layers.BatchNormalization()(x)
|
|
x = keras.layers.ReLU(6.)(x)
|
|
return keras.Model(inputs=input, outputs=x, name="fast_depth")
|
|
|
|
|
|
def delta1_metric(y_true, y_pred):
|
|
maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred)
|
|
return tf.nn.moments(tf.cast(maxRatio < tf.convert_to_tensor(1.25), tf.float32), axes=None)[0]
|
|
|
|
|
|
def delta2_metric(y_true, y_pred):
|
|
maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred)
|
|
return tf.nn.moments(tf.cast(maxRatio < tf.convert_to_tensor(1.25 ** 2), tf.float32), axes=None)[0]
|
|
|
|
|
|
def delta3_metric(y_true, y_pred):
|
|
maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred)
|
|
return tf.nn.moments(tf.cast(maxRatio < tf.convert_to_tensor(1.25 ** 3), tf.float32), axes=None)[0]
|
|
|
|
|
|
def compile(model, optimiser=keras.optimizers.SGD(), loss=keras.losses.MeanSquaredError(), custom_metrics=None):
|
|
"""
|
|
Compile FastDepth model with relevant metrics
|
|
:param model: Model to compile
|
|
:param optimiser: Custom optimiser to use
|
|
:param loss: Loss function to use
|
|
:param include_metrics: Whether to include metrics (RMSE, MSE, a1,2,3)
|
|
"""
|
|
model.compile(optimizer=optimiser,
|
|
loss=loss,
|
|
metrics=[keras.metrics.RootMeanSquaredError(),
|
|
keras.metrics.MeanSquaredError(),
|
|
delta1_metric,
|
|
delta2_metric,
|
|
delta3_metric] if custom_metrics is None else custom_metrics)
|
|
|
|
|
|
def train(existing_model=None, pretrained_weights='imagenet', epochs=4, save_file=None, dataset=None):
|
|
"""
|
|
Compile, train and save (if a save file is specified) a Fast Depth model.
|
|
:param existing_model: Existing FastDepth model to train. None will create
|
|
:param pretrained_weights: Weights to use if existing_model is not specified. See keras.applications.MobileNet
|
|
weights parameter for options here.
|
|
:param epochs: Number of epochs to run for
|
|
:param save_file: File/directory to save to after training. By default the model won't be saved
|
|
:param dataset: Train dataset to use. By default will DOWNLOAD and use tensorflow nyu_v2 dataset
|
|
"""
|
|
if not existing_model:
|
|
existing_model = mobilenet_nnconv5(pretrained_weights)
|
|
compile(existing_model)
|
|
if not dataset:
|
|
dataset = load_nyu()
|
|
existing_model.fit(dataset, epochs=epochs)
|
|
if save_file:
|
|
existing_model.save(save_file)
|
|
return existing_model
|
|
|
|
|
|
def evaluate(compiled_model, dataset=None):
|
|
"""
|
|
Evaluate the model using rmse, delta1/2/3 metrics
|
|
:param compiled_model: Compiled, trained model to evaluate
|
|
:param dataset: Dataset for evaluation. Should be of format {'image': image, 'depth': label},
|
|
where label width/height matches image width/height.
|
|
Defaults to Tensorflow nyu_v2 evaluation split dataset (https://www.tensorflow.org/datasets/catalog/nyu_depth_v2)
|
|
"""
|
|
if dataset is None:
|
|
dataset = load_nyu_evaluate()
|
|
compiled_model.evaluate(dataset, verbose=1)
|
|
|
|
|
|
def forward(model, image):
|
|
"""
|
|
Propagate a single or batch of images through the model. Image(s) should be in format NHWC
|
|
:param model:
|
|
:param image:
|
|
:return:
|
|
"""
|
|
return model(crop_and_resize(image))
|
|
|
|
|
|
def load_model(file):
|
|
"""
|
|
Load previously trained FastDepth model from disk. Will include relevant metrics (custom objects)
|
|
:param file: File/directory to load the model from
|
|
:return:
|
|
"""
|
|
return keras.models.load_model(file, custom_objects={'delta1_metric': delta1_metric,
|
|
'delta2_metric': delta2_metric,
|
|
'delta3_metric': delta3_metric})
|
|
|
|
|
|
def crop_and_resize(x):
|
|
shape = tf.shape(x['depth'])
|
|
|
|
def layer():
|
|
return keras.Sequential([
|
|
keras.layers.experimental.preprocessing.CenterCrop(
|
|
shape[1], shape[2]),
|
|
keras.layers.experimental.preprocessing.Resizing(
|
|
224, 224, interpolation='nearest')
|
|
])
|
|
|
|
# Reshape label to 4d, can't use array unwrap as it's unsupported by tensorflow
|
|
return layer()(x['image']), layer()(tf.reshape(x['depth'], [shape[0], shape[1], shape[2], 1]))
|
|
|
|
|
|
def load_nyu():
|
|
"""
|
|
Load the nyu_v2 dataset train split. Will be downloaded to ../nyu
|
|
:return: nyu_v2 dataset builder
|
|
"""
|
|
builder = tfds.builder('nyu_depth_v2')
|
|
builder.download_and_prepare(download_dir='../nyu')
|
|
return builder \
|
|
.as_dataset(split='train', shuffle_files=True) \
|
|
.shuffle(buffer_size=1024) \
|
|
.batch(8) \
|
|
.map(lambda x: crop_and_resize(x))
|
|
|
|
|
|
def load_nyu_evaluate():
|
|
"""
|
|
Load the nyu_v2 dataset validation split. Will be downloaded to ../nyu
|
|
:return: nyu_v2 dataset builder
|
|
"""
|
|
builder = tfds.builder('nyu_depth_v2')
|
|
builder.download_and_prepare(download_dir='../nyu')
|
|
return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
model = mobilenet_nnconv5()
|
|
model.summary()
|