Files
fast-depth-tf/unsupervised/warp.py
2021-08-24 20:13:30 +09:30

224 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import math
import tensorflow as tf
def euler_to_matrix(x, y, z):
"""
:param x: Tensor of shape (B, 1) - x axis rotation
:param y: Tensor of shape (B, 1) - y axis rotation
:param z: Tensor of shape (B, 1) - z axis rotation
:return: Rotation matrix for the given euler anglers, in the order rotation(x) -> rotation(y) -> rotation(z)
"""
batch_size = tf.shape(z)[0]
# Euler angles should be between -pi and pi, clip so the pose network is coerced to this range
z = tf.clip_by_value(z, -math.pi, math.pi)
y = tf.clip_by_value(y, -math.pi, math.pi)
x = tf.clip_by_value(x, -math.pi, math.pi)
cosx = tf.cos(x)
sinx = tf.sin(x)
cosy = tf.cos(y)
siny = tf.sin(y)
cosz = tf.cos(z)
sinz = tf.sin(z)
# Otherwise this will need to be reversed
# Rotate about x, y then z. z goes first here as rotation is always left side of coordinates
# R = Rz(φ)Ry(θ)Rx(ψ)
# = | cos(θ)cos(φ) sin(ψ)sin(θ)cos(φ) cos(ψ)sin(φ) cos(ψ)sin(θ)cos(φ) + sin(ψ)sin(φ) |
# | cos(θ)sin(φ) sin(ψ)sin(θ)sin(φ) + cos(ψ)cos(φ) cos(ψ)sin(θ)sin(φ) sin(ψ)cos(φ) |
# | sin(θ) sin(ψ)cos(θ) cos(ψ)cos(θ) |
row_1 = tf.concat([cosy * cosz, sinx * siny * cosz - cosx * sinz, cosx * siny * cosz + sinx * sinz], 1)
row_2 = tf.concat([cosy * sinz, sinx * siny * sinz + cosx * cosz, cosx * siny * sinz - sinx * cosz], 1)
row_3 = tf.concat([-siny, sinx * cosy, cosx * cosy], 1)
return tf.reshape(tf.concat([row_1, row_2, row_3], axis=1), [batch_size, 3, 3])
def pose_vec2mat(vec):
"""Converts 6DoF parameters to transformation matrix
Args:
vec: 6DoF parameters in the order of tx, ty, tz, rx, ry, rz -- [B, 6]
Returns:
A transformation matrix -- [B, 3, 4]
"""
batch_size, _ = vec.get_shape().as_list()
translation = tf.slice(vec, [0, 0], [-1, 3])
translation = tf.expand_dims(translation, -1)
rx = tf.slice(vec, [0, 3], [-1, 1])
ry = tf.slice(vec, [0, 4], [-1, 1])
rz = tf.slice(vec, [0, 5], [-1, 1])
rot_mat = euler_to_matrix(rx, ry, rz)
transform_mat = tf.concat([rot_mat, translation], axis=2)
return transform_mat
def image_coordinate(batch, height, width):
"""
Construct a tensor for the given height/width with elements the homogenous coordinates for the pixel
:param batch: Number of images in a batch
:param height: Height of image
:param width: Width of image
:return: Tensor of shape (B, height, width, 3), homogenous coordinates for an image.
Coordinates are in order [x, y, 1]
"""
x_coords = tf.range(width)
y_coords = tf.range(height)
x_mesh, y_mesh = tf.meshgrid(x_coords, y_coords)
ones_mesh = tf.cast(tf.ones([height, width]), tf.int32)
stacked = tf.stack([x_mesh, y_mesh, ones_mesh], axis=2)
return tf.cast(tf.repeat(tf.expand_dims(stacked, axis=0), batch, axis=0), dtype=tf.float32)
def projective_inverse_warp(source_img, depth, pose, intrinsics, coordinates):
"""
Calculate the reprojected image from the source to the target, based on the given depth, pose and intrinsics
SFM Learner inverse warp step
ps ~ K.T(t->s).Dt(pt)*K^-1.pt
Note that the depth pixel Dt(pt) is multiplied by every coordinate value (just element-wise, not matrix multiplication)
Idea is to map the pixel coordinates of the target image to 3d space (Dt(pt).K^-1.pt), then map these onto
the source image in pixel coordinates (K.T(t->s).{3d coord}), then using the projected coordinates we sample
the pixels in the source image (ps) to reconstruct the target image.
:param source_img: Tensor (batch, height, width, 3)
:param depth: Tensor, (batch, height, width)
:param pose: (batch, 6)
:param intrinsics: (batch, 3, 3) TODO: Intrinsics per image (per source/target image)?
:param coordinates: (batch, 3, height * width) - coordinates for the image. Pass this in so it doesn't need to be
calculated on every warp step
:return: The source image reprojected to the target
"""
# Convert pose vector (output of pose net) to pose matrix (4x4)
pose_3x4 = pose_vec2mat(pose)
# Convert intrinsics matrix (3x3) to (4x4) so it can be multiplied by the pose net
# intrinsics_4x4 =
# Calculate inverse of the 4x4 intrinsics matrix
intrinsics_inverse = tf.linalg.inv(intrinsics)
depth_flat = tf.reshape(depth, [depth.shape[0], depth.shape[1] * depth.shape[2]])
# Do the function
sample_coordinates = tf.matmul(tf.matmul(intrinsics, pose_3x4),
tf.concat([depth_flat * tf.matmul(intrinsics_inverse, coordinates),
tf.ones([depth_flat.shape[0], 1, depth_flat.shape[1]])], axis=1))
# Normalise the x/y axes (divide by z axis)
sample_coordinates = sample_coordinates[:, 0:2] / sample_coordinates[:, 2]
# Reshape back to image coordinates
sample_coordinates = tf.reshape(tf.transpose(sample_coordinates, [0, 2, 1]),
[depth.shape[0], depth.shape[1], depth.shape[2], 2])
# sample from the source image using the coordinates applied by the function
return bilinear_sampler(source_img, sample_coordinates)
def bilinear_sampler(imgs, coords):
"""Construct a new image by bilinear sampling from the input image.
Points falling outside the source image boundary have value 0.
Args:
imgs: source image to be sampled from [batch, height_s, width_s, channels]
coords: coordinates of source pixels to sample from [batch, height_t,
width_t, 2]. height_t/width_t correspond to the dimensions of the output
image (don't need to be the same as height_s/width_s). The two channels
correspond to x and y coordinates respectively.
Returns:
A new sampled image [batch, height_t, width_t, channels]
"""
def _repeat(x, n_repeats):
rep = tf.transpose(
tf.expand_dims(tf.ones(shape=tf.stack([
n_repeats,
])), 1), [1, 0])
rep = tf.cast(rep, 'float32')
x = tf.matmul(tf.reshape(x, (-1, 1)), rep)
return tf.reshape(x, [-1])
coords_x, coords_y = tf.split(coords, [1, 1], axis=3)
inp_size = imgs.get_shape()
coord_size = coords.get_shape()
out_size = coords.get_shape().as_list()
out_size[3] = imgs.get_shape().as_list()[3]
coords_x = tf.cast(coords_x, 'float32')
coords_y = tf.cast(coords_y, 'float32')
x0 = tf.floor(coords_x)
x1 = x0 + 1
y0 = tf.floor(coords_y)
y1 = y0 + 1
y_max = tf.cast(tf.shape(imgs)[1] - 1, 'float32')
x_max = tf.cast(tf.shape(imgs)[2] - 1, 'float32')
zero = tf.zeros([1], dtype='float32')
x0_safe = tf.clip_by_value(x0, zero, x_max)
y0_safe = tf.clip_by_value(y0, zero, y_max)
x1_safe = tf.clip_by_value(x1, zero, x_max)
y1_safe = tf.clip_by_value(y1, zero, y_max)
# bilinear interp weights, with points outside the grid having weight 0
# wt_x0 = (x1 - coords_x) * tf.cast(tf.equal(x0, x0_safe), 'float32')
# wt_x1 = (coords_x - x0) * tf.cast(tf.equal(x1, x1_safe), 'float32')
# wt_y0 = (y1 - coords_y) * tf.cast(tf.equal(y0, y0_safe), 'float32')
# wt_y1 = (coords_y - y0) * tf.cast(tf.equal(y1, y1_safe), 'float32')
wt_x0 = x1_safe - coords_x
wt_x1 = coords_x - x0_safe
wt_y0 = y1_safe - coords_y
wt_y1 = coords_y - y0_safe
# indices in the flat image to sample from
dim2 = tf.cast(inp_size[2], 'float32')
dim1 = tf.cast(inp_size[2] * inp_size[1], 'float32')
base = tf.reshape(
_repeat(
tf.cast(tf.range(coord_size[0]), 'float32') * dim1,
coord_size[1] * coord_size[2]),
[out_size[0], out_size[1], out_size[2], 1])
base_y0 = base + y0_safe * dim2
base_y1 = base + y1_safe * dim2
idx00 = tf.reshape(x0_safe + base_y0, [-1])
idx01 = x0_safe + base_y1
idx10 = x1_safe + base_y0
idx11 = x1_safe + base_y1
# sample from imgs
imgs_flat = tf.reshape(imgs, tf.stack([-1, inp_size[3]]))
imgs_flat = tf.cast(imgs_flat, 'float32')
im00 = tf.reshape(
tf.gather(imgs_flat, tf.cast(idx00, 'int32')), out_size)
im01 = tf.reshape(
tf.gather(imgs_flat, tf.cast(idx01, 'int32')), out_size)
im10 = tf.reshape(
tf.gather(imgs_flat, tf.cast(idx10, 'int32')), out_size)
im11 = tf.reshape(
tf.gather(imgs_flat, tf.cast(idx11, 'int32')), out_size)
w00 = wt_x0 * wt_y0
w01 = wt_x0 * wt_y1
w10 = wt_x1 * wt_y0
w11 = wt_x1 * wt_y1
output = tf.add_n([
w00 * im00, w01 * im01,
w10 * im10, w11 * im11
])
return output