From e372fe33ba47f81972b88562719296fe6a00a6c8 Mon Sep 17 00:00:00 2001 From: Piv <18462828+Piv200@users.noreply.github.com> Date: Tue, 13 Jul 2021 20:32:45 +0930 Subject: [PATCH] Add per-pixel loss functions --- unsupervised/loss.py | 53 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 unsupervised/loss.py diff --git a/unsupervised/loss.py b/unsupervised/loss.py new file mode 100644 index 0000000..d62e5b1 --- /dev/null +++ b/unsupervised/loss.py @@ -0,0 +1,53 @@ +import tensorflow as tf + + +def l1_loss(target_img, reprojected_img): + """ + Calculates the l1 norm between the target and reprojected image + + :param target_img: Tensor (batch, height, width, 3) + :param reprojected_img: Tensor, same shape as target_img + :return: The per-pixel l1 norm -> Tensor (batch, height, width, 1) + """ + return tf.reduce_mean(tf.abs(target_img - reprojected_img), axis=3) + + +def l2_loss(target_img, reprojected_img): + """ + Calculates the l2 norm between the target and reprojected image + + :param target_img: Tensor (batch, height, width, 3) + :param reprojected_img: Tensor, same shape as target_img + :return: The per-pixel l2 norm -> Tensor (batch, height, width, 1) + """ + return tf.reduce_mean((target_img - reprojected_img) ** 2 ** (1 / 2), axis=3) + + +def make_combined_ssim_l1_loss(ssim_weight: int = 0.85, other_loss_fn=l1_loss): + """ + Create a loss function that will calculate ssim for the two images, and use the other_loss_fn to calculate the + per pixel loss + :param ssim_weight: Weighting that should be applied to SSIM weight vs l1 difference between target and + reprojected image + :param other_loss_fn: Function to combine with the ssim + :return: Function to calculate the per-pixel combined ssim with other loss function + """ + + def combined_ssim_loss(target_img, reprojected_img): + """ + Calculates the per-pixel photometric reconstruction loss for each source image, + combined this with the SSIM between the reconstructed image and the actual image. + + Calculates the following: + ssim_weight * SSIM(target_img, reprojected_img) + (1 - ssim_weight) * other_loss_fn(target_img - reprojected_img) + + :param target_img: Tensor with shape (batch, height, width, 3) - current image we're training on + :param reprojected_img: Tensor with same shape as target_img, Reprojected from some source image that + should be as close as possible to the target image + :return: Per-pixel loss -> Tensor with shape (batch, height, width, 1), where height and width match target_img + height and width + """ + ssim = tf.image.ssim(target_img, reprojected_img, axis=3, keepdim=True) + return ssim_weight * ssim + (1 - ssim_weight) * other_loss_fn(target_img, reprojected_img) + + return combined_ssim_loss