mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-03 01:29:50 +00:00
24 lines
730 B
Python
24 lines
730 B
Python
import torch
|
|
|
|
|
|
def total_variation(image):
|
|
"""
|
|
Compute normalized total variation.
|
|
Inputs:
|
|
- image: PyTorch Variable of shape (N, C, H, W)
|
|
Returns:
|
|
- TV: total variation normalized by the number of elements
|
|
"""
|
|
n_elements = image.shape[1] * image.shape[2] * image.shape[3]
|
|
return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) +
|
|
torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements)
|
|
|
|
|
|
class ComparativeTotalVariation(torch.nn.Module):
|
|
"""
|
|
Compute the comparative loss in tv between two images. to match their tv
|
|
"""
|
|
|
|
def forward(self, pred, target):
|
|
return torch.abs(total_variation(pred) - total_variation(target))
|