import tensorflow as tf def l1_loss(target_img, reprojected_img): """ Calculates the l1 norm between the target and reprojected image :param target_img: Tensor (batch, height, width, 3) :param reprojected_img: Tensor, same shape as target_img :return: The per-pixel l1 norm -> Tensor (batch, height, width, 1) """ return tf.reduce_mean(tf.abs(target_img - reprojected_img), axis=3) def l2_loss(target_img, reprojected_img): """ Calculates the l2 norm between the target and reprojected image :param target_img: Tensor (batch, height, width, 3) :param reprojected_img: Tensor, same shape as target_img :return: The per-pixel l2 norm -> Tensor (batch, height, width, 1) """ return tf.reduce_mean((target_img - reprojected_img) ** 2 ** (1 / 2), axis=3) def make_combined_ssim_l1_loss(ssim_weight: int = 0.85, other_loss_fn=l1_loss): """ Create a loss function that will calculate ssim for the two images, and use the other_loss_fn to calculate the per pixel loss :param ssim_weight: Weighting that should be applied to SSIM weight vs l1 difference between target and reprojected image :param other_loss_fn: Function to combine with the ssim :return: Function to calculate the per-pixel combined ssim with other loss function """ def combined_ssim_loss(target_img, reprojected_img): """ Calculates the per-pixel photometric reconstruction loss for each source image, combined this with the SSIM between the reconstructed image and the actual image. Calculates the following: ssim_weight * SSIM(target_img, reprojected_img) + (1 - ssim_weight) * other_loss_fn(target_img - reprojected_img) :param target_img: Tensor with shape (batch, height, width, 3) - current image we're training on :param reprojected_img: Tensor with same shape as target_img, Reprojected from some source image that should be as close as possible to the target image :return: Per-pixel loss -> Tensor with shape (batch, height, width, 1), where height and width match target_img height and width """ ssim = tf.image.ssim(target_img, reprojected_img, axis=3, keepdim=True) return ssim_weight * ssim + (1 - ssim_weight) * other_loss_fn(target_img, reprojected_img) return combined_ssim_loss