Add compiling packnet model, refactor modules to not duplicate loaders and trainers
This commit is contained in:
49
train.py
Normal file
49
train.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Collection of functions to train the various models, and use different losses
|
||||
"""
|
||||
import tensorflow.keras as keras
|
||||
|
||||
from load import load_nyu
|
||||
from metric import *
|
||||
|
||||
|
||||
def compile(model, optimiser=keras.optimizers.SGD(), loss=keras.losses.MeanSquaredError(), custom_metrics=None):
|
||||
"""
|
||||
Compile FastDepth model with relevant metrics
|
||||
:param model: Model to compile
|
||||
:param optimiser: Custom optimiser to use
|
||||
:param loss: Loss function to use
|
||||
:param include_metrics: Whether to include metrics (RMSE, MSE, a1,2,3)
|
||||
"""
|
||||
model.compile(optimizer=optimiser,
|
||||
loss=loss,
|
||||
metrics=[keras.metrics.RootMeanSquaredError(),
|
||||
keras.metrics.MeanSquaredError(),
|
||||
delta1_metric,
|
||||
delta2,
|
||||
delta3,
|
||||
keras.metrics.MeanAbsolutePercentageError(),
|
||||
keras.metrics.MeanAbsoluteError()] if custom_metrics is None else custom_metrics)
|
||||
|
||||
|
||||
def train(existing_model=None, pretrained_weights='imagenet', epochs=4, save_file=None, dataset=None,
|
||||
checkpoint='ckpt'):
|
||||
"""
|
||||
Compile, train and save (if a save file is specified) a Fast Depth model.
|
||||
:param existing_model: Existing FastDepth model to train. None will create
|
||||
:param pretrained_weights: Weights to use if existing_model is not specified. See keras.applications.MobileNet
|
||||
weights parameter for options here.
|
||||
:param epochs: Number of epochs to run for
|
||||
:param save_file: File/directory to save to after training. By default the model won't be saved
|
||||
:param dataset: Train dataset to use. By default will DOWNLOAD and use tensorflow nyu_v2 dataset
|
||||
:param checkpoint: Checkpoint to save to
|
||||
"""
|
||||
callbacks = []
|
||||
if checkpoint:
|
||||
callbacks.append(keras.callbacks.ModelCheckpoint(checkpoint, save_weights_only=True))
|
||||
if not dataset:
|
||||
dataset = load_nyu()
|
||||
existing_model.fit(dataset, epochs=epochs, callbacks=callbacks)
|
||||
if save_file:
|
||||
existing_model.save(save_file)
|
||||
return existing_model
|
||||
Reference in New Issue
Block a user