50 lines
2.1 KiB
Python
50 lines
2.1 KiB
Python
"""
|
|
Collection of functions to train the various models, and use different losses
|
|
"""
|
|
import tensorflow.keras as keras
|
|
|
|
from load import load_nyu
|
|
from metric import *
|
|
|
|
|
|
def compile(model, optimiser=keras.optimizers.SGD(), loss=keras.losses.MeanSquaredError(), custom_metrics=None):
|
|
"""
|
|
Compile FastDepth model with relevant metrics
|
|
:param model: Model to compile
|
|
:param optimiser: Custom optimiser to use
|
|
:param loss: Loss function to use
|
|
:param include_metrics: Whether to include metrics (RMSE, MSE, a1,2,3)
|
|
"""
|
|
model.compile(optimizer=optimiser,
|
|
loss=loss,
|
|
metrics=[keras.metrics.RootMeanSquaredError(),
|
|
keras.metrics.MeanSquaredError(),
|
|
delta1,
|
|
delta2,
|
|
delta3,
|
|
keras.metrics.MeanAbsolutePercentageError(),
|
|
keras.metrics.MeanAbsoluteError()] if custom_metrics is None else custom_metrics)
|
|
|
|
|
|
def train(existing_model=None, pretrained_weights='imagenet', epochs=4, save_file=None, dataset=None,
|
|
checkpoint='ckpt'):
|
|
"""
|
|
Compile, train and save (if a save file is specified) a Fast Depth model.
|
|
:param existing_model: Existing FastDepth model to train. None will create
|
|
:param pretrained_weights: Weights to use if existing_model is not specified. See keras.applications.MobileNet
|
|
weights parameter for options here.
|
|
:param epochs: Number of epochs to run for
|
|
:param save_file: File/directory to save to after training. By default the model won't be saved
|
|
:param dataset: Train dataset to use. By default will DOWNLOAD and use tensorflow nyu_v2 dataset
|
|
:param checkpoint: Checkpoint to save to
|
|
"""
|
|
callbacks = []
|
|
if checkpoint:
|
|
callbacks.append(keras.callbacks.ModelCheckpoint(checkpoint, save_weights_only=True))
|
|
if not dataset:
|
|
dataset = load_nyu()
|
|
existing_model.fit(dataset, epochs=epochs, callbacks=callbacks)
|
|
if save_file:
|
|
existing_model.save(save_file)
|
|
return existing_model
|