Finish Projective Inverse Warp algorithm

This commit is contained in:
Piv
2021-08-24 20:13:30 +09:30
parent b7917ec465
commit c164c9720a
2 changed files with 125 additions and 17 deletions

View File

@@ -53,7 +53,6 @@ def pose_vec2mat(vec):
ry = tf.slice(vec, [0, 4], [-1, 1]) ry = tf.slice(vec, [0, 4], [-1, 1])
rz = tf.slice(vec, [0, 5], [-1, 1]) rz = tf.slice(vec, [0, 5], [-1, 1])
rot_mat = euler_to_matrix(rx, ry, rz) rot_mat = euler_to_matrix(rx, ry, rz)
rot_mat = tf.squeeze(rot_mat, axis=[1])
transform_mat = tf.concat([rot_mat, translation], axis=2) transform_mat = tf.concat([rot_mat, translation], axis=2)
return transform_mat return transform_mat
@@ -76,10 +75,10 @@ def image_coordinate(batch, height, width):
stacked = tf.stack([x_mesh, y_mesh, ones_mesh], axis=2) stacked = tf.stack([x_mesh, y_mesh, ones_mesh], axis=2)
return tf.repeat(tf.expand_dims(stacked, axis=0), batch, axis=0) return tf.cast(tf.repeat(tf.expand_dims(stacked, axis=0), batch, axis=0), dtype=tf.float32)
def projective_inverse_warp(target_img, source_img, depth, pose, intrinsics, coordinates): def projective_inverse_warp(source_img, depth, pose, intrinsics, coordinates):
""" """
Calculate the reprojected image from the source to the target, based on the given depth, pose and intrinsics Calculate the reprojected image from the source to the target, based on the given depth, pose and intrinsics
@@ -92,12 +91,11 @@ def projective_inverse_warp(target_img, source_img, depth, pose, intrinsics, coo
the source image in pixel coordinates (K.T(t->s).{3d coord}), then using the projected coordinates we sample the source image in pixel coordinates (K.T(t->s).{3d coord}), then using the projected coordinates we sample
the pixels in the source image (ps) to reconstruct the target image. the pixels in the source image (ps) to reconstruct the target image.
:param target_img: Tensor (batch, height, width, 3) :param source_img: Tensor (batch, height, width, 3)
:param source_img: Tensor, same shape as target_img :param depth: Tensor, (batch, height, width)
:param depth: Tensor, (batch, height, width, 1)
:param pose: (batch, 6) :param pose: (batch, 6)
:param intrinsics: (batch, 3, 3) TODO: Intrinsics per image (per source/target image)? :param intrinsics: (batch, 3, 3) TODO: Intrinsics per image (per source/target image)?
:param coordinates: (batch, height, width, 3) - coordinates for the image. Pass this in so it doesn't need to be :param coordinates: (batch, 3, height * width) - coordinates for the image. Pass this in so it doesn't need to be
calculated on every warp step calculated on every warp step
:return: The source image reprojected to the target :return: The source image reprojected to the target
""" """
@@ -109,23 +107,117 @@ def projective_inverse_warp(target_img, source_img, depth, pose, intrinsics, coo
# Calculate inverse of the 4x4 intrinsics matrix # Calculate inverse of the 4x4 intrinsics matrix
intrinsics_inverse = tf.linalg.inv(intrinsics) intrinsics_inverse = tf.linalg.inv(intrinsics)
# Create grid of homogenous coordinates depth_flat = tf.reshape(depth, [depth.shape[0], depth.shape[1] * depth.shape[2]])
grid_coords = image_coordinate(*depth.shape)
# Flatten the image coords to [B, 3, height * width] so each point can be used in calculations
grid_coords = tf.transpose(tf.reshape(grid_coords, [0, 2, 1]))
# TODO: Do we need to transpose?
depth_flat = tf.transpose(tf.reshape(depth, [0, 2, 1]))
# Do the function # Do the function
sample_coordinates = tf.matmul(tf.matmul(intrinsics, pose_3x4), sample_coordinates = tf.matmul(tf.matmul(intrinsics, pose_3x4),
tf.concat([depth_flat * tf.matmul(intrinsics_inverse, grid_coords), tf.concat([depth_flat * tf.matmul(intrinsics_inverse, coordinates),
tf.ones(depth_flat.shape)], axis=1)) tf.ones([depth_flat.shape[0], 1, depth_flat.shape[1]])], axis=1))
# Normalise the x/y axes (divide by z axis) # Normalise the x/y axes (divide by z axis)
sample_coordinates = sample_coordinates[:, 0:2] / sample_coordinates[:, 2]
# Reshape back to image coordinates # Reshape back to image coordinates
sample_coordinates = tf.reshape(tf.transpose(sample_coordinates, [0, 2, 1]),
[depth.shape[0], depth.shape[1], depth.shape[2], 2])
# sample from the source image using the coordinates applied by the function # sample from the source image using the coordinates applied by the function
return bilinear_sampler(source_img, sample_coordinates)
pass
def bilinear_sampler(imgs, coords):
"""Construct a new image by bilinear sampling from the input image.
Points falling outside the source image boundary have value 0.
Args:
imgs: source image to be sampled from [batch, height_s, width_s, channels]
coords: coordinates of source pixels to sample from [batch, height_t,
width_t, 2]. height_t/width_t correspond to the dimensions of the output
image (don't need to be the same as height_s/width_s). The two channels
correspond to x and y coordinates respectively.
Returns:
A new sampled image [batch, height_t, width_t, channels]
"""
def _repeat(x, n_repeats):
rep = tf.transpose(
tf.expand_dims(tf.ones(shape=tf.stack([
n_repeats,
])), 1), [1, 0])
rep = tf.cast(rep, 'float32')
x = tf.matmul(tf.reshape(x, (-1, 1)), rep)
return tf.reshape(x, [-1])
coords_x, coords_y = tf.split(coords, [1, 1], axis=3)
inp_size = imgs.get_shape()
coord_size = coords.get_shape()
out_size = coords.get_shape().as_list()
out_size[3] = imgs.get_shape().as_list()[3]
coords_x = tf.cast(coords_x, 'float32')
coords_y = tf.cast(coords_y, 'float32')
x0 = tf.floor(coords_x)
x1 = x0 + 1
y0 = tf.floor(coords_y)
y1 = y0 + 1
y_max = tf.cast(tf.shape(imgs)[1] - 1, 'float32')
x_max = tf.cast(tf.shape(imgs)[2] - 1, 'float32')
zero = tf.zeros([1], dtype='float32')
x0_safe = tf.clip_by_value(x0, zero, x_max)
y0_safe = tf.clip_by_value(y0, zero, y_max)
x1_safe = tf.clip_by_value(x1, zero, x_max)
y1_safe = tf.clip_by_value(y1, zero, y_max)
# bilinear interp weights, with points outside the grid having weight 0
# wt_x0 = (x1 - coords_x) * tf.cast(tf.equal(x0, x0_safe), 'float32')
# wt_x1 = (coords_x - x0) * tf.cast(tf.equal(x1, x1_safe), 'float32')
# wt_y0 = (y1 - coords_y) * tf.cast(tf.equal(y0, y0_safe), 'float32')
# wt_y1 = (coords_y - y0) * tf.cast(tf.equal(y1, y1_safe), 'float32')
wt_x0 = x1_safe - coords_x
wt_x1 = coords_x - x0_safe
wt_y0 = y1_safe - coords_y
wt_y1 = coords_y - y0_safe
# indices in the flat image to sample from
dim2 = tf.cast(inp_size[2], 'float32')
dim1 = tf.cast(inp_size[2] * inp_size[1], 'float32')
base = tf.reshape(
_repeat(
tf.cast(tf.range(coord_size[0]), 'float32') * dim1,
coord_size[1] * coord_size[2]),
[out_size[0], out_size[1], out_size[2], 1])
base_y0 = base + y0_safe * dim2
base_y1 = base + y1_safe * dim2
idx00 = tf.reshape(x0_safe + base_y0, [-1])
idx01 = x0_safe + base_y1
idx10 = x1_safe + base_y0
idx11 = x1_safe + base_y1
# sample from imgs
imgs_flat = tf.reshape(imgs, tf.stack([-1, inp_size[3]]))
imgs_flat = tf.cast(imgs_flat, 'float32')
im00 = tf.reshape(
tf.gather(imgs_flat, tf.cast(idx00, 'int32')), out_size)
im01 = tf.reshape(
tf.gather(imgs_flat, tf.cast(idx01, 'int32')), out_size)
im10 = tf.reshape(
tf.gather(imgs_flat, tf.cast(idx10, 'int32')), out_size)
im11 = tf.reshape(
tf.gather(imgs_flat, tf.cast(idx11, 'int32')), out_size)
w00 = wt_x0 * wt_y0
w01 = wt_x0 * wt_y1
w10 = wt_x1 * wt_y0
w11 = wt_x1 * wt_y1
output = tf.add_n([
w00 * im00, w01 * im01,
w10 * im10, w11 * im11
])
return output

View File

@@ -42,6 +42,22 @@ class MyTestCase(unittest.TestCase):
self.assertEqual(coords[0, height - 1, width - 1, 1], height - 1) self.assertEqual(coords[0, height - 1, width - 1, 1], height - 1)
self.assertEqual(coords[0, height - 1, width - 1, 2], 1) self.assertEqual(coords[0, height - 1, width - 1, 2], 1)
def test_warp(self):
height = 1000
width = 2000
coords = warp.image_coordinate(1, height, width)
coords = tf.reshape(coords, [1, height * width, 3])
coords = tf.transpose(coords, [0, 2, 1])
# source image to sample from
img = tf.random.uniform([1, height, width, 3]) * 255
intrinsics = tf.constant([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]], dtype=tf.float32)
disp = tf.random.uniform([1, height, width]) * 255
pose = tf.random.uniform([1, 6])
warp.projective_inverse_warp(img, disp, pose, intrinsics, coords)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()