Add Kitti depth dataset
Warning: Using this requires >175gb of disk space (tensorflow will also generate examples that will take up space)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras as keras
|
||||
import tensorflow_datasets as tfds
|
||||
# Needed for the kitti dataset, don't delete
|
||||
|
||||
"""
|
||||
Unofficial tensorflow keras implementation of FastDepth (mobilenet_nnconv5).
|
||||
@@ -162,11 +163,15 @@ def load_model(file):
|
||||
|
||||
def crop_and_resize(x):
|
||||
shape = tf.shape(x['depth'])
|
||||
img_shape = tf.shape(x['image'])
|
||||
# Ensure we get a square for when we resize is later.
|
||||
# For horizontal images this is basically just cropping the sides off
|
||||
center_shape = min(shape[1], shape[2], img_shape[1], img_shape[2])
|
||||
|
||||
def layer():
|
||||
return keras.Sequential([
|
||||
keras.layers.experimental.preprocessing.CenterCrop(
|
||||
shape[1], shape[2]),
|
||||
center_shape, center_shape),
|
||||
keras.layers.experimental.preprocessing.Resizing(
|
||||
224, 224, interpolation='nearest')
|
||||
])
|
||||
@@ -175,13 +180,13 @@ def crop_and_resize(x):
|
||||
return layer()(x['image']), layer()(tf.reshape(x['depth'], [shape[0], shape[1], shape[2], 1]))
|
||||
|
||||
|
||||
def load_nyu():
|
||||
def load_nyu(download_dir='../nyu'):
|
||||
"""
|
||||
Load the nyu_v2 dataset train split. Will be downloaded to ../nyu
|
||||
:return: nyu_v2 dataset builder
|
||||
"""
|
||||
builder = tfds.builder('nyu_depth_v2')
|
||||
builder.download_and_prepare(download_dir='../nyu')
|
||||
builder.download_and_prepare(download_dir=download_dir)
|
||||
return builder \
|
||||
.as_dataset(split='train', shuffle_files=True) \
|
||||
.shuffle(buffer_size=1024) \
|
||||
@@ -189,16 +194,22 @@ def load_nyu():
|
||||
.map(lambda x: crop_and_resize(x))
|
||||
|
||||
|
||||
def load_nyu_evaluate():
|
||||
def load_nyu_evaluate(download_dir='../nyu'):
|
||||
"""
|
||||
Load the nyu_v2 dataset validation split. Will be downloaded to ../nyu
|
||||
:return: nyu_v2 dataset builder
|
||||
"""
|
||||
builder = tfds.builder('nyu_depth_v2')
|
||||
builder.download_and_prepare(download_dir='../nyu')
|
||||
builder.download_and_prepare(download_dir=download_dir)
|
||||
return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x))
|
||||
|
||||
|
||||
def load_kitti(download_dir='../kitti'):
|
||||
ds = tfds.builder('kitti_depth')
|
||||
ds.download_and_prepare(download_dir=download_dir)
|
||||
return ds.as_dataset(tfds.Split.TRAIN).batch(8).map(lambda x: crop_and_resize(x))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = mobilenet_nnconv5()
|
||||
model.summary()
|
||||
|
||||
Reference in New Issue
Block a user