38 lines
1.2 KiB
Python
38 lines
1.2 KiB
Python
import unittest
|
|
|
|
import tensorflow as tf
|
|
|
|
import packnet_functional as p
|
|
|
|
|
|
class PacknetTests(unittest.TestCase):
|
|
|
|
def test_pack_3d_layer(self):
|
|
# 3d packing expects a multiple of 16 for channels due to using 16 groups in group normalisation
|
|
test_input = tf.random.normal([4, 224, 224, 32])
|
|
y = p.pack_3d(test_input, 3, features_3d=4)
|
|
out_shape = [i for i in test_input.shape]
|
|
out_shape[1] = out_shape[1] // 2
|
|
out_shape[2] = out_shape[2] // 2
|
|
# TODO: Anything else we can test here for validity?
|
|
self.assertEqual(y.shape, out_shape)
|
|
|
|
def test_unpack_3d_layer(self):
|
|
num_output_channels = 32
|
|
test_input = tf.random.normal([4, 112, 112, 64])
|
|
y = p.unpack_3d(test_input, num_output_channels, 3, features_3d=4)
|
|
out_shape = [i for i in test_input.shape]
|
|
out_shape[1] = out_shape[1] * 2
|
|
out_shape[2] = out_shape[2] * 2
|
|
out_shape[3] = num_output_channels
|
|
# TODO: Anything else we can test here for validity?
|
|
self.assertEqual(y.shape, out_shape)
|
|
|
|
def test_packnet(self):
|
|
packnet = p.make_packnet()
|
|
self.assertIsNotNone(packnet)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|