Add compiling packnet model, refactor modules to not duplicate loaders and trainers
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras as keras
|
||||
import tensorflow_datasets as tfds
|
||||
from load import load_nyu, load_nyu_evaluate
|
||||
|
||||
from load import load_nyu_evaluate
|
||||
from metric import *
|
||||
from util import crop_and_resize
|
||||
# Needed for the kitti dataset, don't delete
|
||||
|
||||
"""
|
||||
Unofficial tensorflow keras implementation of FastDepth (mobilenet_nnconv5).
|
||||
@@ -76,59 +75,6 @@ def mobilenet_nnconv5(weights=None, shape=(224, 224, 3)):
|
||||
return keras.Model(inputs=input, outputs=x, name="fast_depth")
|
||||
|
||||
|
||||
def delta1_metric(y_true, y_pred):
|
||||
maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred)
|
||||
return tf.nn.moments(tf.cast(maxRatio < tf.convert_to_tensor(1.25), tf.float32), axes=None)[0]
|
||||
|
||||
|
||||
def delta2_metric(y_true, y_pred):
|
||||
maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred)
|
||||
return tf.nn.moments(tf.cast(maxRatio < tf.convert_to_tensor(1.25 ** 2), tf.float32), axes=None)[0]
|
||||
|
||||
|
||||
def delta3_metric(y_true, y_pred):
|
||||
maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred)
|
||||
return tf.nn.moments(tf.cast(maxRatio < tf.convert_to_tensor(1.25 ** 3), tf.float32), axes=None)[0]
|
||||
|
||||
|
||||
def compile(model, optimiser=keras.optimizers.SGD(), loss=keras.losses.MeanSquaredError(), custom_metrics=None):
|
||||
"""
|
||||
Compile FastDepth model with relevant metrics
|
||||
:param model: Model to compile
|
||||
:param optimiser: Custom optimiser to use
|
||||
:param loss: Loss function to use
|
||||
:param include_metrics: Whether to include metrics (RMSE, MSE, a1,2,3)
|
||||
"""
|
||||
model.compile(optimizer=optimiser,
|
||||
loss=loss,
|
||||
metrics=[keras.metrics.RootMeanSquaredError(),
|
||||
keras.metrics.MeanSquaredError(),
|
||||
delta1_metric,
|
||||
delta2_metric,
|
||||
delta3_metric] if custom_metrics is None else custom_metrics)
|
||||
|
||||
|
||||
def train(existing_model=None, pretrained_weights='imagenet', epochs=4, save_file=None, dataset=None):
|
||||
"""
|
||||
Compile, train and save (if a save file is specified) a Fast Depth model.
|
||||
:param existing_model: Existing FastDepth model to train. None will create
|
||||
:param pretrained_weights: Weights to use if existing_model is not specified. See keras.applications.MobileNet
|
||||
weights parameter for options here.
|
||||
:param epochs: Number of epochs to run for
|
||||
:param save_file: File/directory to save to after training. By default the model won't be saved
|
||||
:param dataset: Train dataset to use. By default will DOWNLOAD and use tensorflow nyu_v2 dataset
|
||||
"""
|
||||
if not existing_model:
|
||||
existing_model = mobilenet_nnconv5(pretrained_weights)
|
||||
compile(existing_model)
|
||||
if not dataset:
|
||||
dataset = load_nyu()
|
||||
existing_model.fit(dataset, epochs=epochs)
|
||||
if save_file:
|
||||
existing_model.save(save_file)
|
||||
return existing_model
|
||||
|
||||
|
||||
def evaluate(compiled_model, dataset=None):
|
||||
"""
|
||||
Evaluate the model using rmse, delta1/2/3 metrics
|
||||
@@ -152,16 +98,6 @@ def forward(model, image):
|
||||
return model(crop_and_resize(image))
|
||||
|
||||
|
||||
def load_model(file):
|
||||
"""
|
||||
Load previously trained FastDepth model from disk. Will include relevant metrics (custom objects)
|
||||
:param file: File/directory to load the model from
|
||||
:return:
|
||||
"""
|
||||
return keras.models.load_model(file, custom_objects={'delta1_metric': delta1_metric,
|
||||
'delta2_metric': delta2_metric,
|
||||
'delta3_metric': delta3_metric})
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = mobilenet_nnconv5()
|
||||
model.summary()
|
||||
|
||||
Reference in New Issue
Block a user