diff --git a/packnet_functional.py b/packnet_functional.py index 5bcb578..ae32229 100644 --- a/packnet_functional.py +++ b/packnet_functional.py @@ -68,12 +68,12 @@ def pack_3d(inputs, kernel_size, r=2, features_3d=8): return packnet_conv2d(x, inputs.shape[3], kernel_size, 1) -def unpack_3d(inputs, out_channels, kernel_size, r=3, features_3d=8): +def unpack_3d(inputs, out_channels, kernel_size, r=2, features_3d=8): x = packnet_conv2d(inputs, out_channels * (r ** 2) // features_3d, kernel_size, 1) x = tf.expand_dims(x, 4) # B x H/2 x W/2 x 4(out)/D x D - x = keras.layers.Conv3D(features_3d, kernel_size=3, padding='same') - b, c, d, h, w = x.shape + x = keras.layers.Conv3D(features_3d, kernel_size=3, padding='same')(x) + b, h, w, c, d = x.shape x = tf.reshape(x, [b, h, w, c * d]) return nn.depth_to_space(x, r) diff --git a/packnet_tests.py b/packnet_tests.py new file mode 100644 index 0000000..22b3feb --- /dev/null +++ b/packnet_tests.py @@ -0,0 +1,33 @@ +import unittest + +import tensorflow as tf + +import packnet_functional as p + + +class PacknetTests(unittest.TestCase): + + def test_pack_3d_layer(self): + # 3d packing expects a multiple of 16 for channels due to using 16 groups in group normalisation + test_input = tf.random.normal([4, 224, 224, 32]) + y = p.pack_3d(test_input, 3, features_3d=4) + out_shape = [i for i in test_input.shape] + out_shape[1] = out_shape[1] // 2 + out_shape[2] = out_shape[2] // 2 + # TODO: Anything else we can test here for validity? + self.assertEqual(y.shape, out_shape) + + def test_unpack_3d_layer(self): + num_output_channels = 32 + test_input = tf.random.normal([4, 112, 112, 64]) + y = p.unpack_3d(test_input, num_output_channels, 3, features_3d=4) + out_shape = [i for i in test_input.shape] + out_shape[1] = out_shape[1] * 2 + out_shape[2] = out_shape[2] * 2 + out_shape[3] = num_output_channels + # TODO: Anything else we can test here for validity? + self.assertEqual(y.shape, out_shape) + + +if __name__ == '__main__': + unittest.main()