Files
fast-depth-tf/unsupervised/train.py

111 lines
4.9 KiB
Python

"""
Trainer to learn depth information on unlabeled data (raw images/videos)
Allows pluggable depth networks for differing performance (including fast-depth)
"""
import tensorflow as tf
def compute_smooth_loss(self, pred_disp):
def gradient(pred):
D_dy = pred[:, 1:, :, :] - pred[:, :-1, :, :]
D_dx = pred[:, :, 1:, :] - pred[:, :, :-1, :]
return D_dx, D_dy
dx, dy = gradient(pred_disp)
dx2, dxdy = gradient(dx)
dydx, dy2 = gradient(dy)
return tf.reduce_mean(tf.abs(dx2)) + \
tf.reduce_mean(tf.abs(dxdy)) + \
tf.reduce_mean(tf.abs(dydx)) + \
tf.reduce_mean(tf.abs(dy2))
def get_reference_explain_mask(self, downscaling):
opt = self.opt
tmp = np.array([0, 1])
ref_exp_mask = np.tile(tmp,
(opt.batch_size,
int(opt.img_height/(2**downscaling)),
int(opt.img_width/(2**downscaling)),
1))
ref_exp_mask = tf.constant(ref_exp_mask, dtype=tf.float32)
return ref_exp_mask
def get_sfm_loss_fn(opt):
def sfm_loss_fn(y, y_pred):
# TODO: Correctly format a batch that is required for this loss function
pixel_loss = 0
exp_loss = 0
smooth_loss = 0
tgt_image_all = []
src_image_stack_all = []
proj_image_stack_all = []
proj_error_stack_all = []
exp_mask_stack_all = []
for s in range(opt.num_scales):
if opt.explain_reg_weight > 0:
# Construct a reference explainability mask (i.e. all
# pixels are explainable)
ref_exp_mask = get_reference_explain_mask(s)
# Scale the source and target images for computing loss at the
# according scale.
curr_tgt_image = tf.image.resize_area(tgt_image,
[int(opt.img_height/(2**s)), int(opt.img_width/(2**s))])
curr_src_image_stack = tf.image.resize_area(src_image_stack,
[int(opt.img_height/(2**s)), int(opt.img_width/(2**s))])
if opt.smooth_weight > 0:
smooth_loss += opt.smooth_weight/(2**s) * \
compute_smooth_loss(pred_disp[s])
for i in range(opt.num_source):
# Inverse warp the source image to the target image frame
curr_proj_image = projective_inverse_warp(
curr_src_image_stack[:, :, :, 3*i:3*(i+1)],
tf.squeeze(pred_depth[s], axis=3),
pred_poses[:, i, :],
intrinsics[:, s, :, :])
curr_proj_error = tf.abs(curr_proj_image - curr_tgt_image)
# Cross-entropy loss as regularization for the
# explainability prediction
if opt.explain_reg_weight > 0:
curr_exp_logits = tf.slice(pred_exp_logits[s],
[0, 0, 0, i*2],
[-1, -1, -1, 2])
exp_loss += opt.explain_reg_weight * \
self.compute_exp_reg_loss(curr_exp_logits,
ref_exp_mask)
curr_exp = tf.nn.softmax(curr_exp_logits)
# Photo-consistency loss weighted by explainability
if opt.explain_reg_weight > 0:
pixel_loss += tf.reduce_mean(curr_proj_error *
tf.expand_dims(curr_exp[:, :, :, 1], -1))
else:
pixel_loss += tf.reduce_mean(curr_proj_error)
# Prepare images for tensorboard summaries
if i == 0:
proj_image_stack = curr_proj_image
proj_error_stack = curr_proj_error
if opt.explain_reg_weight > 0:
exp_mask_stack = tf.expand_dims(
curr_exp[:, :, :, 1], -1)
else:
proj_image_stack = tf.concat([proj_image_stack,
curr_proj_image], axis=3)
proj_error_stack = tf.concat([proj_error_stack,
curr_proj_error], axis=3)
if opt.explain_reg_weight > 0:
exp_mask_stack = tf.concat([exp_mask_stack,
tf.expand_dims(curr_exp[:, :, :, 1], -1)], axis=3)
tgt_image_all.append(curr_tgt_image)
src_image_stack_all.append(curr_src_image_stack)
proj_image_stack_all.append(proj_image_stack)
proj_error_stack_all.append(proj_error_stack)
if opt.explain_reg_weight > 0:
exp_mask_stack_all.append(exp_mask_stack)
total_loss = pixel_loss + smooth_loss + exp_loss
return total_loss
return sfm_loss_fn