diff --git a/unsupervised/warp.py b/unsupervised/warp.py index ac70095..390d202 100644 --- a/unsupervised/warp.py +++ b/unsupervised/warp.py @@ -53,7 +53,6 @@ def pose_vec2mat(vec): ry = tf.slice(vec, [0, 4], [-1, 1]) rz = tf.slice(vec, [0, 5], [-1, 1]) rot_mat = euler_to_matrix(rx, ry, rz) - rot_mat = tf.squeeze(rot_mat, axis=[1]) transform_mat = tf.concat([rot_mat, translation], axis=2) return transform_mat @@ -76,10 +75,10 @@ def image_coordinate(batch, height, width): stacked = tf.stack([x_mesh, y_mesh, ones_mesh], axis=2) - return tf.repeat(tf.expand_dims(stacked, axis=0), batch, axis=0) + return tf.cast(tf.repeat(tf.expand_dims(stacked, axis=0), batch, axis=0), dtype=tf.float32) -def projective_inverse_warp(target_img, source_img, depth, pose, intrinsics, coordinates): +def projective_inverse_warp(source_img, depth, pose, intrinsics, coordinates): """ Calculate the reprojected image from the source to the target, based on the given depth, pose and intrinsics @@ -92,12 +91,11 @@ def projective_inverse_warp(target_img, source_img, depth, pose, intrinsics, coo the source image in pixel coordinates (K.T(t->s).{3d coord}), then using the projected coordinates we sample the pixels in the source image (ps) to reconstruct the target image. - :param target_img: Tensor (batch, height, width, 3) - :param source_img: Tensor, same shape as target_img - :param depth: Tensor, (batch, height, width, 1) + :param source_img: Tensor (batch, height, width, 3) + :param depth: Tensor, (batch, height, width) :param pose: (batch, 6) :param intrinsics: (batch, 3, 3) TODO: Intrinsics per image (per source/target image)? - :param coordinates: (batch, height, width, 3) - coordinates for the image. Pass this in so it doesn't need to be + :param coordinates: (batch, 3, height * width) - coordinates for the image. Pass this in so it doesn't need to be calculated on every warp step :return: The source image reprojected to the target """ @@ -109,23 +107,117 @@ def projective_inverse_warp(target_img, source_img, depth, pose, intrinsics, coo # Calculate inverse of the 4x4 intrinsics matrix intrinsics_inverse = tf.linalg.inv(intrinsics) - # Create grid of homogenous coordinates - grid_coords = image_coordinate(*depth.shape) - # Flatten the image coords to [B, 3, height * width] so each point can be used in calculations - grid_coords = tf.transpose(tf.reshape(grid_coords, [0, 2, 1])) - - # TODO: Do we need to transpose? - depth_flat = tf.transpose(tf.reshape(depth, [0, 2, 1])) + depth_flat = tf.reshape(depth, [depth.shape[0], depth.shape[1] * depth.shape[2]]) # Do the function sample_coordinates = tf.matmul(tf.matmul(intrinsics, pose_3x4), - tf.concat([depth_flat * tf.matmul(intrinsics_inverse, grid_coords), - tf.ones(depth_flat.shape)], axis=1)) + tf.concat([depth_flat * tf.matmul(intrinsics_inverse, coordinates), + tf.ones([depth_flat.shape[0], 1, depth_flat.shape[1]])], axis=1)) # Normalise the x/y axes (divide by z axis) + sample_coordinates = sample_coordinates[:, 0:2] / sample_coordinates[:, 2] # Reshape back to image coordinates + sample_coordinates = tf.reshape(tf.transpose(sample_coordinates, [0, 2, 1]), + [depth.shape[0], depth.shape[1], depth.shape[2], 2]) # sample from the source image using the coordinates applied by the function + return bilinear_sampler(source_img, sample_coordinates) - pass + +def bilinear_sampler(imgs, coords): + """Construct a new image by bilinear sampling from the input image. + + Points falling outside the source image boundary have value 0. + + Args: + imgs: source image to be sampled from [batch, height_s, width_s, channels] + coords: coordinates of source pixels to sample from [batch, height_t, + width_t, 2]. height_t/width_t correspond to the dimensions of the output + image (don't need to be the same as height_s/width_s). The two channels + correspond to x and y coordinates respectively. + Returns: + A new sampled image [batch, height_t, width_t, channels] + """ + + def _repeat(x, n_repeats): + rep = tf.transpose( + tf.expand_dims(tf.ones(shape=tf.stack([ + n_repeats, + ])), 1), [1, 0]) + rep = tf.cast(rep, 'float32') + x = tf.matmul(tf.reshape(x, (-1, 1)), rep) + return tf.reshape(x, [-1]) + + coords_x, coords_y = tf.split(coords, [1, 1], axis=3) + inp_size = imgs.get_shape() + coord_size = coords.get_shape() + out_size = coords.get_shape().as_list() + out_size[3] = imgs.get_shape().as_list()[3] + + coords_x = tf.cast(coords_x, 'float32') + coords_y = tf.cast(coords_y, 'float32') + + x0 = tf.floor(coords_x) + x1 = x0 + 1 + y0 = tf.floor(coords_y) + y1 = y0 + 1 + + y_max = tf.cast(tf.shape(imgs)[1] - 1, 'float32') + x_max = tf.cast(tf.shape(imgs)[2] - 1, 'float32') + zero = tf.zeros([1], dtype='float32') + + x0_safe = tf.clip_by_value(x0, zero, x_max) + y0_safe = tf.clip_by_value(y0, zero, y_max) + x1_safe = tf.clip_by_value(x1, zero, x_max) + y1_safe = tf.clip_by_value(y1, zero, y_max) + + # bilinear interp weights, with points outside the grid having weight 0 + # wt_x0 = (x1 - coords_x) * tf.cast(tf.equal(x0, x0_safe), 'float32') + # wt_x1 = (coords_x - x0) * tf.cast(tf.equal(x1, x1_safe), 'float32') + # wt_y0 = (y1 - coords_y) * tf.cast(tf.equal(y0, y0_safe), 'float32') + # wt_y1 = (coords_y - y0) * tf.cast(tf.equal(y1, y1_safe), 'float32') + + wt_x0 = x1_safe - coords_x + wt_x1 = coords_x - x0_safe + wt_y0 = y1_safe - coords_y + wt_y1 = coords_y - y0_safe + + # indices in the flat image to sample from + dim2 = tf.cast(inp_size[2], 'float32') + dim1 = tf.cast(inp_size[2] * inp_size[1], 'float32') + base = tf.reshape( + _repeat( + tf.cast(tf.range(coord_size[0]), 'float32') * dim1, + coord_size[1] * coord_size[2]), + [out_size[0], out_size[1], out_size[2], 1]) + + base_y0 = base + y0_safe * dim2 + base_y1 = base + y1_safe * dim2 + idx00 = tf.reshape(x0_safe + base_y0, [-1]) + idx01 = x0_safe + base_y1 + idx10 = x1_safe + base_y0 + idx11 = x1_safe + base_y1 + + # sample from imgs + imgs_flat = tf.reshape(imgs, tf.stack([-1, inp_size[3]])) + imgs_flat = tf.cast(imgs_flat, 'float32') + im00 = tf.reshape( + tf.gather(imgs_flat, tf.cast(idx00, 'int32')), out_size) + im01 = tf.reshape( + tf.gather(imgs_flat, tf.cast(idx01, 'int32')), out_size) + im10 = tf.reshape( + tf.gather(imgs_flat, tf.cast(idx10, 'int32')), out_size) + im11 = tf.reshape( + tf.gather(imgs_flat, tf.cast(idx11, 'int32')), out_size) + + w00 = wt_x0 * wt_y0 + w01 = wt_x0 * wt_y1 + w10 = wt_x1 * wt_y0 + w11 = wt_x1 * wt_y1 + + output = tf.add_n([ + w00 * im00, w01 * im01, + w10 * im10, w11 * im11 + ]) + return output diff --git a/unsupervised/warp_tests.py b/unsupervised/warp_tests.py index 7a0c1c2..626c7e1 100644 --- a/unsupervised/warp_tests.py +++ b/unsupervised/warp_tests.py @@ -42,6 +42,22 @@ class MyTestCase(unittest.TestCase): self.assertEqual(coords[0, height - 1, width - 1, 1], height - 1) self.assertEqual(coords[0, height - 1, width - 1, 2], 1) + def test_warp(self): + height = 1000 + width = 2000 + coords = warp.image_coordinate(1, height, width) + coords = tf.reshape(coords, [1, height * width, 3]) + coords = tf.transpose(coords, [0, 2, 1]) + # source image to sample from + img = tf.random.uniform([1, height, width, 3]) * 255 + + intrinsics = tf.constant([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]], dtype=tf.float32) + + disp = tf.random.uniform([1, height, width]) * 255 + pose = tf.random.uniform([1, 6]) + + warp.projective_inverse_warp(img, disp, pose, intrinsics, coords) + if __name__ == '__main__': unittest.main()