mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-14 06:57:35 +00:00
Adjustments to the clip preprocessor. Allow merging in new weights for ip adapters so you can change the arcitecture while maintaining as much data as possible
This commit is contained in:
@@ -34,10 +34,9 @@ class UpsampleBlock(nn.Module):
|
||||
class CLIPImagePreProcessor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_size=672,
|
||||
input_size=896,
|
||||
clip_input_size=224,
|
||||
downscale_factor: int = 6,
|
||||
channels=None, # 108
|
||||
downscale_factor: int = 16,
|
||||
):
|
||||
super().__init__()
|
||||
# make sure they are evenly divisible
|
||||
@@ -48,27 +47,50 @@ class CLIPImagePreProcessor(nn.Module):
|
||||
self.clip_input_size = clip_input_size
|
||||
self.downscale_factor = downscale_factor
|
||||
|
||||
subpixel_channels = in_channels * downscale_factor ** 2 # 3 * 6 ** 2 = 108
|
||||
subpixel_channels = in_channels * downscale_factor ** 2 # 3 * 16 ** 2 = 768
|
||||
channels = subpixel_channels
|
||||
|
||||
if channels is None:
|
||||
channels = subpixel_channels
|
||||
upscale_factor = downscale_factor / int((input_size / clip_input_size)) # 16 / (896 / 224) = 4
|
||||
|
||||
upscale_factor = downscale_factor / int((input_size / clip_input_size)) # 6 / (672 / 224) = 2
|
||||
num_upsample_blocks = int(upscale_factor // 2) # 4 // 2 = 2
|
||||
|
||||
num_upsample_blocks = int(upscale_factor // 2) # 2 // 2 = 1
|
||||
# make the residual down up blocks
|
||||
self.upsample_blocks = nn.ModuleList()
|
||||
self.subpixel_blocks = nn.ModuleList()
|
||||
current_channels = channels
|
||||
current_downscale = downscale_factor
|
||||
for _ in range(num_upsample_blocks):
|
||||
# determine the reshuffled channel count for this dimension
|
||||
output_downscale = current_downscale // 2
|
||||
out_channels = in_channels * output_downscale ** 2
|
||||
# out_channels = current_channels // 2
|
||||
self.upsample_blocks.append(UpsampleBlock(current_channels, out_channels))
|
||||
current_channels = out_channels
|
||||
current_downscale = output_downscale
|
||||
self.subpixel_blocks.append(nn.PixelUnshuffle(current_downscale))
|
||||
|
||||
# (bs, 768, 56, 56) -> (bs, 192, 112, 112)
|
||||
# (bs, 192, 112, 112) -> (bs, 48, 224, 224)
|
||||
|
||||
self.conv_out = nn.Conv2d(
|
||||
current_channels,
|
||||
out_channels=3,
|
||||
kernel_size=3,
|
||||
padding=1
|
||||
) # (bs, 48, 224, 224) -> (bs, 3, 224, 224)
|
||||
|
||||
# do a pooling layer to downscale the input to 1/3 of the size
|
||||
# (bs, 3, 672, 672) -> (bs, 3, 224, 224)
|
||||
# (bs, 3, 896, 896) -> (bs, 3, 224, 224)
|
||||
kernel_size = input_size // clip_input_size
|
||||
self.res_down = nn.AvgPool2d(
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size
|
||||
) # (bs, 3, 672, 672) -> (bs, 3, 224, 224)
|
||||
) # (bs, 3, 896, 896) -> (bs, 3, 224, 224)
|
||||
|
||||
# make a blending for output residual with near 0 weight
|
||||
self.res_blend = nn.Parameter(torch.tensor(0.001)) # (bs, 3, 224, 224) -> (bs, 3, 224, 224)
|
||||
|
||||
self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 672, 672) -> (bs, 108, 112, 112)
|
||||
self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 896, 896) -> (bs, 768, 56, 56)
|
||||
|
||||
self.conv_in = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
@@ -78,35 +100,23 @@ class CLIPImagePreProcessor(nn.Module):
|
||||
padding=1
|
||||
),
|
||||
nn.GELU()
|
||||
) # (bs, 108, 112, 112) -> (bs, 108, 112, 112)
|
||||
|
||||
self.upsample_blocks = nn.ModuleList()
|
||||
current_channels = channels
|
||||
for _ in range(num_upsample_blocks):
|
||||
out_channels = current_channels // 2
|
||||
self.upsample_blocks.append(UpsampleBlock(current_channels, out_channels))
|
||||
current_channels = out_channels
|
||||
|
||||
# (bs, 108, 112, 112) -> (bs, 54, 224, 224)
|
||||
|
||||
self.conv_out = nn.Conv2d(
|
||||
current_channels,
|
||||
out_channels=3,
|
||||
kernel_size=3,
|
||||
padding=1
|
||||
) # (bs, 54, 224, 224) -> (bs, 3, 224, 224)
|
||||
) # (bs, 768, 56, 56) -> (bs, 768, 56, 56)
|
||||
|
||||
# make 2 deep blocks
|
||||
|
||||
def forward(self, x):
|
||||
inputs = x
|
||||
# resize to input_size x input_size
|
||||
x = nn.functional.interpolate(x, size=(self.input_size, self.input_size), mode='bicubic')
|
||||
|
||||
res = self.res_down(x)
|
||||
res = self.res_down(inputs)
|
||||
|
||||
x = self.unshuffle(x)
|
||||
x = self.conv_in(x)
|
||||
for up in self.upsample_blocks:
|
||||
for up, subpixel in zip(self.upsample_blocks, self.subpixel_blocks):
|
||||
x = up(x)
|
||||
block_res = subpixel(inputs)
|
||||
x = x + block_res
|
||||
x = self.conv_out(x)
|
||||
# blend residual
|
||||
x = x * self.res_blend + res
|
||||
|
||||
Reference in New Issue
Block a user