""" Trainer to learn depth information on unlabeled data (raw images/videos) Allows pluggable depth networks for differing performance (including fast-depth) """ import tensorflow as tf def compute_smooth_loss(self, pred_disp): def gradient(pred): D_dy = pred[:, 1:, :, :] - pred[:, :-1, :, :] D_dx = pred[:, :, 1:, :] - pred[:, :, :-1, :] return D_dx, D_dy dx, dy = gradient(pred_disp) dx2, dxdy = gradient(dx) dydx, dy2 = gradient(dy) return tf.reduce_mean(tf.abs(dx2)) + \ tf.reduce_mean(tf.abs(dxdy)) + \ tf.reduce_mean(tf.abs(dydx)) + \ tf.reduce_mean(tf.abs(dy2)) def get_reference_explain_mask(self, downscaling): opt = self.opt tmp = np.array([0, 1]) ref_exp_mask = np.tile(tmp, (opt.batch_size, int(opt.img_height/(2**downscaling)), int(opt.img_width/(2**downscaling)), 1)) ref_exp_mask = tf.constant(ref_exp_mask, dtype=tf.float32) return ref_exp_mask def get_sfm_loss_fn(opt): def sfm_loss_fn(y, y_pred): # TODO: Correctly format a batch that is required for this loss function pixel_loss = 0 exp_loss = 0 smooth_loss = 0 tgt_image_all = [] src_image_stack_all = [] proj_image_stack_all = [] proj_error_stack_all = [] exp_mask_stack_all = [] for s in range(opt.num_scales): if opt.explain_reg_weight > 0: # Construct a reference explainability mask (i.e. all # pixels are explainable) ref_exp_mask = get_reference_explain_mask(s) # Scale the source and target images for computing loss at the # according scale. curr_tgt_image = tf.image.resize_area(tgt_image, [int(opt.img_height/(2**s)), int(opt.img_width/(2**s))]) curr_src_image_stack = tf.image.resize_area(src_image_stack, [int(opt.img_height/(2**s)), int(opt.img_width/(2**s))]) if opt.smooth_weight > 0: smooth_loss += opt.smooth_weight/(2**s) * \ compute_smooth_loss(pred_disp[s]) for i in range(opt.num_source): # Inverse warp the source image to the target image frame curr_proj_image = projective_inverse_warp( curr_src_image_stack[:, :, :, 3*i:3*(i+1)], tf.squeeze(pred_depth[s], axis=3), pred_poses[:, i, :], intrinsics[:, s, :, :]) curr_proj_error = tf.abs(curr_proj_image - curr_tgt_image) # Cross-entropy loss as regularization for the # explainability prediction if opt.explain_reg_weight > 0: curr_exp_logits = tf.slice(pred_exp_logits[s], [0, 0, 0, i*2], [-1, -1, -1, 2]) exp_loss += opt.explain_reg_weight * \ self.compute_exp_reg_loss(curr_exp_logits, ref_exp_mask) curr_exp = tf.nn.softmax(curr_exp_logits) # Photo-consistency loss weighted by explainability if opt.explain_reg_weight > 0: pixel_loss += tf.reduce_mean(curr_proj_error * tf.expand_dims(curr_exp[:, :, :, 1], -1)) else: pixel_loss += tf.reduce_mean(curr_proj_error) # Prepare images for tensorboard summaries if i == 0: proj_image_stack = curr_proj_image proj_error_stack = curr_proj_error if opt.explain_reg_weight > 0: exp_mask_stack = tf.expand_dims( curr_exp[:, :, :, 1], -1) else: proj_image_stack = tf.concat([proj_image_stack, curr_proj_image], axis=3) proj_error_stack = tf.concat([proj_error_stack, curr_proj_error], axis=3) if opt.explain_reg_weight > 0: exp_mask_stack = tf.concat([exp_mask_stack, tf.expand_dims(curr_exp[:, :, :, 1], -1)], axis=3) tgt_image_all.append(curr_tgt_image) src_image_stack_all.append(curr_src_image_stack) proj_image_stack_all.append(proj_image_stack) proj_error_stack_all.append(proj_error_stack) if opt.explain_reg_weight > 0: exp_mask_stack_all.append(exp_mask_stack) total_loss = pixel_loss + smooth_loss + exp_loss return total_loss return sfm_loss_fn