Remove half-features from dense_depth
This commit is contained in:
@@ -64,55 +64,27 @@ def dense_nnconv5(size, weights=None, shape=(224, 224, 3), half_features=True):
|
|||||||
densenet = dense_net(input, size, weights, shape)
|
densenet = dense_net(input, size, weights, shape)
|
||||||
densenet_output_shape = densenet.layers[-1].output.shape
|
densenet_output_shape = densenet.layers[-1].output.shape
|
||||||
|
|
||||||
if half_features:
|
|
||||||
decode_filters = int(int(densenet_output_shape[-1]) / 2)
|
|
||||||
else:
|
|
||||||
decode_filters = int(densenet_output_shape[-1])
|
|
||||||
|
|
||||||
# Reduce the feature set (pointwise)
|
# Reduce the feature set (pointwise)
|
||||||
x = keras.layers.Conv2D(filters=decode_filters, kernel_size=1, padding='same',
|
decoder = keras.layers.Conv2D(filters=int(densenet_output_shape[-1]), kernel_size=1, padding='same',
|
||||||
input_shape=densenet_output_shape, name='conv2')(densenet.output)
|
input_shape=densenet_output_shape, name='conv2')(densenet.output)
|
||||||
|
|
||||||
# TODO: More intermediate layers here?
|
# TODO: More intermediate layers here?
|
||||||
|
|
||||||
# Fast Depth Decoder
|
# Fast Depth Decoder
|
||||||
x = fd.nnconv5(x, densenet.get_layer('pool3_pool').output_shape[3], 1,
|
decoder = fd.nnconv5(decoder, densenet.get_layer('pool3_pool').output_shape[3], 1,
|
||||||
skip_connection=densenet.get_layer('pool3_pool').output)
|
skip_connection=densenet.get_layer('pool3_pool').output)
|
||||||
x = fd.nnconv5(x, densenet.get_layer('pool2_pool').output_shape[3], 2,
|
decoder = fd.nnconv5(decoder, densenet.get_layer('pool2_pool').output_shape[3], 2,
|
||||||
skip_connection=densenet.get_layer('pool2_pool').output)
|
skip_connection=densenet.get_layer('pool2_pool').output)
|
||||||
x = fd.nnconv5(x, densenet.get_layer('pool1').output_shape[3], 3,
|
decoder = fd.nnconv5(decoder, densenet.get_layer('pool1').output_shape[3], 3,
|
||||||
skip_connection=densenet.get_layer('pool1').output)
|
skip_connection=densenet.get_layer('pool1').output)
|
||||||
x = fd.nnconv5(x, densenet.get_layer('conv1/relu').output_shape[3], 4,
|
decoder = fd.nnconv5(decoder, densenet.get_layer('conv1/relu').output_shape[3], 4,
|
||||||
skip_connection=densenet.get_layer('conv1/relu').output)
|
skip_connection=densenet.get_layer('conv1/relu').output)
|
||||||
|
|
||||||
# Final Pointwise for depth extraction
|
# Final Pointwise for depth extraction
|
||||||
x = keras.layers.Conv2D(1, 1, padding='same')(x)
|
decoder = keras.layers.Conv2D(1, 1, padding='same')(decoder)
|
||||||
x = keras.layers.BatchNormalization()(x)
|
decoder = keras.layers.BatchNormalization()(decoder)
|
||||||
x = keras.layers.ReLU(6.)(x)
|
decoder = keras.layers.ReLU(6.)(decoder)
|
||||||
return keras.Model(inputs=input, outputs=x, name="fast_dense_depth")
|
return keras.Model(inputs=input, outputs=decoder, name="fast_dense_depth")
|
||||||
|
|
||||||
|
|
||||||
def crop_and_resize(x):
|
|
||||||
shape = tf.shape(x['depth'])
|
|
||||||
|
|
||||||
def layer():
|
|
||||||
return keras.Sequential([
|
|
||||||
keras.layers.experimental.preprocessing.CenterCrop(
|
|
||||||
shape[1], shape[2]),
|
|
||||||
keras.layers.experimental.preprocessing.Resizing(
|
|
||||||
224, 224, interpolation='nearest')
|
|
||||||
])
|
|
||||||
|
|
||||||
def half_layer():
|
|
||||||
return keras.Sequential([
|
|
||||||
keras.layers.experimental.preprocessing.CenterCrop(
|
|
||||||
shape[1], shape[2]),
|
|
||||||
keras.layers.experimental.preprocessing.Resizing(
|
|
||||||
112, 112, interpolation='nearest')
|
|
||||||
])
|
|
||||||
|
|
||||||
# Reshape label to 4d, can't use array unwrap as it's unsupported by tensorflow
|
|
||||||
return layer()(x['image']), half_layer()(tf.reshape(x['depth'], [shape[0], shape[1], shape[2], 1]))
|
|
||||||
|
|
||||||
|
|
||||||
def load_nyu():
|
def load_nyu():
|
||||||
@@ -126,7 +98,7 @@ def load_nyu():
|
|||||||
.as_dataset(split='train', shuffle_files=True) \
|
.as_dataset(split='train', shuffle_files=True) \
|
||||||
.shuffle(buffer_size=1024) \
|
.shuffle(buffer_size=1024) \
|
||||||
.batch(8) \
|
.batch(8) \
|
||||||
.map(lambda x: crop_and_resize(x))
|
.map(lambda x: fd.crop_and_resize(x))
|
||||||
|
|
||||||
|
|
||||||
def load_nyu_evaluate():
|
def load_nyu_evaluate():
|
||||||
@@ -136,7 +108,7 @@ def load_nyu_evaluate():
|
|||||||
"""
|
"""
|
||||||
builder = tfds.builder('nyu_depth_v2')
|
builder = tfds.builder('nyu_depth_v2')
|
||||||
builder.download_and_prepare(download_dir='../nyu')
|
builder.download_and_prepare(download_dir='../nyu')
|
||||||
return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x))
|
return builder.as_dataset(split='validation').batch(1).map(lambda x: fd.crop_and_resize(x))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user