diff --git a/packnet_functional.py b/packnet_functional.py index c404c07..3929df6 100644 --- a/packnet_functional.py +++ b/packnet_functional.py @@ -22,10 +22,10 @@ def residual_layer(inputs, out_channels, stride, dropout=None): x = layers.Conv2D(out_channels, 3, padding='same', strides=stride)(inputs) x = layers.Conv2D(out_channels, 3, padding='same')(x) shortcut = layers.Conv2D( - out_channels, 3, padding='same', strides=stride)(inputs) + out_channels, 1, padding='same', strides=stride)(inputs) if dropout: shortcut = keras.layers.SpatialDropout2D(dropout)(shortcut) - x = keras.layers.Concatenate()([x, shortcut]) + x = keras.layers.Add()([x, shortcut]) x = group_norm.GroupNormalization(16)(x) return keras.layers.ELU()(x) @@ -53,10 +53,10 @@ def packnet_inverse_depth(inputs, out_channels=1, min_depth=0.5): def pack_3d(inputs, kernel_size, r=2, features_3d=8): """ Implementatino of the 3d packing block proposed here: https://arxiv.org/abs/1905.02693 - :param inputs: - :param kernel_size: - :param r: - :param features_3d: + :param inputs: Tensor inputs + :param kernel_size: Conv3D kernels size + :param r: Packing factor + :param features_3d: Packing depth (increase to increase number of parameters and accuracy) :return: """ # Data format for single image in nyu is HWC (space_to_depth uses NHWC as default) @@ -78,7 +78,6 @@ def unpack_3d(inputs, out_channels, kernel_size, r=2, features_3d=8): return nn.depth_to_space(x, r) -# TODO: Support different size packnet for scaling up/down # TODO: Support different channel format (right now we're supporting NHWC, we should also support NCHW) def make_packnet(shape=(224, 224, 3), skip_add=True, features_3d=4, dropout=None): """ @@ -109,42 +108,48 @@ def make_packnet(shape=(224, 224, 3), skip_add=True, features_3d=4, dropout=None # ================ ENCODER ================= # ================ DECODER ================= - # layer 7 - x = unpack_3d(x, 512, 3, features_3d=features_3d) + # Addition requires we half the outputs so there is a matching number of channels + divide_factor = (2 if skip_add else 1) + # layer 12 - 13 + x = unpack_3d(x, 512 // divide_factor, 3, features_3d=features_3d) x = keras.layers.Add()( [x, skip_5]) if skip_add else keras.layers.Concatenate()([x, skip_5]) x = packnet_conv2d(x, 512, 3, 1) - # layer 8 - x = unpack_3d(x, 256, 3, features_3d=features_3d) + # layer 14 - 15 + x = unpack_3d(x, 256 // divide_factor, 3, features_3d=features_3d) x = keras.layers.Add()( [x, skip_4]) if skip_add else keras.layers.Concatenate()([x, skip_4]) x = packnet_conv2d(x, 256, 3, 1) layer_8 = x - # layer 9 + # layer 16 x = packnet_inverse_depth(x, 1) - # layer 10 - u_layer_8 = unpack_3d(layer_8, 128, 3, features_3d=features_3d) + # layer 17 - 18 + u_layer_8 = unpack_3d(layer_8, 128 // divide_factor, 3, features_3d=features_3d) x = keras.layers.UpSampling2D()(x) x = keras.layers.Add()([u_layer_8, skip_3, x]) if skip_add else keras.layers.Concatenate()([u_layer_8, skip_3, x]) x = packnet_conv2d(x, 128, 3, 1) layer_10 = x - # layer 11 + # layer 19 x = packnet_inverse_depth(x, 1) - # layer 12 + # layer 20 - 21 u_layer_10 = unpack_3d(layer_10, 64, 3, features_3d=features_3d) x = keras.layers.UpSampling2D()(x) x = keras.layers.Add()([u_layer_10, skip_2, x]) if skip_add else keras.layers.Concatenate()([u_layer_10, skip_2, x]) x = packnet_conv2d(x, 64, 3, 1) layer_12 = x - # layer 13 + # layer 22 x = packnet_inverse_depth(x) - # layer 14 - u_layer_12 = unpack_3d(layer_12, 32, 3, features_3d=features_3d) + # layer 23 - 24 + u_layer_12 = unpack_3d(layer_12, 64, 3, features_3d=features_3d) x = keras.layers.UpSampling2D()(x) x = keras.layers.Add()([u_layer_12, skip_1, x]) if skip_add else keras.layers.Concatenate()([u_layer_12, skip_1, x]) - x = packnet_conv2d(x, 32, 3, 1) - # layer 15 + x = packnet_conv2d(x, 64, 3, 1) + # layer 25 x = packnet_inverse_depth(x) # ================ DECODER ================= return keras.Model(inputs=input, outputs=x, name="PackNet") + +if __name__ == '__main__': + # This is the implementation used by the packnet sfm paper + make_packnet(features_3d=8, skip_add=False).summary() \ No newline at end of file