Add Kitti depth dataset

Warning: Using this requires >175gb of disk space (tensorflow will also generate examples that will take up space)
This commit is contained in:
Michael Pivato
2021-04-22 12:13:48 +00:00
parent 02d8cd5810
commit 070aec6eed
3 changed files with 198 additions and 9 deletions

View File

@@ -1,6 +1,7 @@
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow_datasets as tfds
# Needed for the kitti dataset, don't delete
"""
Unofficial tensorflow keras implementation of FastDepth (mobilenet_nnconv5).
@@ -162,11 +163,15 @@ def load_model(file):
def crop_and_resize(x):
shape = tf.shape(x['depth'])
img_shape = tf.shape(x['image'])
# Ensure we get a square for when we resize is later.
# For horizontal images this is basically just cropping the sides off
center_shape = min(shape[1], shape[2], img_shape[1], img_shape[2])
def layer():
return keras.Sequential([
keras.layers.experimental.preprocessing.CenterCrop(
shape[1], shape[2]),
center_shape, center_shape),
keras.layers.experimental.preprocessing.Resizing(
224, 224, interpolation='nearest')
])
@@ -175,13 +180,13 @@ def crop_and_resize(x):
return layer()(x['image']), layer()(tf.reshape(x['depth'], [shape[0], shape[1], shape[2], 1]))
def load_nyu():
def load_nyu(download_dir='../nyu'):
"""
Load the nyu_v2 dataset train split. Will be downloaded to ../nyu
:return: nyu_v2 dataset builder
"""
builder = tfds.builder('nyu_depth_v2')
builder.download_and_prepare(download_dir='../nyu')
builder.download_and_prepare(download_dir=download_dir)
return builder \
.as_dataset(split='train', shuffle_files=True) \
.shuffle(buffer_size=1024) \
@@ -189,16 +194,22 @@ def load_nyu():
.map(lambda x: crop_and_resize(x))
def load_nyu_evaluate():
def load_nyu_evaluate(download_dir='../nyu'):
"""
Load the nyu_v2 dataset validation split. Will be downloaded to ../nyu
:return: nyu_v2 dataset builder
"""
builder = tfds.builder('nyu_depth_v2')
builder.download_and_prepare(download_dir='../nyu')
builder.download_and_prepare(download_dir=download_dir)
return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x))
def load_kitti(download_dir='../kitti'):
ds = tfds.builder('kitti_depth')
ds.download_and_prepare(download_dir=download_dir)
return ds.as_dataset(tfds.Split.TRAIN).batch(8).map(lambda x: crop_and_resize(x))
if __name__ == '__main__':
model = mobilenet_nnconv5()
model.summary()