From f501beb6f28b6e2b12f47b5ece45524a6598cee1 Mon Sep 17 00:00:00 2001 From: Piv <18462828+Piv200@users.noreply.github.com> Date: Mon, 5 Jul 2021 20:50:12 +0930 Subject: [PATCH] Start implementing unsupervised train loop, add sfmlearner train and utils files for reference --- unsupervised/third-party/train.py | 166 ++++++++++++++++++++++++ unsupervised/{ => third-party}/utils.py | 0 unsupervised/train.py | 119 ++++------------- 3 files changed, 190 insertions(+), 95 deletions(-) create mode 100644 unsupervised/third-party/train.py rename unsupervised/{ => third-party}/utils.py (100%) diff --git a/unsupervised/third-party/train.py b/unsupervised/third-party/train.py new file mode 100644 index 0000000..9aed045 --- /dev/null +++ b/unsupervised/third-party/train.py @@ -0,0 +1,166 @@ +""" +Trainer to learn depth information on unlabeled data (raw images/videos) + +Allows pluggable depth networks for differing performance (including fast-depth) +""" + +import tensorflow as tf + + +def compute_smooth_loss(self, pred_disp): + def gradient(pred): + D_dy = pred[:, 1:, :, :] - pred[:, :-1, :, :] + D_dx = pred[:, :, 1:, :] - pred[:, :, :-1, :] + return D_dx, D_dy + dx, dy = gradient(pred_disp) + dx2, dxdy = gradient(dx) + dydx, dy2 = gradient(dy) + return tf.reduce_mean(tf.abs(dx2)) + \ + tf.reduce_mean(tf.abs(dxdy)) + \ + tf.reduce_mean(tf.abs(dydx)) + \ + tf.reduce_mean(tf.abs(dy2)) + + +def get_reference_explain_mask(self, downscaling): + opt = self.opt + tmp = np.array([0, 1]) + ref_exp_mask = np.tile(tmp, + (opt.batch_size, + int(opt.img_height/(2**downscaling)), + int(opt.img_width/(2**downscaling)), + 1)) + ref_exp_mask = tf.constant(ref_exp_mask, dtype=tf.float32) + return ref_exp_mask + + +def get_sfm_loss_fn(opt): + def sfm_loss_fn(y, y_pred): + # TODO: Correctly format a batch that is required for this loss function + pixel_loss = 0 + exp_loss = 0 + smooth_loss = 0 + tgt_image_all = [] + src_image_stack_all = [] + proj_image_stack_all = [] + proj_error_stack_all = [] + exp_mask_stack_all = [] + for s in range(opt.num_scales): + if opt.explain_reg_weight > 0: + # Construct a reference explainability mask (i.e. all + # pixels are explainable) + ref_exp_mask = get_reference_explain_mask(s) + # Scale the source and target images for computing loss at the + # according scale. + curr_tgt_image = tf.image.resize_area(tgt_image, + [int(opt.img_height/(2**s)), int(opt.img_width/(2**s))]) + curr_src_image_stack = tf.image.resize_area(src_image_stack, + [int(opt.img_height/(2**s)), int(opt.img_width/(2**s))]) + + if opt.smooth_weight > 0: + smooth_loss += opt.smooth_weight/(2**s) * \ + compute_smooth_loss(pred_disp[s]) + + for i in range(opt.num_source): + # Inverse warp the source image to the target image frame + curr_proj_image = projective_inverse_warp( + curr_src_image_stack[:, :, :, 3*i:3*(i+1)], + tf.squeeze(pred_depth[s], axis=3), + pred_poses[:, i, :], + intrinsics[:, s, :, :]) + curr_proj_error = tf.abs(curr_proj_image - curr_tgt_image) + # Cross-entropy loss as regularization for the + # explainability prediction + if opt.explain_reg_weight > 0: + curr_exp_logits = tf.slice(pred_exp_logits[s], + [0, 0, 0, i*2], + [-1, -1, -1, 2]) + exp_loss += opt.explain_reg_weight * \ + self.compute_exp_reg_loss(curr_exp_logits, + ref_exp_mask) + curr_exp = tf.nn.softmax(curr_exp_logits) + # Photo-consistency loss weighted by explainability + if opt.explain_reg_weight > 0: + pixel_loss += tf.reduce_mean(curr_proj_error * + tf.expand_dims(curr_exp[:, :, :, 1], -1)) + else: + pixel_loss += tf.reduce_mean(curr_proj_error) + # Prepare images for tensorboard summaries + if i == 0: + proj_image_stack = curr_proj_image + proj_error_stack = curr_proj_error + if opt.explain_reg_weight > 0: + exp_mask_stack = tf.expand_dims( + curr_exp[:, :, :, 1], -1) + else: + proj_image_stack = tf.concat([proj_image_stack, + curr_proj_image], axis=3) + proj_error_stack = tf.concat([proj_error_stack, + curr_proj_error], axis=3) + if opt.explain_reg_weight > 0: + exp_mask_stack = tf.concat([exp_mask_stack, + tf.expand_dims(curr_exp[:, :, :, 1], -1)], axis=3) + tgt_image_all.append(curr_tgt_image) + src_image_stack_all.append(curr_src_image_stack) + proj_image_stack_all.append(proj_image_stack) + proj_error_stack_all.append(proj_error_stack) + if opt.explain_reg_weight > 0: + exp_mask_stack_all.append(exp_mask_stack) + total_loss = pixel_loss + smooth_loss + exp_loss + return total_loss + return sfm_loss_fn + + +def photometric_reconstruction_loss(tgt_img, ref_imgs, intrinsics, + depth, explainability_mask, pose, + rotation_mode='euler', padding_mode='zeros'): + def one_scale(d, mask): + assert(mask is None or d.size() + [2:] == mask.size()[2:]) + assert(pose.size(1) == len(ref_imgs)) + + reconstruction_loss = 0 + b, _, h, w = d.size() + downscale = tgt_img.size(2)/h + + tgt_img_scaled = F.interpolate(tgt_img, (h, w), mode='area') + ref_imgs_scaled = [F.interpolate( + ref_img, (h, w), mode='area') for ref_img in ref_imgs] + intrinsics_scaled = tf.concat( + (intrinsics[:, 0:2]/downscale, intrinsics[:, 2:]), dim=1) + + warped_imgs = [] + diff_maps = [] + + for i, ref_img in enumerate(ref_imgs_scaled): + current_pose = pose[:, i] + + ref_img_warped, valid_points = inverse_warp(ref_img, depth[:, 0], current_pose, + intrinsics_scaled, + rotation_mode, padding_mode) + diff = (tgt_img_scaled - ref_img_warped) * \ + valid_points.unsqueeze(1).float() + + if explainability_mask is not None: + diff = diff * explainability_mask[:, i:i+1].expand_as(diff) + + reconstruction_loss += diff.abs().mean() + assert((reconstruction_loss == reconstruction_loss).item() == 1) + + warped_imgs.append(ref_img_warped[0]) + diff_maps.append(diff[0]) + + return reconstruction_loss, warped_imgs, diff_maps + + warped_results, diff_results = [], [] + if type(explainability_mask) not in [tuple, list]: + explainability_mask = [explainability_mask] + if type(depth) not in [list, tuple]: + depth = [depth] + + total_loss = 0 + for d, mask in zip(depth, explainability_mask): + loss, warped, diff = one_scale(d, mask) + total_loss += loss + warped_results.append(warped) + diff_results.append(diff) + return total_loss, warped_results, diff_results diff --git a/unsupervised/utils.py b/unsupervised/third-party/utils.py similarity index 100% rename from unsupervised/utils.py rename to unsupervised/third-party/utils.py diff --git a/unsupervised/train.py b/unsupervised/train.py index b3ec44f..e2162ee 100644 --- a/unsupervised/train.py +++ b/unsupervised/train.py @@ -5,106 +5,35 @@ Allows pluggable depth networks for differing performance (including fast-depth) """ import tensorflow as tf +import tensorflow.keras as keras -def compute_smooth_loss(self, pred_disp): - def gradient(pred): - D_dy = pred[:, 1:, :, :] - pred[:, :-1, :, :] - D_dx = pred[:, :, 1:, :] - pred[:, :, :-1, :] - return D_dx, D_dy - dx, dy = gradient(pred_disp) - dx2, dxdy = gradient(dx) - dydx, dy2 = gradient(dy) - return tf.reduce_mean(tf.abs(dx2)) + \ - tf.reduce_mean(tf.abs(dxdy)) + \ - tf.reduce_mean(tf.abs(dydx)) + \ - tf.reduce_mean(tf.abs(dy2)) +class SFMLearner(keras.Model): + + def __init__(depth_model, pose_model): + pass + + def train_step(self, data): + pass -def get_reference_explain_mask(self, downscaling): - opt = self.opt - tmp = np.array([0, 1]) - ref_exp_mask = np.tile(tmp, - (opt.batch_size, - int(opt.img_height/(2**downscaling)), - int(opt.img_width/(2**downscaling)), - 1)) - ref_exp_mask = tf.constant(ref_exp_mask, dtype=tf.float32) - return ref_exp_mask +def projective_inverse_warp(depth, pose, t_img, s_imgs, intrinsics): + ''' + SFM Learner inverse warp step + ps ~ K.T(t->s).Dt(pt).K^-1.pt + + projected source pixel + ''' + pass -def get_sfm_loss_fn(opt): - def sfm_loss_fn(y, y_pred): - # TODO: Correctly format a batch that is required for this loss function - pixel_loss = 0 - exp_loss = 0 - smooth_loss = 0 - tgt_image_all = [] - src_image_stack_all = [] - proj_image_stack_all = [] - proj_error_stack_all = [] - exp_mask_stack_all = [] - for s in range(opt.num_scales): - if opt.explain_reg_weight > 0: - # Construct a reference explainability mask (i.e. all - # pixels are explainable) - ref_exp_mask = get_reference_explain_mask(s) - # Scale the source and target images for computing loss at the - # according scale. - curr_tgt_image = tf.image.resize_area(tgt_image, - [int(opt.img_height/(2**s)), int(opt.img_width/(2**s))]) - curr_src_image_stack = tf.image.resize_area(src_image_stack, - [int(opt.img_height/(2**s)), int(opt.img_width/(2**s))]) +def bilinear_sample(projected_coords, s_img): + ''' + Sample the 4 closest pixels in the source image via the projected coordinates + to get the source image warped to the target image + ''' + pass - if opt.smooth_weight > 0: - smooth_loss += opt.smooth_weight/(2**s) * \ - compute_smooth_loss(pred_disp[s]) - for i in range(opt.num_source): - # Inverse warp the source image to the target image frame - curr_proj_image = projective_inverse_warp( - curr_src_image_stack[:, :, :, 3*i:3*(i+1)], - tf.squeeze(pred_depth[s], axis=3), - pred_poses[:, i, :], - intrinsics[:, s, :, :]) - curr_proj_error = tf.abs(curr_proj_image - curr_tgt_image) - # Cross-entropy loss as regularization for the - # explainability prediction - if opt.explain_reg_weight > 0: - curr_exp_logits = tf.slice(pred_exp_logits[s], - [0, 0, 0, i*2], - [-1, -1, -1, 2]) - exp_loss += opt.explain_reg_weight * \ - self.compute_exp_reg_loss(curr_exp_logits, - ref_exp_mask) - curr_exp = tf.nn.softmax(curr_exp_logits) - # Photo-consistency loss weighted by explainability - if opt.explain_reg_weight > 0: - pixel_loss += tf.reduce_mean(curr_proj_error * - tf.expand_dims(curr_exp[:, :, :, 1], -1)) - else: - pixel_loss += tf.reduce_mean(curr_proj_error) - # Prepare images for tensorboard summaries - if i == 0: - proj_image_stack = curr_proj_image - proj_error_stack = curr_proj_error - if opt.explain_reg_weight > 0: - exp_mask_stack = tf.expand_dims( - curr_exp[:, :, :, 1], -1) - else: - proj_image_stack = tf.concat([proj_image_stack, - curr_proj_image], axis=3) - proj_error_stack = tf.concat([proj_error_stack, - curr_proj_error], axis=3) - if opt.explain_reg_weight > 0: - exp_mask_stack = tf.concat([exp_mask_stack, - tf.expand_dims(curr_exp[:, :, :, 1], -1)], axis=3) - tgt_image_all.append(curr_tgt_image) - src_image_stack_all.append(curr_src_image_stack) - proj_image_stack_all.append(proj_image_stack) - proj_error_stack_all.append(proj_error_stack) - if opt.explain_reg_weight > 0: - exp_mask_stack_all.append(exp_mask_stack) - total_loss = pixel_loss + smooth_loss + exp_loss - return total_loss - return sfm_loss_fn +def make_sfm_learner_pose_net(input_shape=(224, 224, 3)): + pass