Add metrics, prepare for training
This commit is contained in:
@@ -1,7 +1,8 @@
|
|||||||
|
import tensorflow as tf
|
||||||
import tensorflow.keras as keras
|
import tensorflow.keras as keras
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Functional version of fastdepth model. Note that this doesn't work at the moment
|
Functional version of fastdepth model
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
||||||
@@ -16,11 +17,11 @@ def FDDepthwiseBlock(inputs,
|
|||||||
return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x)
|
return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x)
|
||||||
|
|
||||||
|
|
||||||
def make_fastdepth_functional():
|
def make_fastdepth_functional(weights=None):
|
||||||
# This doesn't work, at least right now...
|
# This doesn't work, at least right now...
|
||||||
input = keras.layers.Input(shape=(224, 224, 3))
|
input = keras.layers.Input(shape=(224, 224, 3))
|
||||||
x = input
|
x = input
|
||||||
mobilenet = keras.applications.MobileNet(include_top=False, weights=None)
|
mobilenet = keras.applications.MobileNet(include_top=False, weights=weights)
|
||||||
for layer in mobilenet.layers:
|
for layer in mobilenet.layers:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
if layer.name == 'conv_pw_5_relu':
|
if layer.name == 'conv_pw_5_relu':
|
||||||
@@ -30,6 +31,7 @@ def make_fastdepth_functional():
|
|||||||
elif layer.name == 'conv_pw_1_relu':
|
elif layer.name == 'conv_pw_1_relu':
|
||||||
conv1 = x
|
conv1 = x
|
||||||
|
|
||||||
|
# Fast depth decoder
|
||||||
x = FDDepthwiseBlock(x, 512, block_id=14)
|
x = FDDepthwiseBlock(x, 512, block_id=14)
|
||||||
# TODO: Bilinear interpolation
|
# TODO: Bilinear interpolation
|
||||||
# x = keras.layers.experimental.preprocessing.Resizing(14, 14)
|
# x = keras.layers.experimental.preprocessing.Resizing(14, 14)
|
||||||
@@ -53,5 +55,43 @@ def make_fastdepth_functional():
|
|||||||
return keras.Model(inputs=input, outputs=x, name="fast_depth")
|
return keras.Model(inputs=input, outputs=x, name="fast_depth")
|
||||||
|
|
||||||
|
|
||||||
|
def delta1_metric(y_true, y_pred):
|
||||||
|
maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred)
|
||||||
|
return float((maxRatio < 1.25).float().mean())
|
||||||
|
|
||||||
|
|
||||||
|
def delta2_metric(y_true, y_pred):
|
||||||
|
maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred)
|
||||||
|
return float((maxRatio < 1.2 ** 25).float().mean())
|
||||||
|
|
||||||
|
|
||||||
|
def delta3_metric(y_true, y_pred):
|
||||||
|
maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred)
|
||||||
|
return float((maxRatio < 1.25 ** 3).float().mean())
|
||||||
|
|
||||||
|
|
||||||
|
def fastdepth_for_training():
|
||||||
|
# Pretrained mobilenet on imagenet dataset
|
||||||
|
model = make_fastdepth_functional('imagenet')
|
||||||
|
return model.compile(optimizer=keras.optimizers.SGD(momentum=0.9),
|
||||||
|
loss=keras.losses.MSE(),
|
||||||
|
metrics=[keras.metrics.RootMeanSquaredError(),
|
||||||
|
keras.metrics.MeanSquaredError(),
|
||||||
|
delta1_metric,
|
||||||
|
delta2_metric,
|
||||||
|
delta3_metric])
|
||||||
|
|
||||||
|
|
||||||
|
def train_compiled_model(compiled_model, dataset):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param compiled_model: Compiled model to train on
|
||||||
|
:param dataset: Dataset to train on (must be compatible with model
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# TODO: Use tf nyu_v2 dataset to train.
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
make_fastdepth_functional().summary()
|
make_fastdepth_functional().summary()
|
||||||
|
|||||||
8
main.py
8
main.py
@@ -7,13 +7,5 @@ def load_nyu():
|
|||||||
return builder.as_dataset(split='train', shuffle_files=True)
|
return builder.as_dataset(split='train', shuffle_files=True)
|
||||||
|
|
||||||
|
|
||||||
def print_hi(name):
|
|
||||||
# Use a breakpoint in the code line below to debug your script.
|
|
||||||
print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint.
|
|
||||||
|
|
||||||
|
|
||||||
# Press the green button in the gutter to run the script.
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
load_nyu()
|
load_nyu()
|
||||||
|
|
||||||
# See PyCharm help at https://www.jetbrains.com/help/pycharm/
|
|
||||||
|
|||||||
Reference in New Issue
Block a user