91 lines
4.0 KiB
Python
91 lines
4.0 KiB
Python
"""
|
|
Trainer to learn depth information on unlabeled data (raw images/videos)
|
|
|
|
Allows pluggable depth networks for differing performance (including fast-depth)
|
|
"""
|
|
|
|
import tensorflow as tf
|
|
import tensorflow.python.keras as keras
|
|
from unsupervised import warp
|
|
import unsupervised.loss as loss
|
|
|
|
|
|
class UnsupervisedPoseDepthLearner(keras.Model):
|
|
"""
|
|
Keras model to learn simultaneous depth + pose from image/video sequences.
|
|
|
|
To train this, the datasource should yield 3 frames and camera intrinsics.
|
|
Optionally velocity + timestamp per frame to train to real scale
|
|
"""
|
|
|
|
def __init__(self, depth_model, pose_model, num_scales=3, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.depth_model = depth_model
|
|
self.pose_model = pose_model
|
|
# TODO: I think num_scales should be something defined on the depth model itself
|
|
self.num_scales = num_scales
|
|
self.smoothness = 1e-3
|
|
|
|
def train_step(self, data):
|
|
"""
|
|
|
|
:param data: Format: {frames: Mat[3], intrinsics: Tensor}
|
|
"""
|
|
# Pass through depth for target image
|
|
# TODO: Convert frame to tensor (or do this in the dataloader)
|
|
# TODO: Ensure the depth output includes enough outputs for each scale
|
|
depth = self.depth_model(data.frames[1])
|
|
|
|
# Pass through depth -> pose for both source images
|
|
# TODO: Concat these poses using tf.concat
|
|
pose1 = self.pose_model(data.frames[1], data.frames[0])
|
|
pose2 = self.pose_model(data.frames[1], data.frames[2])
|
|
|
|
shape = depth[0].shape
|
|
|
|
# TODO: Pull coords out of train step into initialiser, then it only needs to be created once.
|
|
# Ideally the size/batch size will still be calculated automatically
|
|
coords = warp.image_coordinate(shape[0], shape[1], shape[2])
|
|
total_loss = 0
|
|
|
|
scale_losses = []
|
|
# For each scale, do the projective inverse warp step and calculate losses
|
|
for scale in range(self.num_scales):
|
|
# TODO: Could simplify this by stacking the source images (see sfmlearner)
|
|
# It isn't too much of an issue right now since we're only using 2 images (left/right)
|
|
# For each depth output (scale), do the projective inverse warp on each input image and calculate the losses
|
|
# Only take the min loss between the two warped images (from monodepth2)
|
|
warp1 = warp.projective_inverse_warp(data.frames[0], depth[scale], pose1, data.intrinsics, coords)
|
|
warp2 = warp.projective_inverse_warp(data.frames[2], depth[scale], pose1, data.intrinsics, coords)
|
|
|
|
# Per pixel loss is just the difference in pixel intensities?
|
|
# Something like l1 plus ssim
|
|
warp_loss1 = loss.make_combined_ssim_l1_loss(data.frames[1], warp1)
|
|
warp_loss2 = loss.make_combined_ssim_l1_loss(data.frames[1], warp2)
|
|
|
|
# Take loss between target (data.frames[1]) and source images (pre-warp)
|
|
source_loss1 = loss.make_combined_ssim_l1_loss(data.frames[1], data.frames[0])
|
|
source_loss2 = loss.make_combined_ssim_l1_loss(data.frames[1], data.frames[2])
|
|
|
|
# Take the min (per pixel) of the losses of warped/unwarped images (so min across pixels of 4 images)
|
|
# TODO: Verify the axes are correct
|
|
reprojection_loss = tf.reduce_mean(
|
|
tf.reduce_min(tf.concat([warp_loss1, warp_loss2, source_loss1, source_loss2], axis=3), axis=3))
|
|
|
|
# Calculate smooth losses
|
|
smooth_loss = loss.smooth_loss(depth[scale], data.frames[1])
|
|
|
|
# TODO: Monodepth also divides the smooth loss by 2 ** scale, why?
|
|
smoothed_reprojection_loss = self.smoothness * smooth_loss / (2 ** scale)
|
|
|
|
# Add to total loss (with smooth loss + smooth loss weighting applied to pixel losses)
|
|
total_loss += reprojection_loss + smooth_loss
|
|
pass
|
|
|
|
# Collect losses, average them out (divide by number of scales)
|
|
total_loss /= self.num_scales
|
|
|
|
# Apply optimise step on total loss
|
|
|
|
pass
|