""" Trainer to learn depth information on unlabeled data (raw images/videos) Allows pluggable depth networks for differing performance (including fast-depth) """ import tensorflow as tf def compute_smooth_loss(self, pred_disp): def gradient(pred): D_dy = pred[:, 1:, :, :] - pred[:, :-1, :, :] D_dx = pred[:, :, 1:, :] - pred[:, :, :-1, :] return D_dx, D_dy dx, dy = gradient(pred_disp) dx2, dxdy = gradient(dx) dydx, dy2 = gradient(dy) return tf.reduce_mean(tf.abs(dx2)) + \ tf.reduce_mean(tf.abs(dxdy)) + \ tf.reduce_mean(tf.abs(dydx)) + \ tf.reduce_mean(tf.abs(dy2)) def get_reference_explain_mask(self, downscaling): opt = self.opt tmp = np.array([0, 1]) ref_exp_mask = np.tile(tmp, (opt.batch_size, int(opt.img_height/(2**downscaling)), int(opt.img_width/(2**downscaling)), 1)) ref_exp_mask = tf.constant(ref_exp_mask, dtype=tf.float32) return ref_exp_mask def get_sfm_loss_fn(opt): def sfm_loss_fn(y, y_pred): # TODO: Correctly format a batch that is required for this loss function pixel_loss = 0 exp_loss = 0 smooth_loss = 0 tgt_image_all = [] src_image_stack_all = [] proj_image_stack_all = [] proj_error_stack_all = [] exp_mask_stack_all = [] for s in range(opt.num_scales): if opt.explain_reg_weight > 0: # Construct a reference explainability mask (i.e. all # pixels are explainable) ref_exp_mask = get_reference_explain_mask(s) # Scale the source and target images for computing loss at the # according scale. curr_tgt_image = tf.image.resize_area(tgt_image, [int(opt.img_height/(2**s)), int(opt.img_width/(2**s))]) curr_src_image_stack = tf.image.resize_area(src_image_stack, [int(opt.img_height/(2**s)), int(opt.img_width/(2**s))]) if opt.smooth_weight > 0: smooth_loss += opt.smooth_weight/(2**s) * \ compute_smooth_loss(pred_disp[s]) for i in range(opt.num_source): # Inverse warp the source image to the target image frame curr_proj_image = projective_inverse_warp( curr_src_image_stack[:, :, :, 3*i:3*(i+1)], tf.squeeze(pred_depth[s], axis=3), pred_poses[:, i, :], intrinsics[:, s, :, :]) curr_proj_error = tf.abs(curr_proj_image - curr_tgt_image) # Cross-entropy loss as regularization for the # explainability prediction if opt.explain_reg_weight > 0: curr_exp_logits = tf.slice(pred_exp_logits[s], [0, 0, 0, i*2], [-1, -1, -1, 2]) exp_loss += opt.explain_reg_weight * \ self.compute_exp_reg_loss(curr_exp_logits, ref_exp_mask) curr_exp = tf.nn.softmax(curr_exp_logits) # Photo-consistency loss weighted by explainability if opt.explain_reg_weight > 0: pixel_loss += tf.reduce_mean(curr_proj_error * tf.expand_dims(curr_exp[:, :, :, 1], -1)) else: pixel_loss += tf.reduce_mean(curr_proj_error) # Prepare images for tensorboard summaries if i == 0: proj_image_stack = curr_proj_image proj_error_stack = curr_proj_error if opt.explain_reg_weight > 0: exp_mask_stack = tf.expand_dims( curr_exp[:, :, :, 1], -1) else: proj_image_stack = tf.concat([proj_image_stack, curr_proj_image], axis=3) proj_error_stack = tf.concat([proj_error_stack, curr_proj_error], axis=3) if opt.explain_reg_weight > 0: exp_mask_stack = tf.concat([exp_mask_stack, tf.expand_dims(curr_exp[:, :, :, 1], -1)], axis=3) tgt_image_all.append(curr_tgt_image) src_image_stack_all.append(curr_src_image_stack) proj_image_stack_all.append(proj_image_stack) proj_error_stack_all.append(proj_error_stack) if opt.explain_reg_weight > 0: exp_mask_stack_all.append(exp_mask_stack) total_loss = pixel_loss + smooth_loss + exp_loss return total_loss return sfm_loss_fn def photometric_reconstruction_loss(tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, rotation_mode='euler', padding_mode='zeros'): def one_scale(d, mask): assert(mask is None or d.size() [2:] == mask.size()[2:]) assert(pose.size(1) == len(ref_imgs)) reconstruction_loss = 0 b, _, h, w = d.size() downscale = tgt_img.size(2)/h tgt_img_scaled = F.interpolate(tgt_img, (h, w), mode='area') ref_imgs_scaled = [F.interpolate( ref_img, (h, w), mode='area') for ref_img in ref_imgs] intrinsics_scaled = tf.concat( (intrinsics[:, 0:2]/downscale, intrinsics[:, 2:]), dim=1) warped_imgs = [] diff_maps = [] for i, ref_img in enumerate(ref_imgs_scaled): current_pose = pose[:, i] ref_img_warped, valid_points = inverse_warp(ref_img, depth[:, 0], current_pose, intrinsics_scaled, rotation_mode, padding_mode) diff = (tgt_img_scaled - ref_img_warped) * \ valid_points.unsqueeze(1).float() if explainability_mask is not None: diff = diff * explainability_mask[:, i:i+1].expand_as(diff) reconstruction_loss += diff.abs().mean() assert((reconstruction_loss == reconstruction_loss).item() == 1) warped_imgs.append(ref_img_warped[0]) diff_maps.append(diff[0]) return reconstruction_loss, warped_imgs, diff_maps warped_results, diff_results = [], [] if type(explainability_mask) not in [tuple, list]: explainability_mask = [explainability_mask] if type(depth) not in [list, tuple]: depth = [depth] total_loss = 0 for d, mask in zip(depth, explainability_mask): loss, warped, diff = one_scale(d, mask) total_loss += loss warped_results.append(warped) diff_results.append(diff) return total_loss, warped_results, diff_results