Add compiling packnet model, refactor modules to not duplicate loaders and trainers
This commit is contained in:
19
load.py
19
load.py
@@ -1,6 +1,9 @@
|
||||
from util import crop_and_resize
|
||||
import tensorflow_datasets as tfds
|
||||
import tensorflow.keras as keras
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
from losses import dense_depth_loss_function
|
||||
from metric import *
|
||||
from util import crop_and_resize
|
||||
|
||||
|
||||
def load_nyu(download_dir='../nyu', out_shape=(224, 224)):
|
||||
@@ -31,3 +34,15 @@ def load_kitti(download_dir='../kitti', out_shape=(224, 224)):
|
||||
ds = tfds.builder('kitti_depth')
|
||||
ds.download_and_prepare(download_dir=download_dir)
|
||||
return ds.as_dataset(tfds.Split.TRAIN).batch(8).map(lambda x: crop_and_resize(x, out_shape))
|
||||
|
||||
|
||||
def load_model(file):
|
||||
"""
|
||||
Load previously trained FastDepth model from disk. Will include relevant metrics (custom objects)
|
||||
:param file: File/directory to load the model from
|
||||
:return:
|
||||
"""
|
||||
return keras.models.load_model(file, custom_objects={'delta1_metric': delta1_metric,
|
||||
'delta2_metric': delta2,
|
||||
'delta3_metric': delta3,
|
||||
'dense_depth_loss_function': dense_depth_loss_function})
|
||||
|
||||
Reference in New Issue
Block a user