49 lines
1.8 KiB
Python
49 lines
1.8 KiB
Python
import tensorflow.keras as keras
|
|
import tensorflow_datasets as tfds
|
|
|
|
from losses import dense_depth_loss_function
|
|
from metric import *
|
|
from util import crop_and_resize
|
|
|
|
|
|
def load_nyu(download_dir='../nyu', out_shape=(224, 224)):
|
|
"""
|
|
Load the nyu_v2 dataset train split. Will be downloaded to ../nyu
|
|
:return: nyu_v2 dataset builder
|
|
"""
|
|
builder = tfds.builder('nyu_depth_v2')
|
|
builder.download_and_prepare(download_dir=download_dir)
|
|
return builder \
|
|
.as_dataset(split='train', shuffle_files=True) \
|
|
.shuffle(buffer_size=1024) \
|
|
.batch(8) \
|
|
.map(lambda x: crop_and_resize(x, out_shape))
|
|
|
|
|
|
def load_nyu_evaluate(download_dir='../nyu', out_shape=(224, 224)):
|
|
"""
|
|
Load the nyu_v2 dataset validation split. Will be downloaded to ../nyu
|
|
:return: nyu_v2 dataset builder
|
|
"""
|
|
builder = tfds.builder('nyu_depth_v2')
|
|
builder.download_and_prepare(download_dir=download_dir)
|
|
return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x, out_shape))
|
|
|
|
|
|
def load_kitti(download_dir='../kitti', out_shape=(224, 224)):
|
|
ds = tfds.builder('kitti_depth')
|
|
ds.download_and_prepare(download_dir=download_dir)
|
|
return ds.as_dataset(tfds.Split.TRAIN).batch(8).map(lambda x: crop_and_resize(x, out_shape))
|
|
|
|
|
|
def load_model(file):
|
|
"""
|
|
Load previously trained FastDepth model from disk. Will include relevant metrics (custom objects)
|
|
:param file: File/directory to load the model from
|
|
:return:
|
|
"""
|
|
return keras.models.load_model(file, custom_objects={'delta1_metric': delta1_metric,
|
|
'delta2_metric': delta2,
|
|
'delta3_metric': delta3,
|
|
'dense_depth_loss_function': dense_depth_loss_function})
|