diff --git a/unsupervised/warp_tests.py b/unsupervised/warp_tests.py index 626c7e1..1cf62a1 100644 --- a/unsupervised/warp_tests.py +++ b/unsupervised/warp_tests.py @@ -56,7 +56,7 @@ class MyTestCase(unittest.TestCase): disp = tf.random.uniform([1, height, width]) * 255 pose = tf.random.uniform([1, 6]) - warp.projective_inverse_warp(img, disp, pose, intrinsics, coords) + self.assertEqual(warp.projective_inverse_warp(img, disp, pose, intrinsics, coords).shape, img.shape) if __name__ == '__main__':