import unittest import tensorflow as tf import packnet_functional as p class PacknetTests(unittest.TestCase): def test_pack_3d_layer(self): # 3d packing expects a multiple of 16 for channels due to using 16 groups in group normalisation test_input = tf.random.normal([4, 224, 224, 32]) y = p.pack_3d(test_input, 3, features_3d=4) out_shape = [i for i in test_input.shape] out_shape[1] = out_shape[1] // 2 out_shape[2] = out_shape[2] // 2 # TODO: Anything else we can test here for validity? self.assertEqual(y.shape, out_shape) def test_unpack_3d_layer(self): num_output_channels = 32 test_input = tf.random.normal([4, 112, 112, 64]) y = p.unpack_3d(test_input, num_output_channels, 3, features_3d=4) out_shape = [i for i in test_input.shape] out_shape[1] = out_shape[1] * 2 out_shape[2] = out_shape[2] * 2 out_shape[3] = num_output_channels # TODO: Anything else we can test here for validity? self.assertEqual(y.shape, out_shape) def test_packnet(self): packnet = p.make_packnet() self.assertIsNotNone(packnet) if __name__ == '__main__': unittest.main()