Add compiling packnet model, refactor modules to not duplicate loaders and trainers

This commit is contained in:
Piv
2021-07-23 22:41:46 +09:30
parent 66cbc7faf6
commit 3254eef4bf
8 changed files with 135 additions and 96 deletions

19
load.py
View File

@@ -1,6 +1,9 @@
from util import crop_and_resize
import tensorflow_datasets as tfds
import tensorflow.keras as keras
import tensorflow_datasets as tfds
from losses import dense_depth_loss_function
from metric import *
from util import crop_and_resize
def load_nyu(download_dir='../nyu', out_shape=(224, 224)):
@@ -31,3 +34,15 @@ def load_kitti(download_dir='../kitti', out_shape=(224, 224)):
ds = tfds.builder('kitti_depth')
ds.download_and_prepare(download_dir=download_dir)
return ds.as_dataset(tfds.Split.TRAIN).batch(8).map(lambda x: crop_and_resize(x, out_shape))
def load_model(file):
"""
Load previously trained FastDepth model from disk. Will include relevant metrics (custom objects)
:param file: File/directory to load the model from
:return:
"""
return keras.models.load_model(file, custom_objects={'delta1_metric': delta1_metric,
'delta2_metric': delta2,
'delta3_metric': delta3,
'dense_depth_loss_function': dense_depth_loss_function})