from util import crop_and_resize import tensorflow_datasets as tfds import tensorflow.keras as keras def load_nyu(download_dir='../nyu', out_shape=(224, 224)): """ Load the nyu_v2 dataset train split. Will be downloaded to ../nyu :return: nyu_v2 dataset builder """ builder = tfds.builder('nyu_depth_v2') builder.download_and_prepare(download_dir=download_dir) return builder \ .as_dataset(split='train', shuffle_files=True) \ .shuffle(buffer_size=1024) \ .batch(8) \ .map(lambda x: crop_and_resize(x, out_shape)) def load_nyu_evaluate(download_dir='../nyu', out_shape=(224, 224)): """ Load the nyu_v2 dataset validation split. Will be downloaded to ../nyu :return: nyu_v2 dataset builder """ builder = tfds.builder('nyu_depth_v2') builder.download_and_prepare(download_dir=download_dir) return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x, out_shape)) def load_kitti(download_dir='../kitti', out_shape=(224, 224)): ds = tfds.builder('kitti_depth') ds.download_and_prepare(download_dir=download_dir) return ds.as_dataset(tfds.Split.TRAIN).batch(8).map(lambda x: crop_and_resize(x, out_shape))