Finish Projective Inverse Warp algorithm
This commit is contained in:
@@ -53,7 +53,6 @@ def pose_vec2mat(vec):
|
||||
ry = tf.slice(vec, [0, 4], [-1, 1])
|
||||
rz = tf.slice(vec, [0, 5], [-1, 1])
|
||||
rot_mat = euler_to_matrix(rx, ry, rz)
|
||||
rot_mat = tf.squeeze(rot_mat, axis=[1])
|
||||
transform_mat = tf.concat([rot_mat, translation], axis=2)
|
||||
return transform_mat
|
||||
|
||||
@@ -76,10 +75,10 @@ def image_coordinate(batch, height, width):
|
||||
|
||||
stacked = tf.stack([x_mesh, y_mesh, ones_mesh], axis=2)
|
||||
|
||||
return tf.repeat(tf.expand_dims(stacked, axis=0), batch, axis=0)
|
||||
return tf.cast(tf.repeat(tf.expand_dims(stacked, axis=0), batch, axis=0), dtype=tf.float32)
|
||||
|
||||
|
||||
def projective_inverse_warp(target_img, source_img, depth, pose, intrinsics, coordinates):
|
||||
def projective_inverse_warp(source_img, depth, pose, intrinsics, coordinates):
|
||||
"""
|
||||
Calculate the reprojected image from the source to the target, based on the given depth, pose and intrinsics
|
||||
|
||||
@@ -92,12 +91,11 @@ def projective_inverse_warp(target_img, source_img, depth, pose, intrinsics, coo
|
||||
the source image in pixel coordinates (K.T(t->s).{3d coord}), then using the projected coordinates we sample
|
||||
the pixels in the source image (ps) to reconstruct the target image.
|
||||
|
||||
:param target_img: Tensor (batch, height, width, 3)
|
||||
:param source_img: Tensor, same shape as target_img
|
||||
:param depth: Tensor, (batch, height, width, 1)
|
||||
:param source_img: Tensor (batch, height, width, 3)
|
||||
:param depth: Tensor, (batch, height, width)
|
||||
:param pose: (batch, 6)
|
||||
:param intrinsics: (batch, 3, 3) TODO: Intrinsics per image (per source/target image)?
|
||||
:param coordinates: (batch, height, width, 3) - coordinates for the image. Pass this in so it doesn't need to be
|
||||
:param coordinates: (batch, 3, height * width) - coordinates for the image. Pass this in so it doesn't need to be
|
||||
calculated on every warp step
|
||||
:return: The source image reprojected to the target
|
||||
"""
|
||||
@@ -109,23 +107,117 @@ def projective_inverse_warp(target_img, source_img, depth, pose, intrinsics, coo
|
||||
# Calculate inverse of the 4x4 intrinsics matrix
|
||||
intrinsics_inverse = tf.linalg.inv(intrinsics)
|
||||
|
||||
# Create grid of homogenous coordinates
|
||||
grid_coords = image_coordinate(*depth.shape)
|
||||
# Flatten the image coords to [B, 3, height * width] so each point can be used in calculations
|
||||
grid_coords = tf.transpose(tf.reshape(grid_coords, [0, 2, 1]))
|
||||
|
||||
# TODO: Do we need to transpose?
|
||||
depth_flat = tf.transpose(tf.reshape(depth, [0, 2, 1]))
|
||||
depth_flat = tf.reshape(depth, [depth.shape[0], depth.shape[1] * depth.shape[2]])
|
||||
|
||||
# Do the function
|
||||
sample_coordinates = tf.matmul(tf.matmul(intrinsics, pose_3x4),
|
||||
tf.concat([depth_flat * tf.matmul(intrinsics_inverse, grid_coords),
|
||||
tf.ones(depth_flat.shape)], axis=1))
|
||||
tf.concat([depth_flat * tf.matmul(intrinsics_inverse, coordinates),
|
||||
tf.ones([depth_flat.shape[0], 1, depth_flat.shape[1]])], axis=1))
|
||||
|
||||
# Normalise the x/y axes (divide by z axis)
|
||||
sample_coordinates = sample_coordinates[:, 0:2] / sample_coordinates[:, 2]
|
||||
|
||||
# Reshape back to image coordinates
|
||||
sample_coordinates = tf.reshape(tf.transpose(sample_coordinates, [0, 2, 1]),
|
||||
[depth.shape[0], depth.shape[1], depth.shape[2], 2])
|
||||
|
||||
# sample from the source image using the coordinates applied by the function
|
||||
return bilinear_sampler(source_img, sample_coordinates)
|
||||
|
||||
pass
|
||||
|
||||
def bilinear_sampler(imgs, coords):
|
||||
"""Construct a new image by bilinear sampling from the input image.
|
||||
|
||||
Points falling outside the source image boundary have value 0.
|
||||
|
||||
Args:
|
||||
imgs: source image to be sampled from [batch, height_s, width_s, channels]
|
||||
coords: coordinates of source pixels to sample from [batch, height_t,
|
||||
width_t, 2]. height_t/width_t correspond to the dimensions of the output
|
||||
image (don't need to be the same as height_s/width_s). The two channels
|
||||
correspond to x and y coordinates respectively.
|
||||
Returns:
|
||||
A new sampled image [batch, height_t, width_t, channels]
|
||||
"""
|
||||
|
||||
def _repeat(x, n_repeats):
|
||||
rep = tf.transpose(
|
||||
tf.expand_dims(tf.ones(shape=tf.stack([
|
||||
n_repeats,
|
||||
])), 1), [1, 0])
|
||||
rep = tf.cast(rep, 'float32')
|
||||
x = tf.matmul(tf.reshape(x, (-1, 1)), rep)
|
||||
return tf.reshape(x, [-1])
|
||||
|
||||
coords_x, coords_y = tf.split(coords, [1, 1], axis=3)
|
||||
inp_size = imgs.get_shape()
|
||||
coord_size = coords.get_shape()
|
||||
out_size = coords.get_shape().as_list()
|
||||
out_size[3] = imgs.get_shape().as_list()[3]
|
||||
|
||||
coords_x = tf.cast(coords_x, 'float32')
|
||||
coords_y = tf.cast(coords_y, 'float32')
|
||||
|
||||
x0 = tf.floor(coords_x)
|
||||
x1 = x0 + 1
|
||||
y0 = tf.floor(coords_y)
|
||||
y1 = y0 + 1
|
||||
|
||||
y_max = tf.cast(tf.shape(imgs)[1] - 1, 'float32')
|
||||
x_max = tf.cast(tf.shape(imgs)[2] - 1, 'float32')
|
||||
zero = tf.zeros([1], dtype='float32')
|
||||
|
||||
x0_safe = tf.clip_by_value(x0, zero, x_max)
|
||||
y0_safe = tf.clip_by_value(y0, zero, y_max)
|
||||
x1_safe = tf.clip_by_value(x1, zero, x_max)
|
||||
y1_safe = tf.clip_by_value(y1, zero, y_max)
|
||||
|
||||
# bilinear interp weights, with points outside the grid having weight 0
|
||||
# wt_x0 = (x1 - coords_x) * tf.cast(tf.equal(x0, x0_safe), 'float32')
|
||||
# wt_x1 = (coords_x - x0) * tf.cast(tf.equal(x1, x1_safe), 'float32')
|
||||
# wt_y0 = (y1 - coords_y) * tf.cast(tf.equal(y0, y0_safe), 'float32')
|
||||
# wt_y1 = (coords_y - y0) * tf.cast(tf.equal(y1, y1_safe), 'float32')
|
||||
|
||||
wt_x0 = x1_safe - coords_x
|
||||
wt_x1 = coords_x - x0_safe
|
||||
wt_y0 = y1_safe - coords_y
|
||||
wt_y1 = coords_y - y0_safe
|
||||
|
||||
# indices in the flat image to sample from
|
||||
dim2 = tf.cast(inp_size[2], 'float32')
|
||||
dim1 = tf.cast(inp_size[2] * inp_size[1], 'float32')
|
||||
base = tf.reshape(
|
||||
_repeat(
|
||||
tf.cast(tf.range(coord_size[0]), 'float32') * dim1,
|
||||
coord_size[1] * coord_size[2]),
|
||||
[out_size[0], out_size[1], out_size[2], 1])
|
||||
|
||||
base_y0 = base + y0_safe * dim2
|
||||
base_y1 = base + y1_safe * dim2
|
||||
idx00 = tf.reshape(x0_safe + base_y0, [-1])
|
||||
idx01 = x0_safe + base_y1
|
||||
idx10 = x1_safe + base_y0
|
||||
idx11 = x1_safe + base_y1
|
||||
|
||||
# sample from imgs
|
||||
imgs_flat = tf.reshape(imgs, tf.stack([-1, inp_size[3]]))
|
||||
imgs_flat = tf.cast(imgs_flat, 'float32')
|
||||
im00 = tf.reshape(
|
||||
tf.gather(imgs_flat, tf.cast(idx00, 'int32')), out_size)
|
||||
im01 = tf.reshape(
|
||||
tf.gather(imgs_flat, tf.cast(idx01, 'int32')), out_size)
|
||||
im10 = tf.reshape(
|
||||
tf.gather(imgs_flat, tf.cast(idx10, 'int32')), out_size)
|
||||
im11 = tf.reshape(
|
||||
tf.gather(imgs_flat, tf.cast(idx11, 'int32')), out_size)
|
||||
|
||||
w00 = wt_x0 * wt_y0
|
||||
w01 = wt_x0 * wt_y1
|
||||
w10 = wt_x1 * wt_y0
|
||||
w11 = wt_x1 * wt_y1
|
||||
|
||||
output = tf.add_n([
|
||||
w00 * im00, w01 * im01,
|
||||
w10 * im10, w11 * im11
|
||||
])
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user