import math import tensorflow as tf def euler_to_matrix(x, y, z): """ :param x: Tensor of shape (B, 1) - x axis rotation :param y: Tensor of shape (B, 1) - y axis rotation :param z: Tensor of shape (B, 1) - z axis rotation :return: Rotation matrix for the given euler anglers, in the order rotation(x) -> rotation(y) -> rotation(z) """ batch_size = tf.shape(z)[0] # Euler angles should be between -pi and pi, clip so the pose network is coerced to this range z = tf.clip_by_value(z, -math.pi, math.pi) y = tf.clip_by_value(y, -math.pi, math.pi) x = tf.clip_by_value(x, -math.pi, math.pi) cosx = tf.cos(x) sinx = tf.sin(x) cosy = tf.cos(y) siny = tf.sin(y) cosz = tf.cos(z) sinz = tf.sin(z) # Otherwise this will need to be reversed # Rotate about x, y then z. z goes first here as rotation is always left side of coordinates # R = Rz(φ)Ry(θ)Rx(ψ) # = | cos(θ)cos(φ) sin(ψ)sin(θ)cos(φ) − cos(ψ)sin(φ) cos(ψ)sin(θ)cos(φ) + sin(ψ)sin(φ) | # | cos(θ)sin(φ) sin(ψ)sin(θ)sin(φ) + cos(ψ)cos(φ) cos(ψ)sin(θ)sin(φ) − sin(ψ)cos(φ) | # | −sin(θ) sin(ψ)cos(θ) cos(ψ)cos(θ) | row_1 = tf.concat([cosy * cosz, sinx * siny * cosz - cosx * sinz, cosx * siny * cosz + sinx * sinz], 1) row_2 = tf.concat([cosy * sinz, sinx * siny * sinz + cosx * cosz, cosx * siny * sinz - sinx * cosz], 1) row_3 = tf.concat([-siny, sinx * cosy, cosx * cosy], 1) return tf.reshape(tf.concat([row_1, row_2, row_3], axis=1), [batch_size, 3, 3]) def pose_vec2mat(vec): """Converts 6DoF parameters to transformation matrix Args: vec: 6DoF parameters in the order of tx, ty, tz, rx, ry, rz -- [B, 6] Returns: A transformation matrix -- [B, 3, 4] """ batch_size, _ = vec.get_shape().as_list() translation = tf.slice(vec, [0, 0], [-1, 3]) translation = tf.expand_dims(translation, -1) rx = tf.slice(vec, [0, 3], [-1, 1]) ry = tf.slice(vec, [0, 4], [-1, 1]) rz = tf.slice(vec, [0, 5], [-1, 1]) rot_mat = euler_to_matrix(rx, ry, rz) transform_mat = tf.concat([rot_mat, translation], axis=2) return transform_mat def image_coordinate(batch, height, width): """ Construct a tensor for the given height/width with elements the homogenous coordinates for the pixel :param batch: Number of images in a batch :param height: Height of image :param width: Width of image :return: Tensor of shape (B, height, width, 3), homogenous coordinates for an image. Coordinates are in order [x, y, 1] """ x_coords = tf.range(width) y_coords = tf.range(height) x_mesh, y_mesh = tf.meshgrid(x_coords, y_coords) ones_mesh = tf.cast(tf.ones([height, width]), tf.int32) stacked = tf.stack([x_mesh, y_mesh, ones_mesh], axis=2) return tf.cast(tf.repeat(tf.expand_dims(stacked, axis=0), batch, axis=0), dtype=tf.float32) def projective_inverse_warp(source_img, depth, pose, intrinsics, coordinates): """ Calculate the reprojected image from the source to the target, based on the given depth, pose and intrinsics SFM Learner inverse warp step ps ~ K.T(t->s).Dt(pt)*K^-1.pt Note that the depth pixel Dt(pt) is multiplied by every coordinate value (just element-wise, not matrix multiplication) Idea is to map the pixel coordinates of the target image to 3d space (Dt(pt).K^-1.pt), then map these onto the source image in pixel coordinates (K.T(t->s).{3d coord}), then using the projected coordinates we sample the pixels in the source image (ps) to reconstruct the target image. :param source_img: Tensor (batch, height, width, 3) :param depth: Tensor, (batch, height, width) :param pose: (batch, 6) :param intrinsics: (batch, 3, 3) TODO: Intrinsics per image (per source/target image)? :param coordinates: (batch, 3, height * width) - coordinates for the image. Pass this in so it doesn't need to be calculated on every warp step :return: The source image reprojected to the target """ # Convert pose vector (output of pose net) to pose matrix (4x4) pose_3x4 = pose_vec2mat(pose) # Convert intrinsics matrix (3x3) to (4x4) so it can be multiplied by the pose net # intrinsics_4x4 = # Calculate inverse of the 4x4 intrinsics matrix intrinsics_inverse = tf.linalg.inv(intrinsics) depth_flat = tf.reshape(depth, [depth.shape[0], depth.shape[1] * depth.shape[2]]) # Do the function sample_coordinates = tf.matmul(tf.matmul(intrinsics, pose_3x4), tf.concat([depth_flat * tf.matmul(intrinsics_inverse, coordinates), tf.ones([depth_flat.shape[0], 1, depth_flat.shape[1]])], axis=1)) # Normalise the x/y axes (divide by z axis) sample_coordinates = sample_coordinates[:, 0:2] / sample_coordinates[:, 2] # Reshape back to image coordinates sample_coordinates = tf.reshape(tf.transpose(sample_coordinates, [0, 2, 1]), [depth.shape[0], depth.shape[1], depth.shape[2], 2]) # sample from the source image using the coordinates applied by the function return bilinear_sampler(source_img, sample_coordinates) def bilinear_sampler(imgs, coords): """Construct a new image by bilinear sampling from the input image. Points falling outside the source image boundary have value 0. Args: imgs: source image to be sampled from [batch, height_s, width_s, channels] coords: coordinates of source pixels to sample from [batch, height_t, width_t, 2]. height_t/width_t correspond to the dimensions of the output image (don't need to be the same as height_s/width_s). The two channels correspond to x and y coordinates respectively. Returns: A new sampled image [batch, height_t, width_t, channels] """ def _repeat(x, n_repeats): rep = tf.transpose( tf.expand_dims(tf.ones(shape=tf.stack([ n_repeats, ])), 1), [1, 0]) rep = tf.cast(rep, 'float32') x = tf.matmul(tf.reshape(x, (-1, 1)), rep) return tf.reshape(x, [-1]) coords_x, coords_y = tf.split(coords, [1, 1], axis=3) inp_size = imgs.get_shape() coord_size = coords.get_shape() out_size = coords.get_shape().as_list() out_size[3] = imgs.get_shape().as_list()[3] coords_x = tf.cast(coords_x, 'float32') coords_y = tf.cast(coords_y, 'float32') x0 = tf.floor(coords_x) x1 = x0 + 1 y0 = tf.floor(coords_y) y1 = y0 + 1 y_max = tf.cast(tf.shape(imgs)[1] - 1, 'float32') x_max = tf.cast(tf.shape(imgs)[2] - 1, 'float32') zero = tf.zeros([1], dtype='float32') x0_safe = tf.clip_by_value(x0, zero, x_max) y0_safe = tf.clip_by_value(y0, zero, y_max) x1_safe = tf.clip_by_value(x1, zero, x_max) y1_safe = tf.clip_by_value(y1, zero, y_max) # bilinear interp weights, with points outside the grid having weight 0 # wt_x0 = (x1 - coords_x) * tf.cast(tf.equal(x0, x0_safe), 'float32') # wt_x1 = (coords_x - x0) * tf.cast(tf.equal(x1, x1_safe), 'float32') # wt_y0 = (y1 - coords_y) * tf.cast(tf.equal(y0, y0_safe), 'float32') # wt_y1 = (coords_y - y0) * tf.cast(tf.equal(y1, y1_safe), 'float32') wt_x0 = x1_safe - coords_x wt_x1 = coords_x - x0_safe wt_y0 = y1_safe - coords_y wt_y1 = coords_y - y0_safe # indices in the flat image to sample from dim2 = tf.cast(inp_size[2], 'float32') dim1 = tf.cast(inp_size[2] * inp_size[1], 'float32') base = tf.reshape( _repeat( tf.cast(tf.range(coord_size[0]), 'float32') * dim1, coord_size[1] * coord_size[2]), [out_size[0], out_size[1], out_size[2], 1]) base_y0 = base + y0_safe * dim2 base_y1 = base + y1_safe * dim2 idx00 = tf.reshape(x0_safe + base_y0, [-1]) idx01 = x0_safe + base_y1 idx10 = x1_safe + base_y0 idx11 = x1_safe + base_y1 # sample from imgs imgs_flat = tf.reshape(imgs, tf.stack([-1, inp_size[3]])) imgs_flat = tf.cast(imgs_flat, 'float32') im00 = tf.reshape( tf.gather(imgs_flat, tf.cast(idx00, 'int32')), out_size) im01 = tf.reshape( tf.gather(imgs_flat, tf.cast(idx01, 'int32')), out_size) im10 = tf.reshape( tf.gather(imgs_flat, tf.cast(idx10, 'int32')), out_size) im11 = tf.reshape( tf.gather(imgs_flat, tf.cast(idx11, 'int32')), out_size) w00 = wt_x0 * wt_y0 w01 = wt_x0 * wt_y1 w10 = wt_x1 * wt_y0 w11 = wt_x1 * wt_y1 output = tf.add_n([ w00 * im00, w01 * im01, w10 * im10, w11 * im11 ]) return output