Add per-pixel loss functions
This commit is contained in:
53
unsupervised/loss.py
Normal file
53
unsupervised/loss.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def l1_loss(target_img, reprojected_img):
|
||||
"""
|
||||
Calculates the l1 norm between the target and reprojected image
|
||||
|
||||
:param target_img: Tensor (batch, height, width, 3)
|
||||
:param reprojected_img: Tensor, same shape as target_img
|
||||
:return: The per-pixel l1 norm -> Tensor (batch, height, width, 1)
|
||||
"""
|
||||
return tf.reduce_mean(tf.abs(target_img - reprojected_img), axis=3)
|
||||
|
||||
|
||||
def l2_loss(target_img, reprojected_img):
|
||||
"""
|
||||
Calculates the l2 norm between the target and reprojected image
|
||||
|
||||
:param target_img: Tensor (batch, height, width, 3)
|
||||
:param reprojected_img: Tensor, same shape as target_img
|
||||
:return: The per-pixel l2 norm -> Tensor (batch, height, width, 1)
|
||||
"""
|
||||
return tf.reduce_mean((target_img - reprojected_img) ** 2 ** (1 / 2), axis=3)
|
||||
|
||||
|
||||
def make_combined_ssim_l1_loss(ssim_weight: int = 0.85, other_loss_fn=l1_loss):
|
||||
"""
|
||||
Create a loss function that will calculate ssim for the two images, and use the other_loss_fn to calculate the
|
||||
per pixel loss
|
||||
:param ssim_weight: Weighting that should be applied to SSIM weight vs l1 difference between target and
|
||||
reprojected image
|
||||
:param other_loss_fn: Function to combine with the ssim
|
||||
:return: Function to calculate the per-pixel combined ssim with other loss function
|
||||
"""
|
||||
|
||||
def combined_ssim_loss(target_img, reprojected_img):
|
||||
"""
|
||||
Calculates the per-pixel photometric reconstruction loss for each source image,
|
||||
combined this with the SSIM between the reconstructed image and the actual image.
|
||||
|
||||
Calculates the following:
|
||||
ssim_weight * SSIM(target_img, reprojected_img) + (1 - ssim_weight) * other_loss_fn(target_img - reprojected_img)
|
||||
|
||||
:param target_img: Tensor with shape (batch, height, width, 3) - current image we're training on
|
||||
:param reprojected_img: Tensor with same shape as target_img, Reprojected from some source image that
|
||||
should be as close as possible to the target image
|
||||
:return: Per-pixel loss -> Tensor with shape (batch, height, width, 1), where height and width match target_img
|
||||
height and width
|
||||
"""
|
||||
ssim = tf.image.ssim(target_img, reprojected_img, axis=3, keepdim=True)
|
||||
return ssim_weight * ssim + (1 - ssim_weight) * other_loss_fn(target_img, reprojected_img)
|
||||
|
||||
return combined_ssim_loss
|
||||
Reference in New Issue
Block a user