Add pack layer tests, fix unpack_3d layer
This commit is contained in:
@@ -68,12 +68,12 @@ def pack_3d(inputs, kernel_size, r=2, features_3d=8):
|
|||||||
return packnet_conv2d(x, inputs.shape[3], kernel_size, 1)
|
return packnet_conv2d(x, inputs.shape[3], kernel_size, 1)
|
||||||
|
|
||||||
|
|
||||||
def unpack_3d(inputs, out_channels, kernel_size, r=3, features_3d=8):
|
def unpack_3d(inputs, out_channels, kernel_size, r=2, features_3d=8):
|
||||||
x = packnet_conv2d(inputs, out_channels * (r ** 2) //
|
x = packnet_conv2d(inputs, out_channels * (r ** 2) //
|
||||||
features_3d, kernel_size, 1)
|
features_3d, kernel_size, 1)
|
||||||
x = tf.expand_dims(x, 4) # B x H/2 x W/2 x 4(out)/D x D
|
x = tf.expand_dims(x, 4) # B x H/2 x W/2 x 4(out)/D x D
|
||||||
x = keras.layers.Conv3D(features_3d, kernel_size=3, padding='same')
|
x = keras.layers.Conv3D(features_3d, kernel_size=3, padding='same')(x)
|
||||||
b, c, d, h, w = x.shape
|
b, h, w, c, d = x.shape
|
||||||
x = tf.reshape(x, [b, h, w, c * d])
|
x = tf.reshape(x, [b, h, w, c * d])
|
||||||
return nn.depth_to_space(x, r)
|
return nn.depth_to_space(x, r)
|
||||||
|
|
||||||
|
|||||||
33
packnet_tests.py
Normal file
33
packnet_tests.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
import packnet_functional as p
|
||||||
|
|
||||||
|
|
||||||
|
class PacknetTests(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_pack_3d_layer(self):
|
||||||
|
# 3d packing expects a multiple of 16 for channels due to using 16 groups in group normalisation
|
||||||
|
test_input = tf.random.normal([4, 224, 224, 32])
|
||||||
|
y = p.pack_3d(test_input, 3, features_3d=4)
|
||||||
|
out_shape = [i for i in test_input.shape]
|
||||||
|
out_shape[1] = out_shape[1] // 2
|
||||||
|
out_shape[2] = out_shape[2] // 2
|
||||||
|
# TODO: Anything else we can test here for validity?
|
||||||
|
self.assertEqual(y.shape, out_shape)
|
||||||
|
|
||||||
|
def test_unpack_3d_layer(self):
|
||||||
|
num_output_channels = 32
|
||||||
|
test_input = tf.random.normal([4, 112, 112, 64])
|
||||||
|
y = p.unpack_3d(test_input, num_output_channels, 3, features_3d=4)
|
||||||
|
out_shape = [i for i in test_input.shape]
|
||||||
|
out_shape[1] = out_shape[1] * 2
|
||||||
|
out_shape[2] = out_shape[2] * 2
|
||||||
|
out_shape[3] = num_output_channels
|
||||||
|
# TODO: Anything else we can test here for validity?
|
||||||
|
self.assertEqual(y.shape, out_shape)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user