More packnet implementation

This commit is contained in:
Piv
2021-07-18 19:54:28 +09:30
parent 603de2bc9f
commit d8bf493999

View File

@@ -30,8 +30,11 @@ def residual_layer(inputs, out_channels, stride, dropout=None):
# Packnet usually expects more than one layer per block (2,2,3,3)
def residual_bock(inputs, out_channels, residual_layers, stride, dropout=None):
pass
def residual_block(inputs, out_channels, residual_layers, stride, dropout=None):
x = inputs
for i in range(0, residual_layers):
x = residual_layer(x, out_channels, stride, dropout)
return x
def packnet_conv2d(inputs, out_channels, kernel_size, stride):
@@ -73,19 +76,43 @@ def unpack_3d(inputs, out_channels, kernel_size, r=3, features_3d=8):
# TODO: Support different size packnet for scaling up/down
def make_packnet(shape=(224, 224, 3), skip_add=True, features_3d=4):
def make_packnet(shape=(224, 224, 3), skip_add=True, features_3d=4, dropout=None):
"""
Make the PackNet depth network.
:param shape: Input shape of the image
:param skip_add: Set to use add rather than concat skip connections, defaults to True
:return:
"""
# ================ ENCODER =================
input = keras.layers.Input(shape=shape)
x = packnet_conv2d(input, 32, 5, 1)
skip_1 = x
x = packnet_conv2d(input, 32, 7, 1)
x = pack_3d(x, 5, features_3d)
x = residual_layer(x, 64, )
x = packnet_conv2d(x, 32, 7, 1)
x = pack_3d(x, 5, features_3d=features_3d)
skip_2 = x
x = residual_block(x, 64, 2, 1, dropout)
x = pack_3d(x, 3, features_3d=features_3d)
skip_3 = x
x = residual_block(x, 128, 2, 1, dropout)
x = pack_3d(x, 3, features_3d=features_3d)
skip_4 = x
x = residual_block(x, 256, 3, 1, dropout)
x = pack_3d(x, 3, features_3d=features_3d)
skip_5 = x
x = residual_block(x, 512, 3, 1, dropout)
x = pack_3d(x, 3, features_3d=features_3d)
# ================ ENCODER =================
x = unpack_3d(x, 512, 3, features_3d=features_3d)
x = keras.layers.Add([x, skip_5]) if skip_add else keras.layers.Concatenate([x, skip_5])
x = packnet_conv2d(x, 512, 3, 1)
x = unpack_3d(x, 256, 3, features_3d=features_3d)
x = keras.layers.Add([x, skip_4]) if skip_add else keras.layers.Concatenate([x, skip_4])
x = packnet_conv2d(x, 256, 3, 1)
# TODO: This is wrong, look at the paper
x = packnet_inverse_depth(x, 1)
x = keras.layers.UpSampling2D()
# TODO: Skip connection
if skip_add: