mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 09:44:02 +00:00
Bug fixes with ip adapter training. Made a clip pre processor that can be trained with ip adapter to help augment the clip input to squeeze in more detail from a larget input. moved clip processing to the dataloader for speed.
This commit is contained in:
113
toolkit/models/clip_pre_processor.py
Normal file
113
toolkit/models/clip_pre_processor.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class UpsampleBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.conv_in = nn.Sequential(
|
||||
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
|
||||
nn.GELU()
|
||||
)
|
||||
self.conv_up = nn.Sequential(
|
||||
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
|
||||
nn.GELU()
|
||||
)
|
||||
|
||||
self.conv_out = nn.Sequential(
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
x = self.conv_up(x)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class CLIPImagePreProcessor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_size=672,
|
||||
clip_input_size=224,
|
||||
downscale_factor: int = 6,
|
||||
channels=None, # 108
|
||||
):
|
||||
super().__init__()
|
||||
# make sure they are evenly divisible
|
||||
assert input_size % clip_input_size == 0
|
||||
in_channels = 3
|
||||
|
||||
self.input_size = input_size
|
||||
self.clip_input_size = clip_input_size
|
||||
self.downscale_factor = downscale_factor
|
||||
|
||||
subpixel_channels = in_channels * downscale_factor ** 2 # 3 * 6 ** 2 = 108
|
||||
|
||||
if channels is None:
|
||||
channels = subpixel_channels
|
||||
|
||||
upscale_factor = downscale_factor / int((input_size / clip_input_size)) # 6 / (672 / 224) = 2
|
||||
|
||||
num_upsample_blocks = int(upscale_factor // 2) # 2 // 2 = 1
|
||||
|
||||
# do a pooling layer to downscale the input to 1/3 of the size
|
||||
# (bs, 3, 672, 672) -> (bs, 3, 224, 224)
|
||||
kernel_size = input_size // clip_input_size
|
||||
self.res_down = nn.AvgPool2d(
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size
|
||||
) # (bs, 3, 672, 672) -> (bs, 3, 224, 224)
|
||||
|
||||
# make a blending for output residual with near 0 weight
|
||||
self.res_blend = nn.Parameter(torch.tensor(0.001)) # (bs, 3, 224, 224) -> (bs, 3, 224, 224)
|
||||
|
||||
self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 672, 672) -> (bs, 108, 112, 112)
|
||||
|
||||
self.conv_in = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
subpixel_channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=1
|
||||
),
|
||||
nn.GELU()
|
||||
) # (bs, 108, 112, 112) -> (bs, 108, 112, 112)
|
||||
|
||||
self.upsample_blocks = nn.ModuleList()
|
||||
current_channels = channels
|
||||
for _ in range(num_upsample_blocks):
|
||||
out_channels = current_channels // 2
|
||||
self.upsample_blocks.append(UpsampleBlock(current_channels, out_channels))
|
||||
current_channels = out_channels
|
||||
|
||||
# (bs, 108, 112, 112) -> (bs, 54, 224, 224)
|
||||
|
||||
self.conv_out = nn.Conv2d(
|
||||
current_channels,
|
||||
out_channels=3,
|
||||
kernel_size=3,
|
||||
padding=1
|
||||
) # (bs, 54, 224, 224) -> (bs, 3, 224, 224)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
# resize to input_size x input_size
|
||||
x = nn.functional.interpolate(x, size=(self.input_size, self.input_size), mode='bicubic')
|
||||
|
||||
res = self.res_down(x)
|
||||
|
||||
x = self.unshuffle(x)
|
||||
x = self.conv_in(x)
|
||||
for up in self.upsample_blocks:
|
||||
x = up(x)
|
||||
x = self.conv_out(x)
|
||||
# blend residual
|
||||
x = x * self.res_blend + res
|
||||
return x
|
||||
Reference in New Issue
Block a user