Add compiling packnet model, refactor modules to not duplicate loaders and trainers
This commit is contained in:
@@ -19,10 +19,10 @@ def residual_layer(inputs, out_channels, stride, dropout=None):
|
||||
:param dropout:
|
||||
:return:
|
||||
"""
|
||||
x = layers.Conv2D(out_channels, 3, padding='same', stride=stride)(inputs)
|
||||
x = layers.Conv2D(out_channels, 3, padding='same', strides=stride)(inputs)
|
||||
x = layers.Conv2D(out_channels, 3, padding='same')(x)
|
||||
shortcut = layers.Conv2D(
|
||||
out_channels, 3, padding='same', stride=stride)(inputs)
|
||||
out_channels, 3, padding='same', strides=stride)(inputs)
|
||||
if dropout:
|
||||
shortcut = keras.layers.SpatialDropout2D(dropout)(shortcut)
|
||||
x = keras.layers.Concatenate()([x, shortcut])
|
||||
@@ -46,7 +46,7 @@ def packnet_conv2d(inputs, out_channels, kernel_size, stride):
|
||||
|
||||
|
||||
def packnet_inverse_depth(inputs, out_channels=1, min_depth=0.5):
|
||||
x = packnet_conv2d(inputs, out_channels, kernel_size=3, stride=1)
|
||||
x = layers.Conv2D(out_channels, 3, padding='same')(inputs)
|
||||
return keras.activations.sigmoid(x) / min_depth
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ def pack_3d(inputs, kernel_size, r=2, features_3d=8):
|
||||
x = tf.expand_dims(x, 4)
|
||||
x = keras.layers.Conv3D(features_3d, kernel_size=3, padding='same')(x)
|
||||
b, h, w, c, d = x.shape
|
||||
x = tf.reshape(x, (b, h, w, c * d))
|
||||
x = keras.layers.Reshape((h, w, c * d))(x)
|
||||
return packnet_conv2d(x, inputs.shape[3], kernel_size, 1)
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ def unpack_3d(inputs, out_channels, kernel_size, r=2, features_3d=8):
|
||||
x = tf.expand_dims(x, 4) # B x H/2 x W/2 x 4(out)/D x D
|
||||
x = keras.layers.Conv3D(features_3d, kernel_size=3, padding='same')(x)
|
||||
b, h, w, c, d = x.shape
|
||||
x = tf.reshape(x, [b, h, w, c * d])
|
||||
x = keras.layers.Reshape([h, w, c * d])(x)
|
||||
return nn.depth_to_space(x, r)
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ def make_packnet(shape=(224, 224, 3), skip_add=True, features_3d=4, dropout=None
|
||||
input = keras.layers.Input(shape=shape)
|
||||
x = packnet_conv2d(input, 32, 5, 1)
|
||||
skip_1 = x
|
||||
x = packnet_conv2d(x, 32, 7, 1)
|
||||
x = packnet_conv2d(x, 64, 7, 1)
|
||||
x = pack_3d(x, 5, features_3d=features_3d)
|
||||
skip_2 = x
|
||||
x = residual_block(x, 64, 2, 1, dropout)
|
||||
@@ -108,24 +108,43 @@ def make_packnet(shape=(224, 224, 3), skip_add=True, features_3d=4, dropout=None
|
||||
x = pack_3d(x, 3, features_3d=features_3d)
|
||||
# ================ ENCODER =================
|
||||
|
||||
# ================ DECODER =================
|
||||
# layer 7
|
||||
x = unpack_3d(x, 512, 3, features_3d=features_3d)
|
||||
x = keras.layers.Add(
|
||||
[x, skip_5]) if skip_add else keras.layers.Concatenate([x, skip_5])
|
||||
x = keras.layers.Add()(
|
||||
[x, skip_5]) if skip_add else keras.layers.Concatenate()([x, skip_5])
|
||||
x = packnet_conv2d(x, 512, 3, 1)
|
||||
# layer 8
|
||||
x = unpack_3d(x, 256, 3, features_3d=features_3d)
|
||||
x = keras.layers.Add(
|
||||
[x, skip_4]) if skip_add else keras.layers.Concatenate([x, skip_4])
|
||||
x = keras.layers.Add()(
|
||||
[x, skip_4]) if skip_add else keras.layers.Concatenate()([x, skip_4])
|
||||
x = packnet_conv2d(x, 256, 3, 1)
|
||||
# TODO: This is wrong, look at the paper
|
||||
layer_8 = x
|
||||
# layer 9
|
||||
x = packnet_inverse_depth(x, 1)
|
||||
x = keras.layers.UpSampling2D()
|
||||
|
||||
# TODO: Skip connection
|
||||
if skip_add:
|
||||
x = keras.layers.Add([x, ])
|
||||
else:
|
||||
x = keras.layers.Concatenate([x, ])
|
||||
|
||||
x = packnet_conv2d(x, 32, 3, 1)
|
||||
# layer 10
|
||||
u_layer_8 = unpack_3d(layer_8, 128, 3, features_3d=features_3d)
|
||||
x = keras.layers.UpSampling2D()(x)
|
||||
x = keras.layers.Add()([u_layer_8, skip_3, x]) if skip_add else keras.layers.Concatenate()([u_layer_8, skip_3, x])
|
||||
x = packnet_conv2d(x, 128, 3, 1)
|
||||
layer_10 = x
|
||||
# layer 11
|
||||
x = packnet_inverse_depth(x, 1)
|
||||
# layer 12
|
||||
u_layer_10 = unpack_3d(layer_10, 64, 3, features_3d=features_3d)
|
||||
x = keras.layers.UpSampling2D()(x)
|
||||
x = keras.layers.Add()([u_layer_10, skip_2, x]) if skip_add else keras.layers.Concatenate()([u_layer_10, skip_2, x])
|
||||
x = packnet_conv2d(x, 64, 3, 1)
|
||||
layer_12 = x
|
||||
# layer 13
|
||||
x = packnet_inverse_depth(x)
|
||||
# layer 14
|
||||
u_layer_12 = unpack_3d(layer_12, 32, 3, features_3d=features_3d)
|
||||
x = keras.layers.UpSampling2D()(x)
|
||||
x = keras.layers.Add()([u_layer_12, skip_1, x]) if skip_add else keras.layers.Concatenate()([u_layer_12, skip_1, x])
|
||||
x = packnet_conv2d(x, 32, 3, 1)
|
||||
# layer 15
|
||||
x = packnet_inverse_depth(x)
|
||||
# ================ DECODER =================
|
||||
|
||||
return keras.Model(inputs=input, outputs=x, name="PackNet")
|
||||
|
||||
Reference in New Issue
Block a user