mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-12 06:50:09 +00:00
181 lines
6.6 KiB
Python
181 lines
6.6 KiB
Python
from typing import Optional, List
|
|
|
|
import torch
|
|
from matplotlib import pyplot
|
|
from torch import Tensor
|
|
from torch.nn import Module, Sequential, Tanh, Sigmoid
|
|
|
|
from tha3.nn.image_processing_util import GridChangeApplier, apply_color_change
|
|
from tha3.nn.common.resize_conv_unet import ResizeConvUNet, ResizeConvUNetArgs
|
|
from tha3.util import numpy_linear_to_srgb
|
|
from tha3.module.module_factory import ModuleFactory
|
|
from tha3.nn.conv import create_conv3_from_block_args, create_conv3
|
|
from tha3.nn.nonlinearity_factory import ReLUFactory
|
|
from tha3.nn.normalization import InstanceNorm2dFactory
|
|
from tha3.nn.util import BlockArgs
|
|
|
|
|
|
class Editor07Args:
|
|
def __init__(self,
|
|
image_size: int = 512,
|
|
image_channels: int = 4,
|
|
num_pose_params: int = 6,
|
|
start_channels: int = 32,
|
|
bottleneck_image_size=32,
|
|
num_bottleneck_blocks=6,
|
|
max_channels: int = 512,
|
|
upsampling_mode: str = 'nearest',
|
|
block_args: Optional[BlockArgs] = None,
|
|
use_separable_convolution: bool = False):
|
|
if block_args is None:
|
|
block_args = BlockArgs(
|
|
normalization_layer_factory=InstanceNorm2dFactory(),
|
|
nonlinearity_factory=ReLUFactory(inplace=False))
|
|
|
|
self.block_args = block_args
|
|
self.upsampling_mode = upsampling_mode
|
|
self.max_channels = max_channels
|
|
self.num_bottleneck_blocks = num_bottleneck_blocks
|
|
self.bottleneck_image_size = bottleneck_image_size
|
|
self.start_channels = start_channels
|
|
self.num_pose_params = num_pose_params
|
|
self.image_channels = image_channels
|
|
self.image_size = image_size
|
|
self.use_separable_convolution = use_separable_convolution
|
|
|
|
|
|
class Editor07(Module):
|
|
def __init__(self, args: Editor07Args):
|
|
super().__init__()
|
|
self.args = args
|
|
|
|
self.body = ResizeConvUNet(ResizeConvUNetArgs(
|
|
image_size=args.image_size,
|
|
input_channels=2 * args.image_channels + args.num_pose_params + 2,
|
|
start_channels=args.start_channels,
|
|
bottleneck_image_size=args.bottleneck_image_size,
|
|
num_bottleneck_blocks=args.num_bottleneck_blocks,
|
|
max_channels=args.max_channels,
|
|
upsample_mode=args.upsampling_mode,
|
|
block_args=args.block_args,
|
|
use_separable_convolution=args.use_separable_convolution))
|
|
self.color_change_creator = Sequential(
|
|
create_conv3_from_block_args(
|
|
in_channels=self.args.start_channels,
|
|
out_channels=self.args.image_channels,
|
|
bias=True,
|
|
block_args=self.args.block_args),
|
|
Tanh())
|
|
self.alpha_creator = Sequential(
|
|
create_conv3_from_block_args(
|
|
in_channels=self.args.start_channels,
|
|
out_channels=self.args.image_channels,
|
|
bias=True,
|
|
block_args=self.args.block_args),
|
|
Sigmoid())
|
|
self.grid_change_creator = create_conv3(
|
|
in_channels=self.args.start_channels,
|
|
out_channels=2,
|
|
bias=False,
|
|
initialization_method='zero',
|
|
use_spectral_norm=False)
|
|
self.grid_change_applier = GridChangeApplier()
|
|
|
|
def forward(self,
|
|
input_original_image: Tensor,
|
|
input_warped_image: Tensor,
|
|
input_grid_change: Tensor,
|
|
pose: Tensor,
|
|
*args) -> List[Tensor]:
|
|
n, c = pose.shape
|
|
pose = pose.view(n, c, 1, 1).repeat(1, 1, self.args.image_size, self.args.image_size)
|
|
feature = torch.cat([input_original_image, input_warped_image, input_grid_change, pose], dim=1)
|
|
|
|
feature = self.body.forward(feature)[-1]
|
|
output_grid_change = input_grid_change + self.grid_change_creator(feature)
|
|
|
|
output_color_change = self.color_change_creator(feature)
|
|
output_color_change_alpha = self.alpha_creator(feature)
|
|
output_warped_image = self.grid_change_applier.apply(output_grid_change, input_original_image)
|
|
output_color_changed = apply_color_change(output_color_change_alpha, output_color_change, output_warped_image)
|
|
|
|
return [
|
|
output_color_changed,
|
|
output_color_change_alpha,
|
|
output_color_change,
|
|
output_warped_image,
|
|
output_grid_change,
|
|
]
|
|
|
|
COLOR_CHANGED_IMAGE_INDEX = 0
|
|
COLOR_CHANGE_ALPHA_INDEX = 1
|
|
COLOR_CHANGE_IMAGE_INDEX = 2
|
|
WARPED_IMAGE_INDEX = 3
|
|
GRID_CHANGE_INDEX = 4
|
|
OUTPUT_LENGTH = 5
|
|
|
|
|
|
class Editor07Factory(ModuleFactory):
|
|
def __init__(self, args: Editor07Args):
|
|
super().__init__()
|
|
self.args = args
|
|
|
|
def create(self) -> Module:
|
|
return Editor07(self.args)
|
|
|
|
|
|
def show_image(pytorch_image):
|
|
numpy_image = ((pytorch_image + 1.0) / 2.0).squeeze(0).numpy()
|
|
numpy_image[0:3, :, :] = numpy_linear_to_srgb(numpy_image[0:3, :, :])
|
|
c, h, w = numpy_image.shape
|
|
numpy_image = numpy_image.reshape((c, h * w)).transpose().reshape((h, w, c))
|
|
pyplot.imshow(numpy_image)
|
|
pyplot.show()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cuda = torch.device('cuda')
|
|
|
|
image_size = 512
|
|
image_channels = 4
|
|
num_pose_params = 6
|
|
args = Editor07Args(
|
|
image_size=512,
|
|
image_channels=4,
|
|
start_channels=32,
|
|
num_pose_params=6,
|
|
bottleneck_image_size=32,
|
|
num_bottleneck_blocks=6,
|
|
max_channels=512,
|
|
upsampling_mode='nearest',
|
|
block_args=BlockArgs(
|
|
initialization_method='he',
|
|
use_spectral_norm=False,
|
|
normalization_layer_factory=InstanceNorm2dFactory(),
|
|
nonlinearity_factory=ReLUFactory(inplace=False)))
|
|
module = Editor07(args).to(cuda)
|
|
|
|
image_count = 1
|
|
input_image = torch.zeros(image_count, 4, image_size, image_size, device=cuda)
|
|
direct_image = torch.zeros(image_count, 4, image_size, image_size, device=cuda)
|
|
warped_image = torch.zeros(image_count, 4, image_size, image_size, device=cuda)
|
|
grid_change = torch.zeros(image_count, 2, image_size, image_size, device=cuda)
|
|
pose = torch.zeros(image_count, num_pose_params, device=cuda)
|
|
|
|
repeat = 100
|
|
acc = 0.0
|
|
for i in range(repeat + 2):
|
|
start = torch.cuda.Event(enable_timing=True)
|
|
end = torch.cuda.Event(enable_timing=True)
|
|
|
|
start.record()
|
|
module.forward(input_image, warped_image, grid_change, pose)
|
|
end.record()
|
|
torch.cuda.synchronize()
|
|
if i >= 2:
|
|
elapsed_time = start.elapsed_time(end)
|
|
print("%d:" % i, elapsed_time)
|
|
acc = acc + elapsed_time
|
|
|
|
print("average:", acc / repeat)
|