""" Collection of functions to train the various models, and use different losses """ import tensorflow.keras as keras from load import load_nyu from metric import * def compile(model, optimiser=keras.optimizers.SGD(), loss=keras.losses.MeanSquaredError(), custom_metrics=None): """ Compile FastDepth model with relevant metrics :param model: Model to compile :param optimiser: Custom optimiser to use :param loss: Loss function to use :param include_metrics: Whether to include metrics (RMSE, MSE, a1,2,3) """ model.compile(optimizer=optimiser, loss=loss, metrics=[keras.metrics.RootMeanSquaredError(), keras.metrics.MeanSquaredError(), delta1, delta2, delta3, keras.metrics.MeanAbsolutePercentageError(), keras.metrics.MeanAbsoluteError()] if custom_metrics is None else custom_metrics) def train(existing_model=None, pretrained_weights='imagenet', epochs=4, save_file=None, dataset=None, checkpoint='ckpt'): """ Compile, train and save (if a save file is specified) a Fast Depth model. :param existing_model: Existing FastDepth model to train. None will create :param pretrained_weights: Weights to use if existing_model is not specified. See keras.applications.MobileNet weights parameter for options here. :param epochs: Number of epochs to run for :param save_file: File/directory to save to after training. By default the model won't be saved :param dataset: Train dataset to use. By default will DOWNLOAD and use tensorflow nyu_v2 dataset :param checkpoint: Checkpoint to save to """ callbacks = [] if checkpoint: callbacks.append(keras.callbacks.ModelCheckpoint(checkpoint, save_weights_only=True)) if not dataset: dataset = load_nyu() existing_model.fit(dataset, epochs=epochs, callbacks=callbacks) if save_file: existing_model.save(save_file) return existing_model