Refactor load/util, start fixing packnet to support NHWC format
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras as keras
|
||||
import tensorflow_datasets as tfds
|
||||
from load import load_nyu, load_nyu_evaluate
|
||||
from util import crop_and_resize
|
||||
# Needed for the kitti dataset, don't delete
|
||||
|
||||
"""
|
||||
@@ -160,56 +162,6 @@ def load_model(file):
|
||||
'delta2_metric': delta2_metric,
|
||||
'delta3_metric': delta3_metric})
|
||||
|
||||
|
||||
def crop_and_resize(x):
|
||||
shape = tf.shape(x['depth'])
|
||||
img_shape = tf.shape(x['image'])
|
||||
# Ensure we get a square for when we resize is later.
|
||||
# For horizontal images this is basically just cropping the sides off
|
||||
center_shape = min(shape[1], shape[2], img_shape[1], img_shape[2])
|
||||
|
||||
def layer():
|
||||
return keras.Sequential([
|
||||
keras.layers.experimental.preprocessing.CenterCrop(
|
||||
center_shape, center_shape),
|
||||
keras.layers.experimental.preprocessing.Resizing(
|
||||
224, 224, interpolation='nearest')
|
||||
])
|
||||
|
||||
# Reshape label to 4d, can't use array unwrap as it's unsupported by tensorflow
|
||||
return layer()(x['image']), layer()(tf.reshape(x['depth'], [shape[0], shape[1], shape[2], 1]))
|
||||
|
||||
|
||||
def load_nyu(download_dir='../nyu'):
|
||||
"""
|
||||
Load the nyu_v2 dataset train split. Will be downloaded to ../nyu
|
||||
:return: nyu_v2 dataset builder
|
||||
"""
|
||||
builder = tfds.builder('nyu_depth_v2')
|
||||
builder.download_and_prepare(download_dir=download_dir)
|
||||
return builder \
|
||||
.as_dataset(split='train', shuffle_files=True) \
|
||||
.shuffle(buffer_size=1024) \
|
||||
.batch(8) \
|
||||
.map(lambda x: crop_and_resize(x))
|
||||
|
||||
|
||||
def load_nyu_evaluate(download_dir='../nyu'):
|
||||
"""
|
||||
Load the nyu_v2 dataset validation split. Will be downloaded to ../nyu
|
||||
:return: nyu_v2 dataset builder
|
||||
"""
|
||||
builder = tfds.builder('nyu_depth_v2')
|
||||
builder.download_and_prepare(download_dir=download_dir)
|
||||
return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x))
|
||||
|
||||
|
||||
def load_kitti(download_dir='../kitti'):
|
||||
ds = tfds.builder('kitti_depth')
|
||||
ds.download_and_prepare(download_dir=download_dir)
|
||||
return ds.as_dataset(tfds.Split.TRAIN).batch(8).map(lambda x: crop_and_resize(x))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = mobilenet_nnconv5()
|
||||
model.summary()
|
||||
|
||||
Reference in New Issue
Block a user