Start implementing unsupervised train loop, add sfmlearner train and utils files for reference
This commit is contained in:
166
unsupervised/third-party/train.py
vendored
Normal file
166
unsupervised/third-party/train.py
vendored
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Trainer to learn depth information on unlabeled data (raw images/videos)
|
||||
|
||||
Allows pluggable depth networks for differing performance (including fast-depth)
|
||||
"""
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def compute_smooth_loss(self, pred_disp):
|
||||
def gradient(pred):
|
||||
D_dy = pred[:, 1:, :, :] - pred[:, :-1, :, :]
|
||||
D_dx = pred[:, :, 1:, :] - pred[:, :, :-1, :]
|
||||
return D_dx, D_dy
|
||||
dx, dy = gradient(pred_disp)
|
||||
dx2, dxdy = gradient(dx)
|
||||
dydx, dy2 = gradient(dy)
|
||||
return tf.reduce_mean(tf.abs(dx2)) + \
|
||||
tf.reduce_mean(tf.abs(dxdy)) + \
|
||||
tf.reduce_mean(tf.abs(dydx)) + \
|
||||
tf.reduce_mean(tf.abs(dy2))
|
||||
|
||||
|
||||
def get_reference_explain_mask(self, downscaling):
|
||||
opt = self.opt
|
||||
tmp = np.array([0, 1])
|
||||
ref_exp_mask = np.tile(tmp,
|
||||
(opt.batch_size,
|
||||
int(opt.img_height/(2**downscaling)),
|
||||
int(opt.img_width/(2**downscaling)),
|
||||
1))
|
||||
ref_exp_mask = tf.constant(ref_exp_mask, dtype=tf.float32)
|
||||
return ref_exp_mask
|
||||
|
||||
|
||||
def get_sfm_loss_fn(opt):
|
||||
def sfm_loss_fn(y, y_pred):
|
||||
# TODO: Correctly format a batch that is required for this loss function
|
||||
pixel_loss = 0
|
||||
exp_loss = 0
|
||||
smooth_loss = 0
|
||||
tgt_image_all = []
|
||||
src_image_stack_all = []
|
||||
proj_image_stack_all = []
|
||||
proj_error_stack_all = []
|
||||
exp_mask_stack_all = []
|
||||
for s in range(opt.num_scales):
|
||||
if opt.explain_reg_weight > 0:
|
||||
# Construct a reference explainability mask (i.e. all
|
||||
# pixels are explainable)
|
||||
ref_exp_mask = get_reference_explain_mask(s)
|
||||
# Scale the source and target images for computing loss at the
|
||||
# according scale.
|
||||
curr_tgt_image = tf.image.resize_area(tgt_image,
|
||||
[int(opt.img_height/(2**s)), int(opt.img_width/(2**s))])
|
||||
curr_src_image_stack = tf.image.resize_area(src_image_stack,
|
||||
[int(opt.img_height/(2**s)), int(opt.img_width/(2**s))])
|
||||
|
||||
if opt.smooth_weight > 0:
|
||||
smooth_loss += opt.smooth_weight/(2**s) * \
|
||||
compute_smooth_loss(pred_disp[s])
|
||||
|
||||
for i in range(opt.num_source):
|
||||
# Inverse warp the source image to the target image frame
|
||||
curr_proj_image = projective_inverse_warp(
|
||||
curr_src_image_stack[:, :, :, 3*i:3*(i+1)],
|
||||
tf.squeeze(pred_depth[s], axis=3),
|
||||
pred_poses[:, i, :],
|
||||
intrinsics[:, s, :, :])
|
||||
curr_proj_error = tf.abs(curr_proj_image - curr_tgt_image)
|
||||
# Cross-entropy loss as regularization for the
|
||||
# explainability prediction
|
||||
if opt.explain_reg_weight > 0:
|
||||
curr_exp_logits = tf.slice(pred_exp_logits[s],
|
||||
[0, 0, 0, i*2],
|
||||
[-1, -1, -1, 2])
|
||||
exp_loss += opt.explain_reg_weight * \
|
||||
self.compute_exp_reg_loss(curr_exp_logits,
|
||||
ref_exp_mask)
|
||||
curr_exp = tf.nn.softmax(curr_exp_logits)
|
||||
# Photo-consistency loss weighted by explainability
|
||||
if opt.explain_reg_weight > 0:
|
||||
pixel_loss += tf.reduce_mean(curr_proj_error *
|
||||
tf.expand_dims(curr_exp[:, :, :, 1], -1))
|
||||
else:
|
||||
pixel_loss += tf.reduce_mean(curr_proj_error)
|
||||
# Prepare images for tensorboard summaries
|
||||
if i == 0:
|
||||
proj_image_stack = curr_proj_image
|
||||
proj_error_stack = curr_proj_error
|
||||
if opt.explain_reg_weight > 0:
|
||||
exp_mask_stack = tf.expand_dims(
|
||||
curr_exp[:, :, :, 1], -1)
|
||||
else:
|
||||
proj_image_stack = tf.concat([proj_image_stack,
|
||||
curr_proj_image], axis=3)
|
||||
proj_error_stack = tf.concat([proj_error_stack,
|
||||
curr_proj_error], axis=3)
|
||||
if opt.explain_reg_weight > 0:
|
||||
exp_mask_stack = tf.concat([exp_mask_stack,
|
||||
tf.expand_dims(curr_exp[:, :, :, 1], -1)], axis=3)
|
||||
tgt_image_all.append(curr_tgt_image)
|
||||
src_image_stack_all.append(curr_src_image_stack)
|
||||
proj_image_stack_all.append(proj_image_stack)
|
||||
proj_error_stack_all.append(proj_error_stack)
|
||||
if opt.explain_reg_weight > 0:
|
||||
exp_mask_stack_all.append(exp_mask_stack)
|
||||
total_loss = pixel_loss + smooth_loss + exp_loss
|
||||
return total_loss
|
||||
return sfm_loss_fn
|
||||
|
||||
|
||||
def photometric_reconstruction_loss(tgt_img, ref_imgs, intrinsics,
|
||||
depth, explainability_mask, pose,
|
||||
rotation_mode='euler', padding_mode='zeros'):
|
||||
def one_scale(d, mask):
|
||||
assert(mask is None or d.size()
|
||||
[2:] == mask.size()[2:])
|
||||
assert(pose.size(1) == len(ref_imgs))
|
||||
|
||||
reconstruction_loss = 0
|
||||
b, _, h, w = d.size()
|
||||
downscale = tgt_img.size(2)/h
|
||||
|
||||
tgt_img_scaled = F.interpolate(tgt_img, (h, w), mode='area')
|
||||
ref_imgs_scaled = [F.interpolate(
|
||||
ref_img, (h, w), mode='area') for ref_img in ref_imgs]
|
||||
intrinsics_scaled = tf.concat(
|
||||
(intrinsics[:, 0:2]/downscale, intrinsics[:, 2:]), dim=1)
|
||||
|
||||
warped_imgs = []
|
||||
diff_maps = []
|
||||
|
||||
for i, ref_img in enumerate(ref_imgs_scaled):
|
||||
current_pose = pose[:, i]
|
||||
|
||||
ref_img_warped, valid_points = inverse_warp(ref_img, depth[:, 0], current_pose,
|
||||
intrinsics_scaled,
|
||||
rotation_mode, padding_mode)
|
||||
diff = (tgt_img_scaled - ref_img_warped) * \
|
||||
valid_points.unsqueeze(1).float()
|
||||
|
||||
if explainability_mask is not None:
|
||||
diff = diff * explainability_mask[:, i:i+1].expand_as(diff)
|
||||
|
||||
reconstruction_loss += diff.abs().mean()
|
||||
assert((reconstruction_loss == reconstruction_loss).item() == 1)
|
||||
|
||||
warped_imgs.append(ref_img_warped[0])
|
||||
diff_maps.append(diff[0])
|
||||
|
||||
return reconstruction_loss, warped_imgs, diff_maps
|
||||
|
||||
warped_results, diff_results = [], []
|
||||
if type(explainability_mask) not in [tuple, list]:
|
||||
explainability_mask = [explainability_mask]
|
||||
if type(depth) not in [list, tuple]:
|
||||
depth = [depth]
|
||||
|
||||
total_loss = 0
|
||||
for d, mask in zip(depth, explainability_mask):
|
||||
loss, warped, diff = one_scale(d, mask)
|
||||
total_loss += loss
|
||||
warped_results.append(warped)
|
||||
diff_results.append(diff)
|
||||
return total_loss, warped_results, diff_results
|
||||
Reference in New Issue
Block a user