Refactor load/util, start fixing packnet to support NHWC format
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras as keras
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
@@ -25,7 +24,8 @@ def dense_depth(size, weights=None, shape=(224, 224, 3)):
|
||||
densenet_output_channels = densenet.layers[-1].output.shape[-1]
|
||||
|
||||
# Reduce the feature set (pointwise)
|
||||
decoder = keras.layers.Conv2D(filters=densenet_output_channels, kernel_size=1, padding='same')(densenet.output)
|
||||
decoder = keras.layers.Conv2D(
|
||||
filters=densenet_output_channels, kernel_size=1, padding='same')(densenet.output)
|
||||
|
||||
# The actual decoder
|
||||
decoder = dense_upsample_block(
|
||||
@@ -66,19 +66,19 @@ def dense_nnconv5(size, weights=None, shape=(224, 224, 3), half_features=True):
|
||||
|
||||
# Reduce the feature set (pointwise)
|
||||
decoder = keras.layers.Conv2D(filters=int(densenet_output_shape[-1]), kernel_size=1, padding='same',
|
||||
input_shape=densenet_output_shape, name='conv2')(densenet.output)
|
||||
input_shape=densenet_output_shape, name='conv2')(densenet.output)
|
||||
|
||||
# TODO: More intermediate layers here?
|
||||
|
||||
# Fast Depth Decoder
|
||||
decoder = fd.nnconv5(decoder, densenet.get_layer('pool3_pool').output_shape[3], 1,
|
||||
skip_connection=densenet.get_layer('pool3_pool').output)
|
||||
skip_connection=densenet.get_layer('pool3_pool').output)
|
||||
decoder = fd.nnconv5(decoder, densenet.get_layer('pool2_pool').output_shape[3], 2,
|
||||
skip_connection=densenet.get_layer('pool2_pool').output)
|
||||
skip_connection=densenet.get_layer('pool2_pool').output)
|
||||
decoder = fd.nnconv5(decoder, densenet.get_layer('pool1').output_shape[3], 3,
|
||||
skip_connection=densenet.get_layer('pool1').output)
|
||||
skip_connection=densenet.get_layer('pool1').output)
|
||||
decoder = fd.nnconv5(decoder, densenet.get_layer('conv1/relu').output_shape[3], 4,
|
||||
skip_connection=densenet.get_layer('conv1/relu').output)
|
||||
skip_connection=densenet.get_layer('conv1/relu').output)
|
||||
|
||||
# Final Pointwise for depth extraction
|
||||
decoder = keras.layers.Conv2D(1, 1, padding='same')(decoder)
|
||||
@@ -87,30 +87,6 @@ def dense_nnconv5(size, weights=None, shape=(224, 224, 3), half_features=True):
|
||||
return keras.Model(inputs=input, outputs=decoder, name="fast_dense_depth")
|
||||
|
||||
|
||||
def load_nyu(download_dir='../nyu'):
|
||||
"""
|
||||
Load the nyu_v2 dataset train split. Will be downloaded to ../nyu
|
||||
:return: nyu_v2 dataset builder
|
||||
"""
|
||||
builder = tfds.builder('nyu_depth_v2')
|
||||
builder.download_and_prepare(download_dir=download_dir)
|
||||
return builder \
|
||||
.as_dataset(split='train', shuffle_files=True) \
|
||||
.shuffle(buffer_size=1024) \
|
||||
.batch(8) \
|
||||
.map(lambda x: fd.crop_and_resize(x))
|
||||
|
||||
|
||||
def load_nyu_evaluate(download_dir='../nyu'):
|
||||
"""
|
||||
Load the nyu_v2 dataset validation split. Will be downloaded to ../nyu
|
||||
:return: nyu_v2 dataset builder
|
||||
"""
|
||||
builder = tfds.builder('nyu_depth_v2')
|
||||
builder.download_and_prepare(download_dir=download_dir)
|
||||
return builder.as_dataset(split='validation').batch(1).map(lambda x: fd.crop_and_resize(x))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = dense_depth(169, 'imagenet')
|
||||
model.summary()
|
||||
|
||||
Reference in New Issue
Block a user