mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-11 22:40:03 +00:00
59 lines
2.2 KiB
Python
59 lines
2.2 KiB
Python
import torch
|
|
from torch import Tensor
|
|
from torch.nn.functional import affine_grid, grid_sample
|
|
|
|
|
|
def apply_rgb_change(alpha: Tensor, color_change: Tensor, image: Tensor):
|
|
image_rgb = image[:, 0:3, :, :]
|
|
color_change_rgb = color_change[:, 0:3, :, :]
|
|
output_rgb = color_change_rgb * alpha + image_rgb * (1 - alpha)
|
|
return torch.cat([output_rgb, image[:, 3:4, :, :]], dim=1)
|
|
|
|
|
|
def apply_grid_change(grid_change, image: Tensor) -> Tensor:
|
|
n, c, h, w = image.shape
|
|
device = grid_change.device
|
|
grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2)
|
|
identity = torch.tensor(
|
|
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
|
|
dtype=grid_change.dtype,
|
|
device=device).unsqueeze(0).repeat(n, 1, 1)
|
|
base_grid = affine_grid(identity, [n, c, h, w], align_corners=False)
|
|
grid = base_grid + grid_change
|
|
resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=False)
|
|
return resampled_image
|
|
|
|
|
|
class GridChangeApplier:
|
|
def __init__(self):
|
|
self.last_n = None
|
|
self.last_device = None
|
|
self.last_identity = None
|
|
|
|
def apply(self, grid_change: Tensor, image: Tensor, align_corners: bool = False) -> Tensor:
|
|
n, c, h, w = image.shape
|
|
device = grid_change.device
|
|
grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2)
|
|
|
|
if n == self.last_n and device == self.last_device:
|
|
identity = self.last_identity
|
|
else:
|
|
identity = torch.tensor(
|
|
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
|
|
dtype=grid_change.dtype,
|
|
device=device,
|
|
requires_grad=False) \
|
|
.unsqueeze(0).repeat(n, 1, 1)
|
|
self.last_identity = identity
|
|
self.last_n = n
|
|
self.last_device = device
|
|
base_grid = affine_grid(identity, [n, c, h, w], align_corners=align_corners)
|
|
|
|
grid = base_grid + grid_change
|
|
resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=align_corners)
|
|
return resampled_image
|
|
|
|
|
|
def apply_color_change(alpha, color_change, image: Tensor) -> Tensor:
|
|
return color_change * alpha + image * (1 - alpha)
|