Files
fast-depth-tf/train.py

50 lines
2.1 KiB
Python

"""
Collection of functions to train the various models, and use different losses
"""
import tensorflow.keras as keras
from load import load_nyu
from metric import *
def compile(model, optimiser=keras.optimizers.SGD(), loss=keras.losses.MeanSquaredError(), custom_metrics=None):
"""
Compile FastDepth model with relevant metrics
:param model: Model to compile
:param optimiser: Custom optimiser to use
:param loss: Loss function to use
:param include_metrics: Whether to include metrics (RMSE, MSE, a1,2,3)
"""
model.compile(optimizer=optimiser,
loss=loss,
metrics=[keras.metrics.RootMeanSquaredError(),
keras.metrics.MeanSquaredError(),
delta1,
delta2,
delta3,
keras.metrics.MeanAbsolutePercentageError(),
keras.metrics.MeanAbsoluteError()] if custom_metrics is None else custom_metrics)
def train(existing_model=None, pretrained_weights='imagenet', epochs=4, save_file=None, dataset=None,
checkpoint='ckpt'):
"""
Compile, train and save (if a save file is specified) a Fast Depth model.
:param existing_model: Existing FastDepth model to train. None will create
:param pretrained_weights: Weights to use if existing_model is not specified. See keras.applications.MobileNet
weights parameter for options here.
:param epochs: Number of epochs to run for
:param save_file: File/directory to save to after training. By default the model won't be saved
:param dataset: Train dataset to use. By default will DOWNLOAD and use tensorflow nyu_v2 dataset
:param checkpoint: Checkpoint to save to
"""
callbacks = []
if checkpoint:
callbacks.append(keras.callbacks.ModelCheckpoint(checkpoint, save_weights_only=True))
if not dataset:
dataset = load_nyu()
existing_model.fit(dataset, epochs=epochs, callbacks=callbacks)
if save_file:
existing_model.save(save_file)
return existing_model