Refactor load/util, start fixing packnet to support NHWC format

This commit is contained in:
Piv
2021-07-19 12:32:56 +09:30
parent d8bf493999
commit 38e7ad069e
6 changed files with 85 additions and 93 deletions

33
load.py Normal file
View File

@@ -0,0 +1,33 @@
from util import crop_and_resize
import tensorflow_datasets as tfds
import tensorflow.keras as keras
def load_nyu(download_dir='../nyu', out_shape=(224, 224)):
"""
Load the nyu_v2 dataset train split. Will be downloaded to ../nyu
:return: nyu_v2 dataset builder
"""
builder = tfds.builder('nyu_depth_v2')
builder.download_and_prepare(download_dir=download_dir)
return builder \
.as_dataset(split='train', shuffle_files=True) \
.shuffle(buffer_size=1024) \
.batch(8) \
.map(lambda x: crop_and_resize(x, out_shape))
def load_nyu_evaluate(download_dir='../nyu', out_shape=(224, 224)):
"""
Load the nyu_v2 dataset validation split. Will be downloaded to ../nyu
:return: nyu_v2 dataset builder
"""
builder = tfds.builder('nyu_depth_v2')
builder.download_and_prepare(download_dir=download_dir)
return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x, out_shape))
def load_kitti(download_dir='../kitti', out_shape=(224, 224)):
ds = tfds.builder('kitti_depth')
ds.download_and_prepare(download_dir=download_dir)
return ds.as_dataset(tfds.Split.TRAIN).batch(8).map(lambda x: crop_and_resize(x, out_shape))