Adjustments to the clip preprocessor. Allow merging in new weights for ip adapters so you can change the arcitecture while maintaining as much data as possible

This commit is contained in:
Jaret Burkett
2024-01-06 11:56:53 -07:00
parent 645b27f97a
commit b767d29b3c
3 changed files with 99 additions and 36 deletions

View File

@@ -34,10 +34,9 @@ class UpsampleBlock(nn.Module):
class CLIPImagePreProcessor(nn.Module):
def __init__(
self,
input_size=672,
input_size=896,
clip_input_size=224,
downscale_factor: int = 6,
channels=None, # 108
downscale_factor: int = 16,
):
super().__init__()
# make sure they are evenly divisible
@@ -48,27 +47,50 @@ class CLIPImagePreProcessor(nn.Module):
self.clip_input_size = clip_input_size
self.downscale_factor = downscale_factor
subpixel_channels = in_channels * downscale_factor ** 2 # 3 * 6 ** 2 = 108
subpixel_channels = in_channels * downscale_factor ** 2 # 3 * 16 ** 2 = 768
channels = subpixel_channels
if channels is None:
channels = subpixel_channels
upscale_factor = downscale_factor / int((input_size / clip_input_size)) # 16 / (896 / 224) = 4
upscale_factor = downscale_factor / int((input_size / clip_input_size)) # 6 / (672 / 224) = 2
num_upsample_blocks = int(upscale_factor // 2) # 4 // 2 = 2
num_upsample_blocks = int(upscale_factor // 2) # 2 // 2 = 1
# make the residual down up blocks
self.upsample_blocks = nn.ModuleList()
self.subpixel_blocks = nn.ModuleList()
current_channels = channels
current_downscale = downscale_factor
for _ in range(num_upsample_blocks):
# determine the reshuffled channel count for this dimension
output_downscale = current_downscale // 2
out_channels = in_channels * output_downscale ** 2
# out_channels = current_channels // 2
self.upsample_blocks.append(UpsampleBlock(current_channels, out_channels))
current_channels = out_channels
current_downscale = output_downscale
self.subpixel_blocks.append(nn.PixelUnshuffle(current_downscale))
# (bs, 768, 56, 56) -> (bs, 192, 112, 112)
# (bs, 192, 112, 112) -> (bs, 48, 224, 224)
self.conv_out = nn.Conv2d(
current_channels,
out_channels=3,
kernel_size=3,
padding=1
) # (bs, 48, 224, 224) -> (bs, 3, 224, 224)
# do a pooling layer to downscale the input to 1/3 of the size
# (bs, 3, 672, 672) -> (bs, 3, 224, 224)
# (bs, 3, 896, 896) -> (bs, 3, 224, 224)
kernel_size = input_size // clip_input_size
self.res_down = nn.AvgPool2d(
kernel_size=kernel_size,
stride=kernel_size
) # (bs, 3, 672, 672) -> (bs, 3, 224, 224)
) # (bs, 3, 896, 896) -> (bs, 3, 224, 224)
# make a blending for output residual with near 0 weight
self.res_blend = nn.Parameter(torch.tensor(0.001)) # (bs, 3, 224, 224) -> (bs, 3, 224, 224)
self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 672, 672) -> (bs, 108, 112, 112)
self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 896, 896) -> (bs, 768, 56, 56)
self.conv_in = nn.Sequential(
nn.Conv2d(
@@ -78,35 +100,23 @@ class CLIPImagePreProcessor(nn.Module):
padding=1
),
nn.GELU()
) # (bs, 108, 112, 112) -> (bs, 108, 112, 112)
self.upsample_blocks = nn.ModuleList()
current_channels = channels
for _ in range(num_upsample_blocks):
out_channels = current_channels // 2
self.upsample_blocks.append(UpsampleBlock(current_channels, out_channels))
current_channels = out_channels
# (bs, 108, 112, 112) -> (bs, 54, 224, 224)
self.conv_out = nn.Conv2d(
current_channels,
out_channels=3,
kernel_size=3,
padding=1
) # (bs, 54, 224, 224) -> (bs, 3, 224, 224)
) # (bs, 768, 56, 56) -> (bs, 768, 56, 56)
# make 2 deep blocks
def forward(self, x):
inputs = x
# resize to input_size x input_size
x = nn.functional.interpolate(x, size=(self.input_size, self.input_size), mode='bicubic')
res = self.res_down(x)
res = self.res_down(inputs)
x = self.unshuffle(x)
x = self.conv_in(x)
for up in self.upsample_blocks:
for up, subpixel in zip(self.upsample_blocks, self.subpixel_blocks):
x = up(x)
block_res = subpixel(inputs)
x = x + block_res
x = self.conv_out(x)
# blend residual
x = x * self.res_blend + res