mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
124 lines
4.0 KiB
Python
124 lines
4.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class UpsampleBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.conv_in = nn.Sequential(
|
|
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
|
|
nn.GELU()
|
|
)
|
|
self.conv_up = nn.Sequential(
|
|
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
|
|
nn.GELU()
|
|
)
|
|
|
|
self.conv_out = nn.Sequential(
|
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.conv_in(x)
|
|
x = self.conv_up(x)
|
|
x = self.conv_out(x)
|
|
return x
|
|
|
|
|
|
class CLIPImagePreProcessor(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_size=896,
|
|
clip_input_size=224,
|
|
downscale_factor: int = 16,
|
|
):
|
|
super().__init__()
|
|
# make sure they are evenly divisible
|
|
assert input_size % clip_input_size == 0
|
|
in_channels = 3
|
|
|
|
self.input_size = input_size
|
|
self.clip_input_size = clip_input_size
|
|
self.downscale_factor = downscale_factor
|
|
|
|
subpixel_channels = in_channels * downscale_factor ** 2 # 3 * 16 ** 2 = 768
|
|
channels = subpixel_channels
|
|
|
|
upscale_factor = downscale_factor / int((input_size / clip_input_size)) # 16 / (896 / 224) = 4
|
|
|
|
num_upsample_blocks = int(upscale_factor // 2) # 4 // 2 = 2
|
|
|
|
# make the residual down up blocks
|
|
self.upsample_blocks = nn.ModuleList()
|
|
self.subpixel_blocks = nn.ModuleList()
|
|
current_channels = channels
|
|
current_downscale = downscale_factor
|
|
for _ in range(num_upsample_blocks):
|
|
# determine the reshuffled channel count for this dimension
|
|
output_downscale = current_downscale // 2
|
|
out_channels = in_channels * output_downscale ** 2
|
|
# out_channels = current_channels // 2
|
|
self.upsample_blocks.append(UpsampleBlock(current_channels, out_channels))
|
|
current_channels = out_channels
|
|
current_downscale = output_downscale
|
|
self.subpixel_blocks.append(nn.PixelUnshuffle(current_downscale))
|
|
|
|
# (bs, 768, 56, 56) -> (bs, 192, 112, 112)
|
|
# (bs, 192, 112, 112) -> (bs, 48, 224, 224)
|
|
|
|
self.conv_out = nn.Conv2d(
|
|
current_channels,
|
|
out_channels=3,
|
|
kernel_size=3,
|
|
padding=1
|
|
) # (bs, 48, 224, 224) -> (bs, 3, 224, 224)
|
|
|
|
# do a pooling layer to downscale the input to 1/3 of the size
|
|
# (bs, 3, 896, 896) -> (bs, 3, 224, 224)
|
|
kernel_size = input_size // clip_input_size
|
|
self.res_down = nn.AvgPool2d(
|
|
kernel_size=kernel_size,
|
|
stride=kernel_size
|
|
) # (bs, 3, 896, 896) -> (bs, 3, 224, 224)
|
|
|
|
# make a blending for output residual with near 0 weight
|
|
self.res_blend = nn.Parameter(torch.tensor(0.001)) # (bs, 3, 224, 224) -> (bs, 3, 224, 224)
|
|
|
|
self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 896, 896) -> (bs, 768, 56, 56)
|
|
|
|
self.conv_in = nn.Sequential(
|
|
nn.Conv2d(
|
|
subpixel_channels,
|
|
channels,
|
|
kernel_size=3,
|
|
padding=1
|
|
),
|
|
nn.GELU()
|
|
) # (bs, 768, 56, 56) -> (bs, 768, 56, 56)
|
|
|
|
# make 2 deep blocks
|
|
|
|
def forward(self, x):
|
|
inputs = x
|
|
# resize to input_size x input_size
|
|
x = nn.functional.interpolate(x, size=(self.input_size, self.input_size), mode='bicubic')
|
|
|
|
res = self.res_down(inputs)
|
|
|
|
x = self.unshuffle(x)
|
|
x = self.conv_in(x)
|
|
for up, subpixel in zip(self.upsample_blocks, self.subpixel_blocks):
|
|
x = up(x)
|
|
block_res = subpixel(inputs)
|
|
x = x + block_res
|
|
x = self.conv_out(x)
|
|
# blend residual
|
|
x = x * self.res_blend + res
|
|
return x
|