diff --git a/README.md b/README.md new file mode 100644 index 0000000..314fc95 --- /dev/null +++ b/README.md @@ -0,0 +1,86 @@ +# Fast Depth TF + +Tensorflow 2.0 Implementation of Fast Depth. + +Original Implementation in PyTorch is here: https://github.com/dwofk/fast-depth + +This code has been tested with Tensorflow 2.4.1, however any version of Tensorflow >2 should work + +The model has also been successfully optimised using the OpenVINO model optimiser + +## Basic Usage + +To train and evaluate the model on the nyu_v2 dataset, simply run: + +`python main.py` + +WARNING: This will download nyu_v2 which is ~100gb when including archived and extracted files, plus another 70gb for +generated examples. + +The following sample demonstrates creating a FastDepth model that can later be used for inference, training or +evaluation. + +```python +import fast_depth_functional as fd + +# No Pretrained weights +model = fd.mobilenet_nnconv5() + +# Imagenet weights +model = fd.mobilenet_nnconv5(weights='imagenet') + +# Load trained model from file +model = fd.load_model('my_fastdepth_model') +``` + +### Train + +Training with the NYU dataset is as simple as running the following: +WARNING: This will download ~30gb and extra ~70gb if you haven't downloaded it already. It also takes a long time to +prepare the examples (>1 hour) + +```python +import fast_depth_functional as fd + +model = fd.mobilenet_nnconv5(weights='imagenet') + +# Train then save the model as keras h5 format +fd.train(model, save_file='fast_depth') + +# A custom dataset can be passed in if required +fd.train(model, dataset=my_dataset) +``` + +### Evaluate + +Evaluation is similar to training. The nyu dataset validation split will be used by default, and if you trained as shown +above, the dataset will have already been downloaded. + +```python +import fast_depth_functional as fd + +model = fd.load_model('fast_depth') +fd.compile(model) +fd.evaluate(model) + +# A custom dataset for evaluation is supported +fd.evaluate(model, dataset=my_evaluation_dataset) +``` + +## Troubleshooting + +### Windows GPU Fix + +If you are using Windows and encounter an error opening cudnn (you should see CUDNN_STATUS_ALLOC_FAILED somewhere before +the error), first check you have correctly installed CUDA toolkit and cuDNN. If you have, then run the Windows GPU fix +included in this repo: + +```python +import fast_depth_functional as fd + +# Windows GPU Fix +fd.fix_windows_gpu() +``` + +More information about this error can be found here: +https://forums.developer.nvidia.com/t/could-not-create-cudnn-handle-cudnn-status-alloc-failed/108261 \ No newline at end of file diff --git a/fast_depth_functional.py b/fast_depth_functional.py index c085687..62d1d11 100644 --- a/fast_depth_functional.py +++ b/fast_depth_functional.py @@ -2,10 +2,22 @@ import tensorflow as tf import tensorflow.keras as keras import tensorflow_datasets as tfds +""" +Unofficial tensorflow keras implementation of FastDepth (mobilenet_nnconv5). +PyTorch (official) Fast Depth Implementation: https://github.com/dwofk/fast-depth -# Ripped from: https://forums.developer.nvidia.com/t/could-not-create-cudnn-handle-cudnn-status-alloc-failed/108261/4?u=mpivato4 +There's also an experimental version that does not use BatchNormalisation, as well as Parametric ReLU and bilinear +upsampling (mobilenet_nnconv5_no_bn) +""" + + +# Ripped from: +# https://forums.developer.nvidia.com/t/could-not-create-cudnn-handle-cudnn-status-alloc-failed/108261/4?u=mpivato4 # Seems to be an issue on windows so explicitly set gpu growth def fix_windows_gpu(): + """ + Fixes Windows GPU bug when attempting to allocate memory using cuDNN + """ gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: try: @@ -19,11 +31,6 @@ def fix_windows_gpu(): print(e) -''' -Functional version of fastdepth model -''' - - def FDDepthwiseBlock(inputs, out_channels, block_id=1): @@ -35,14 +42,13 @@ def FDDepthwiseBlock(inputs, return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x) -def FDDepthwiseBlockNoBN(inputs, out_channels, block_id=1): - x = keras.layers.DepthwiseConv2D(5, padding='same')(inputs) - x = keras.layers.ReLU(6.)(x) - x = keras.layers.Conv2D(out_channels, 1, padding='same')(x) - return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x) - - -def make_mobilenet_nnconv5(weights=None, shape=(224, 224, 3)): +def mobilenet_nnconv5(weights=None, shape=(224, 224, 3)): + """ + Replication of the FastDepth model in Tensorflow, using the keras Functional API + :param weights: Pretrained weights for MobileNet, defaults to None + :param shape: Input shape of the image, defaults to (224, 224, 3) + :return: FastDepth keras Model + """ input = keras.layers.Input(shape=shape) mobilenet = keras.applications.MobileNet(input_tensor=input, include_top=False, weights=weights) for layer in mobilenet.layers: @@ -50,21 +56,19 @@ def make_mobilenet_nnconv5(weights=None, shape=(224, 224, 3)): # Fast depth decoder x = FDDepthwiseBlock(mobilenet.output, 512, block_id=14) - # TODO: Bilinear interpolation - # x = keras.layers.experimental.preprocessing.Resizing(14, 14, interpolation='bilinear') # Nearest neighbour interpolation, used by fast depth paper - x = keras.layers.experimental.preprocessing.Resizing(14, 14, interpolation='nearest')(x) + x = keras.layers.UpSampling2D()(x) x = FDDepthwiseBlock(x, 256, block_id=15) - x = keras.layers.experimental.preprocessing.Resizing(28, 28, interpolation='nearest')(x) + x = keras.layers.UpSampling2D()(x) x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_5_relu").output]) x = FDDepthwiseBlock(x, 128, block_id=16) - x = keras.layers.experimental.preprocessing.Resizing(56, 56, interpolation='nearest')(x) + x = keras.layers.UpSampling2D()(x) x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_3_relu").output]) x = FDDepthwiseBlock(x, 64, block_id=17) - x = keras.layers.experimental.preprocessing.Resizing(112, 112, interpolation='nearest')(x) + x = keras.layers.UpSampling2D()(x) x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_1_relu").output]) x = FDDepthwiseBlock(x, 32, block_id=18) - x = keras.layers.experimental.preprocessing.Resizing(224, 224, interpolation='nearest')(x) + x = keras.layers.UpSampling2D()(x) x = keras.layers.Conv2D(1, 1, padding='same')(x) x = keras.layers.BatchNormalization()(x) @@ -72,7 +76,25 @@ def make_mobilenet_nnconv5(weights=None, shape=(224, 224, 3)): return keras.Model(inputs=input, outputs=x, name="fast_depth") -def make_mobilenet_nnconv5_no_bn(weights=None, shape=(224, 224, 3)): +#### Experimental #### +def FDDepthwiseBlockNoBN(inputs, out_channels, block_id=1): + x = keras.layers.DepthwiseConv2D(5, padding='same')(inputs) + x = keras.layers.PReLU()(x) + x = keras.layers.Conv2D(out_channels, 1, padding='same')(x) + return keras.layers.PReLU(name='conv_pw_%d_relu' % block_id)(x) + + +def mobilenet_nnconv5_no_bn(weights=None, shape=(224, 224, 3)): + """ + Experimental version of the FastDepth model. + This version has the following changes: + - Bilinear upsampling is used rather than nearest neighbour + - No BatchNormalisation in decoder + - Parametric ReLU in Decoder rather than ReLU + :param weights: Pretrained weights for MobileNet, defaults to None + :param shape: Input shape of the image, defaults to (224, 224, 3) + :return: Experimental FastDepth keras Model + """ input = keras.layers.Input(shape=shape) mobilenet = keras.applications.MobileNet(input_tensor=input, include_top=False, weights=weights) for layer in mobilenet.layers: @@ -80,26 +102,24 @@ def make_mobilenet_nnconv5_no_bn(weights=None, shape=(224, 224, 3)): # Fast depth decoder x = FDDepthwiseBlockNoBN(mobilenet.output, 512, block_id=14) - # Nearest neighbour interpolation, used by fast depth paper - x = keras.layers.experimental.preprocessing.Resizing(14, 14, interpolation='bilinear')(x) + x = keras.layers.UpSampling2D(interpolation='bilinear')(x) x = FDDepthwiseBlockNoBN(x, 256, block_id=15) - x = keras.layers.experimental.preprocessing.Resizing(28, 28, interpolation='bilinear')(x) + x = keras.layers.UpSampling2D(interpolation='bilinear')(x) x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_5_relu").output]) x = FDDepthwiseBlockNoBN(x, 128, block_id=16) - x = keras.layers.experimental.preprocessing.Resizing(56, 56, interpolation='bilinear')(x) + x = keras.layers.UpSampling2D(interpolation='bilinear')(x) x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_3_relu").output]) x = FDDepthwiseBlockNoBN(x, 64, block_id=17) - x = keras.layers.experimental.preprocessing.Resizing(112, 112, interpolation='bilinear')(x) + x = keras.layers.UpSampling2D(interpolation='bilinear')(x) x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_1_relu").output]) x = FDDepthwiseBlockNoBN(x, 32, block_id=18) - x = keras.layers.experimental.preprocessing.Resizing(224, 224, interpolation='bilinear')(x) + x = keras.layers.UpSampling2D(interpolation='bilinear')(x) x = keras.layers.Conv2D(1, 1, padding='same')(x) - x = keras.layers.ReLU(6.)(x) - return keras.Model(inputs=input, outputs=x, name="fast_depth") + x = keras.layers.PReLU()(x) + return keras.Model(inputs=input, outputs=x, name="fast_depth_experimental") -# TODO: Fix these, float doesn't work same as pytorch def delta1_metric(y_true, y_pred): maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred) return tf.nn.moments(tf.cast(maxRatio < tf.convert_to_tensor(1.25), tf.float32), axes=None)[0] @@ -116,6 +136,10 @@ def delta3_metric(y_true, y_pred): def compile(model): + """ + Compile FastDepth model with relevant metrics + :param model: Model to compile + """ # TODO: Learning rate (exponential decay) model.compile(optimizer=keras.optimizers.SGD(momentum=0.9), loss=keras.losses.MeanSquaredError(), @@ -127,8 +151,17 @@ def compile(model): def train(existing_model=None, pretrained_weights='imagenet', epochs=4, save_file=None, dataset=None): + """ + Compile, train and save (if a save file is specified) a Fast Depth model. + :param existing_model: Existing FastDepth model to train. None will create + :param pretrained_weights: Weights to use if existing_model is not specified. See keras.applications.MobileNet + weights parameter for options here. + :param epochs: Number of epochs to run for + :param save_file: File/directory to save to after training. By default the model won't be saved + :param dataset: Train dataset to use. By default will DOWNLOAD and use tensorflow nyu_v2 dataset + """ if not existing_model: - existing_model = make_mobilenet_nnconv5(pretrained_weights) + existing_model = mobilenet_nnconv5(pretrained_weights) compile(existing_model) if not dataset: dataset = load_nyu() @@ -162,6 +195,11 @@ def forward(model, image): def load_model(file): + """ + Load previously trained FastDepth model from disk. Will include relevant metrics (custom objects) + :param file: File/directory to load the model from + :return: + """ return keras.models.load_model(file, custom_objects={'delta1_metric': delta1_metric, 'delta2_metric': delta2_metric, 'delta3_metric': delta3_metric}) @@ -181,6 +219,10 @@ def crop_and_resize(x): def load_nyu(): + """ + Load the nyu_v2 dataset train split. Will be downloaded to ../nyu + :return: nyu_v2 dataset builder + """ builder = tfds.builder('nyu_depth_v2') builder.download_and_prepare(download_dir='../nyu') return builder \ @@ -191,6 +233,10 @@ def load_nyu(): def load_nyu_evaluate(): + """ + Load the nyu_v2 dataset validation split. Will be downloaded to ../nyu + :return: nyu_v2 dataset builder + """ builder = tfds.builder('nyu_depth_v2') builder.download_and_prepare(download_dir='../nyu') return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x)) diff --git a/main.py b/main.py index 903e485..d5a6ff7 100644 --- a/main.py +++ b/main.py @@ -2,6 +2,7 @@ import fast_depth_functional as fd if __name__ == '__main__': fd.fix_windows_gpu() - model = fd.load_model('fast_depth_nyu_v2_224_224_3_e1') + model = fd.mobilenet_nnconv5_no_bn(weights='imagenet') fd.compile(model) + fd.train(existing_model=model, save_file='../fast-depth-experimental') fd.evaluate(model)