Add compiling packnet model, refactor modules to not duplicate loaders and trainers

This commit is contained in:
Piv
2021-07-23 22:41:46 +09:30
parent 66cbc7faf6
commit 3254eef4bf
8 changed files with 135 additions and 96 deletions

View File

@@ -1,9 +1,8 @@
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow_datasets as tfds
from load import load_nyu, load_nyu_evaluate
from load import load_nyu_evaluate
from metric import *
from util import crop_and_resize
# Needed for the kitti dataset, don't delete
"""
Unofficial tensorflow keras implementation of FastDepth (mobilenet_nnconv5).
@@ -76,59 +75,6 @@ def mobilenet_nnconv5(weights=None, shape=(224, 224, 3)):
return keras.Model(inputs=input, outputs=x, name="fast_depth")
def delta1_metric(y_true, y_pred):
maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred)
return tf.nn.moments(tf.cast(maxRatio < tf.convert_to_tensor(1.25), tf.float32), axes=None)[0]
def delta2_metric(y_true, y_pred):
maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred)
return tf.nn.moments(tf.cast(maxRatio < tf.convert_to_tensor(1.25 ** 2), tf.float32), axes=None)[0]
def delta3_metric(y_true, y_pred):
maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred)
return tf.nn.moments(tf.cast(maxRatio < tf.convert_to_tensor(1.25 ** 3), tf.float32), axes=None)[0]
def compile(model, optimiser=keras.optimizers.SGD(), loss=keras.losses.MeanSquaredError(), custom_metrics=None):
"""
Compile FastDepth model with relevant metrics
:param model: Model to compile
:param optimiser: Custom optimiser to use
:param loss: Loss function to use
:param include_metrics: Whether to include metrics (RMSE, MSE, a1,2,3)
"""
model.compile(optimizer=optimiser,
loss=loss,
metrics=[keras.metrics.RootMeanSquaredError(),
keras.metrics.MeanSquaredError(),
delta1_metric,
delta2_metric,
delta3_metric] if custom_metrics is None else custom_metrics)
def train(existing_model=None, pretrained_weights='imagenet', epochs=4, save_file=None, dataset=None):
"""
Compile, train and save (if a save file is specified) a Fast Depth model.
:param existing_model: Existing FastDepth model to train. None will create
:param pretrained_weights: Weights to use if existing_model is not specified. See keras.applications.MobileNet
weights parameter for options here.
:param epochs: Number of epochs to run for
:param save_file: File/directory to save to after training. By default the model won't be saved
:param dataset: Train dataset to use. By default will DOWNLOAD and use tensorflow nyu_v2 dataset
"""
if not existing_model:
existing_model = mobilenet_nnconv5(pretrained_weights)
compile(existing_model)
if not dataset:
dataset = load_nyu()
existing_model.fit(dataset, epochs=epochs)
if save_file:
existing_model.save(save_file)
return existing_model
def evaluate(compiled_model, dataset=None):
"""
Evaluate the model using rmse, delta1/2/3 metrics
@@ -152,16 +98,6 @@ def forward(model, image):
return model(crop_and_resize(image))
def load_model(file):
"""
Load previously trained FastDepth model from disk. Will include relevant metrics (custom objects)
:param file: File/directory to load the model from
:return:
"""
return keras.models.load_model(file, custom_objects={'delta1_metric': delta1_metric,
'delta2_metric': delta2_metric,
'delta3_metric': delta3_metric})
if __name__ == '__main__':
model = mobilenet_nnconv5()
model.summary()