Add documentation and README, use Upsampling2D rather than image Resizing layer
This commit is contained in:
86
README.md
Normal file
86
README.md
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
# Fast Depth TF
|
||||||
|
|
||||||
|
Tensorflow 2.0 Implementation of Fast Depth.
|
||||||
|
|
||||||
|
Original Implementation in PyTorch is here: https://github.com/dwofk/fast-depth
|
||||||
|
|
||||||
|
This code has been tested with Tensorflow 2.4.1, however any version of Tensorflow >2 should work
|
||||||
|
|
||||||
|
The model has also been successfully optimised using the OpenVINO model optimiser
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
To train and evaluate the model on the nyu_v2 dataset, simply run:
|
||||||
|
|
||||||
|
`python main.py`
|
||||||
|
|
||||||
|
WARNING: This will download nyu_v2 which is ~100gb when including archived and extracted files, plus another 70gb for
|
||||||
|
generated examples.
|
||||||
|
|
||||||
|
The following sample demonstrates creating a FastDepth model that can later be used for inference, training or
|
||||||
|
evaluation.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import fast_depth_functional as fd
|
||||||
|
|
||||||
|
# No Pretrained weights
|
||||||
|
model = fd.mobilenet_nnconv5()
|
||||||
|
|
||||||
|
# Imagenet weights
|
||||||
|
model = fd.mobilenet_nnconv5(weights='imagenet')
|
||||||
|
|
||||||
|
# Load trained model from file
|
||||||
|
model = fd.load_model('my_fastdepth_model')
|
||||||
|
```
|
||||||
|
|
||||||
|
### Train
|
||||||
|
|
||||||
|
Training with the NYU dataset is as simple as running the following:
|
||||||
|
WARNING: This will download ~30gb and extra ~70gb if you haven't downloaded it already. It also takes a long time to
|
||||||
|
prepare the examples (>1 hour)
|
||||||
|
|
||||||
|
```python
|
||||||
|
import fast_depth_functional as fd
|
||||||
|
|
||||||
|
model = fd.mobilenet_nnconv5(weights='imagenet')
|
||||||
|
|
||||||
|
# Train then save the model as keras h5 format
|
||||||
|
fd.train(model, save_file='fast_depth')
|
||||||
|
|
||||||
|
# A custom dataset can be passed in if required
|
||||||
|
fd.train(model, dataset=my_dataset)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Evaluate
|
||||||
|
|
||||||
|
Evaluation is similar to training. The nyu dataset validation split will be used by default, and if you trained as shown
|
||||||
|
above, the dataset will have already been downloaded.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import fast_depth_functional as fd
|
||||||
|
|
||||||
|
model = fd.load_model('fast_depth')
|
||||||
|
fd.compile(model)
|
||||||
|
fd.evaluate(model)
|
||||||
|
|
||||||
|
# A custom dataset for evaluation is supported
|
||||||
|
fd.evaluate(model, dataset=my_evaluation_dataset)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Windows GPU Fix
|
||||||
|
|
||||||
|
If you are using Windows and encounter an error opening cudnn (you should see CUDNN_STATUS_ALLOC_FAILED somewhere before
|
||||||
|
the error), first check you have correctly installed CUDA toolkit and cuDNN. If you have, then run the Windows GPU fix
|
||||||
|
included in this repo:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import fast_depth_functional as fd
|
||||||
|
|
||||||
|
# Windows GPU Fix
|
||||||
|
fd.fix_windows_gpu()
|
||||||
|
```
|
||||||
|
|
||||||
|
More information about this error can be found here:
|
||||||
|
https://forums.developer.nvidia.com/t/could-not-create-cudnn-handle-cudnn-status-alloc-failed/108261
|
||||||
@@ -2,10 +2,22 @@ import tensorflow as tf
|
|||||||
import tensorflow.keras as keras
|
import tensorflow.keras as keras
|
||||||
import tensorflow_datasets as tfds
|
import tensorflow_datasets as tfds
|
||||||
|
|
||||||
|
"""
|
||||||
|
Unofficial tensorflow keras implementation of FastDepth (mobilenet_nnconv5).
|
||||||
|
PyTorch (official) Fast Depth Implementation: https://github.com/dwofk/fast-depth
|
||||||
|
|
||||||
# Ripped from: https://forums.developer.nvidia.com/t/could-not-create-cudnn-handle-cudnn-status-alloc-failed/108261/4?u=mpivato4
|
There's also an experimental version that does not use BatchNormalisation, as well as Parametric ReLU and bilinear
|
||||||
|
upsampling (mobilenet_nnconv5_no_bn)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# Ripped from:
|
||||||
|
# https://forums.developer.nvidia.com/t/could-not-create-cudnn-handle-cudnn-status-alloc-failed/108261/4?u=mpivato4
|
||||||
# Seems to be an issue on windows so explicitly set gpu growth
|
# Seems to be an issue on windows so explicitly set gpu growth
|
||||||
def fix_windows_gpu():
|
def fix_windows_gpu():
|
||||||
|
"""
|
||||||
|
Fixes Windows GPU bug when attempting to allocate memory using cuDNN
|
||||||
|
"""
|
||||||
gpus = tf.config.experimental.list_physical_devices('GPU')
|
gpus = tf.config.experimental.list_physical_devices('GPU')
|
||||||
if gpus:
|
if gpus:
|
||||||
try:
|
try:
|
||||||
@@ -19,11 +31,6 @@ def fix_windows_gpu():
|
|||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
'''
|
|
||||||
Functional version of fastdepth model
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
def FDDepthwiseBlock(inputs,
|
def FDDepthwiseBlock(inputs,
|
||||||
out_channels,
|
out_channels,
|
||||||
block_id=1):
|
block_id=1):
|
||||||
@@ -35,14 +42,13 @@ def FDDepthwiseBlock(inputs,
|
|||||||
return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x)
|
return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x)
|
||||||
|
|
||||||
|
|
||||||
def FDDepthwiseBlockNoBN(inputs, out_channels, block_id=1):
|
def mobilenet_nnconv5(weights=None, shape=(224, 224, 3)):
|
||||||
x = keras.layers.DepthwiseConv2D(5, padding='same')(inputs)
|
"""
|
||||||
x = keras.layers.ReLU(6.)(x)
|
Replication of the FastDepth model in Tensorflow, using the keras Functional API
|
||||||
x = keras.layers.Conv2D(out_channels, 1, padding='same')(x)
|
:param weights: Pretrained weights for MobileNet, defaults to None
|
||||||
return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x)
|
:param shape: Input shape of the image, defaults to (224, 224, 3)
|
||||||
|
:return: FastDepth keras Model
|
||||||
|
"""
|
||||||
def make_mobilenet_nnconv5(weights=None, shape=(224, 224, 3)):
|
|
||||||
input = keras.layers.Input(shape=shape)
|
input = keras.layers.Input(shape=shape)
|
||||||
mobilenet = keras.applications.MobileNet(input_tensor=input, include_top=False, weights=weights)
|
mobilenet = keras.applications.MobileNet(input_tensor=input, include_top=False, weights=weights)
|
||||||
for layer in mobilenet.layers:
|
for layer in mobilenet.layers:
|
||||||
@@ -50,21 +56,19 @@ def make_mobilenet_nnconv5(weights=None, shape=(224, 224, 3)):
|
|||||||
|
|
||||||
# Fast depth decoder
|
# Fast depth decoder
|
||||||
x = FDDepthwiseBlock(mobilenet.output, 512, block_id=14)
|
x = FDDepthwiseBlock(mobilenet.output, 512, block_id=14)
|
||||||
# TODO: Bilinear interpolation
|
|
||||||
# x = keras.layers.experimental.preprocessing.Resizing(14, 14, interpolation='bilinear')
|
|
||||||
# Nearest neighbour interpolation, used by fast depth paper
|
# Nearest neighbour interpolation, used by fast depth paper
|
||||||
x = keras.layers.experimental.preprocessing.Resizing(14, 14, interpolation='nearest')(x)
|
x = keras.layers.UpSampling2D()(x)
|
||||||
x = FDDepthwiseBlock(x, 256, block_id=15)
|
x = FDDepthwiseBlock(x, 256, block_id=15)
|
||||||
x = keras.layers.experimental.preprocessing.Resizing(28, 28, interpolation='nearest')(x)
|
x = keras.layers.UpSampling2D()(x)
|
||||||
x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_5_relu").output])
|
x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_5_relu").output])
|
||||||
x = FDDepthwiseBlock(x, 128, block_id=16)
|
x = FDDepthwiseBlock(x, 128, block_id=16)
|
||||||
x = keras.layers.experimental.preprocessing.Resizing(56, 56, interpolation='nearest')(x)
|
x = keras.layers.UpSampling2D()(x)
|
||||||
x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_3_relu").output])
|
x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_3_relu").output])
|
||||||
x = FDDepthwiseBlock(x, 64, block_id=17)
|
x = FDDepthwiseBlock(x, 64, block_id=17)
|
||||||
x = keras.layers.experimental.preprocessing.Resizing(112, 112, interpolation='nearest')(x)
|
x = keras.layers.UpSampling2D()(x)
|
||||||
x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_1_relu").output])
|
x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_1_relu").output])
|
||||||
x = FDDepthwiseBlock(x, 32, block_id=18)
|
x = FDDepthwiseBlock(x, 32, block_id=18)
|
||||||
x = keras.layers.experimental.preprocessing.Resizing(224, 224, interpolation='nearest')(x)
|
x = keras.layers.UpSampling2D()(x)
|
||||||
|
|
||||||
x = keras.layers.Conv2D(1, 1, padding='same')(x)
|
x = keras.layers.Conv2D(1, 1, padding='same')(x)
|
||||||
x = keras.layers.BatchNormalization()(x)
|
x = keras.layers.BatchNormalization()(x)
|
||||||
@@ -72,7 +76,25 @@ def make_mobilenet_nnconv5(weights=None, shape=(224, 224, 3)):
|
|||||||
return keras.Model(inputs=input, outputs=x, name="fast_depth")
|
return keras.Model(inputs=input, outputs=x, name="fast_depth")
|
||||||
|
|
||||||
|
|
||||||
def make_mobilenet_nnconv5_no_bn(weights=None, shape=(224, 224, 3)):
|
#### Experimental ####
|
||||||
|
def FDDepthwiseBlockNoBN(inputs, out_channels, block_id=1):
|
||||||
|
x = keras.layers.DepthwiseConv2D(5, padding='same')(inputs)
|
||||||
|
x = keras.layers.PReLU()(x)
|
||||||
|
x = keras.layers.Conv2D(out_channels, 1, padding='same')(x)
|
||||||
|
return keras.layers.PReLU(name='conv_pw_%d_relu' % block_id)(x)
|
||||||
|
|
||||||
|
|
||||||
|
def mobilenet_nnconv5_no_bn(weights=None, shape=(224, 224, 3)):
|
||||||
|
"""
|
||||||
|
Experimental version of the FastDepth model.
|
||||||
|
This version has the following changes:
|
||||||
|
- Bilinear upsampling is used rather than nearest neighbour
|
||||||
|
- No BatchNormalisation in decoder
|
||||||
|
- Parametric ReLU in Decoder rather than ReLU
|
||||||
|
:param weights: Pretrained weights for MobileNet, defaults to None
|
||||||
|
:param shape: Input shape of the image, defaults to (224, 224, 3)
|
||||||
|
:return: Experimental FastDepth keras Model
|
||||||
|
"""
|
||||||
input = keras.layers.Input(shape=shape)
|
input = keras.layers.Input(shape=shape)
|
||||||
mobilenet = keras.applications.MobileNet(input_tensor=input, include_top=False, weights=weights)
|
mobilenet = keras.applications.MobileNet(input_tensor=input, include_top=False, weights=weights)
|
||||||
for layer in mobilenet.layers:
|
for layer in mobilenet.layers:
|
||||||
@@ -80,26 +102,24 @@ def make_mobilenet_nnconv5_no_bn(weights=None, shape=(224, 224, 3)):
|
|||||||
|
|
||||||
# Fast depth decoder
|
# Fast depth decoder
|
||||||
x = FDDepthwiseBlockNoBN(mobilenet.output, 512, block_id=14)
|
x = FDDepthwiseBlockNoBN(mobilenet.output, 512, block_id=14)
|
||||||
# Nearest neighbour interpolation, used by fast depth paper
|
x = keras.layers.UpSampling2D(interpolation='bilinear')(x)
|
||||||
x = keras.layers.experimental.preprocessing.Resizing(14, 14, interpolation='bilinear')(x)
|
|
||||||
x = FDDepthwiseBlockNoBN(x, 256, block_id=15)
|
x = FDDepthwiseBlockNoBN(x, 256, block_id=15)
|
||||||
x = keras.layers.experimental.preprocessing.Resizing(28, 28, interpolation='bilinear')(x)
|
x = keras.layers.UpSampling2D(interpolation='bilinear')(x)
|
||||||
x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_5_relu").output])
|
x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_5_relu").output])
|
||||||
x = FDDepthwiseBlockNoBN(x, 128, block_id=16)
|
x = FDDepthwiseBlockNoBN(x, 128, block_id=16)
|
||||||
x = keras.layers.experimental.preprocessing.Resizing(56, 56, interpolation='bilinear')(x)
|
x = keras.layers.UpSampling2D(interpolation='bilinear')(x)
|
||||||
x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_3_relu").output])
|
x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_3_relu").output])
|
||||||
x = FDDepthwiseBlockNoBN(x, 64, block_id=17)
|
x = FDDepthwiseBlockNoBN(x, 64, block_id=17)
|
||||||
x = keras.layers.experimental.preprocessing.Resizing(112, 112, interpolation='bilinear')(x)
|
x = keras.layers.UpSampling2D(interpolation='bilinear')(x)
|
||||||
x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_1_relu").output])
|
x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_1_relu").output])
|
||||||
x = FDDepthwiseBlockNoBN(x, 32, block_id=18)
|
x = FDDepthwiseBlockNoBN(x, 32, block_id=18)
|
||||||
x = keras.layers.experimental.preprocessing.Resizing(224, 224, interpolation='bilinear')(x)
|
x = keras.layers.UpSampling2D(interpolation='bilinear')(x)
|
||||||
|
|
||||||
x = keras.layers.Conv2D(1, 1, padding='same')(x)
|
x = keras.layers.Conv2D(1, 1, padding='same')(x)
|
||||||
x = keras.layers.ReLU(6.)(x)
|
x = keras.layers.PReLU()(x)
|
||||||
return keras.Model(inputs=input, outputs=x, name="fast_depth")
|
return keras.Model(inputs=input, outputs=x, name="fast_depth_experimental")
|
||||||
|
|
||||||
|
|
||||||
# TODO: Fix these, float doesn't work same as pytorch
|
|
||||||
def delta1_metric(y_true, y_pred):
|
def delta1_metric(y_true, y_pred):
|
||||||
maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred)
|
maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred)
|
||||||
return tf.nn.moments(tf.cast(maxRatio < tf.convert_to_tensor(1.25), tf.float32), axes=None)[0]
|
return tf.nn.moments(tf.cast(maxRatio < tf.convert_to_tensor(1.25), tf.float32), axes=None)[0]
|
||||||
@@ -116,6 +136,10 @@ def delta3_metric(y_true, y_pred):
|
|||||||
|
|
||||||
|
|
||||||
def compile(model):
|
def compile(model):
|
||||||
|
"""
|
||||||
|
Compile FastDepth model with relevant metrics
|
||||||
|
:param model: Model to compile
|
||||||
|
"""
|
||||||
# TODO: Learning rate (exponential decay)
|
# TODO: Learning rate (exponential decay)
|
||||||
model.compile(optimizer=keras.optimizers.SGD(momentum=0.9),
|
model.compile(optimizer=keras.optimizers.SGD(momentum=0.9),
|
||||||
loss=keras.losses.MeanSquaredError(),
|
loss=keras.losses.MeanSquaredError(),
|
||||||
@@ -127,8 +151,17 @@ def compile(model):
|
|||||||
|
|
||||||
|
|
||||||
def train(existing_model=None, pretrained_weights='imagenet', epochs=4, save_file=None, dataset=None):
|
def train(existing_model=None, pretrained_weights='imagenet', epochs=4, save_file=None, dataset=None):
|
||||||
|
"""
|
||||||
|
Compile, train and save (if a save file is specified) a Fast Depth model.
|
||||||
|
:param existing_model: Existing FastDepth model to train. None will create
|
||||||
|
:param pretrained_weights: Weights to use if existing_model is not specified. See keras.applications.MobileNet
|
||||||
|
weights parameter for options here.
|
||||||
|
:param epochs: Number of epochs to run for
|
||||||
|
:param save_file: File/directory to save to after training. By default the model won't be saved
|
||||||
|
:param dataset: Train dataset to use. By default will DOWNLOAD and use tensorflow nyu_v2 dataset
|
||||||
|
"""
|
||||||
if not existing_model:
|
if not existing_model:
|
||||||
existing_model = make_mobilenet_nnconv5(pretrained_weights)
|
existing_model = mobilenet_nnconv5(pretrained_weights)
|
||||||
compile(existing_model)
|
compile(existing_model)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
dataset = load_nyu()
|
dataset = load_nyu()
|
||||||
@@ -162,6 +195,11 @@ def forward(model, image):
|
|||||||
|
|
||||||
|
|
||||||
def load_model(file):
|
def load_model(file):
|
||||||
|
"""
|
||||||
|
Load previously trained FastDepth model from disk. Will include relevant metrics (custom objects)
|
||||||
|
:param file: File/directory to load the model from
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
return keras.models.load_model(file, custom_objects={'delta1_metric': delta1_metric,
|
return keras.models.load_model(file, custom_objects={'delta1_metric': delta1_metric,
|
||||||
'delta2_metric': delta2_metric,
|
'delta2_metric': delta2_metric,
|
||||||
'delta3_metric': delta3_metric})
|
'delta3_metric': delta3_metric})
|
||||||
@@ -181,6 +219,10 @@ def crop_and_resize(x):
|
|||||||
|
|
||||||
|
|
||||||
def load_nyu():
|
def load_nyu():
|
||||||
|
"""
|
||||||
|
Load the nyu_v2 dataset train split. Will be downloaded to ../nyu
|
||||||
|
:return: nyu_v2 dataset builder
|
||||||
|
"""
|
||||||
builder = tfds.builder('nyu_depth_v2')
|
builder = tfds.builder('nyu_depth_v2')
|
||||||
builder.download_and_prepare(download_dir='../nyu')
|
builder.download_and_prepare(download_dir='../nyu')
|
||||||
return builder \
|
return builder \
|
||||||
@@ -191,6 +233,10 @@ def load_nyu():
|
|||||||
|
|
||||||
|
|
||||||
def load_nyu_evaluate():
|
def load_nyu_evaluate():
|
||||||
|
"""
|
||||||
|
Load the nyu_v2 dataset validation split. Will be downloaded to ../nyu
|
||||||
|
:return: nyu_v2 dataset builder
|
||||||
|
"""
|
||||||
builder = tfds.builder('nyu_depth_v2')
|
builder = tfds.builder('nyu_depth_v2')
|
||||||
builder.download_and_prepare(download_dir='../nyu')
|
builder.download_and_prepare(download_dir='../nyu')
|
||||||
return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x))
|
return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x))
|
||||||
|
|||||||
3
main.py
3
main.py
@@ -2,6 +2,7 @@ import fast_depth_functional as fd
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
fd.fix_windows_gpu()
|
fd.fix_windows_gpu()
|
||||||
model = fd.load_model('fast_depth_nyu_v2_224_224_3_e1')
|
model = fd.mobilenet_nnconv5_no_bn(weights='imagenet')
|
||||||
fd.compile(model)
|
fd.compile(model)
|
||||||
|
fd.train(existing_model=model, save_file='../fast-depth-experimental')
|
||||||
fd.evaluate(model)
|
fd.evaluate(model)
|
||||||
|
|||||||
Reference in New Issue
Block a user