mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Adjustments to the clip preprocessor. Allow merging in new weights for ip adapters so you can change the arcitecture while maintaining as much data as possible
This commit is contained in:
@@ -230,7 +230,7 @@ class IPAdapter(torch.nn.Module):
|
||||
if self.config.image_encoder_arch == 'clip+':
|
||||
# self.clip_image_processor.config
|
||||
# We do a 3x downscale of the image, so we need to adjust the input size
|
||||
preprocessor_input_size = self.image_encoder.config.image_size * 3
|
||||
preprocessor_input_size = self.image_encoder.config.image_size * 4
|
||||
|
||||
# update the preprocessor so images come in at the right size
|
||||
self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size
|
||||
@@ -240,7 +240,6 @@ class IPAdapter(torch.nn.Module):
|
||||
self.preprocessor = CLIPImagePreProcessor(
|
||||
input_size=preprocessor_input_size,
|
||||
clip_input_size=self.image_encoder.config.image_size,
|
||||
downscale_factor=6
|
||||
)
|
||||
|
||||
self.input_size = self.clip_image_processor.size['shortest_edge']
|
||||
@@ -454,13 +453,68 @@ class IPAdapter(torch.nn.Module):
|
||||
if self.preprocessor is not None:
|
||||
yield from self.preprocessor.parameters(recurse)
|
||||
|
||||
def merge_in_weights(self, state_dict: Mapping[str, Any]):
|
||||
# merge in img_proj weights
|
||||
current_img_proj_state_dict = self.image_proj_model.state_dict()
|
||||
for key, value in state_dict["image_proj"].items():
|
||||
if key in current_img_proj_state_dict:
|
||||
current_shape = current_img_proj_state_dict[key].shape
|
||||
new_shape = value.shape
|
||||
if current_shape != new_shape:
|
||||
# merge in what we can and leave the other values as they are
|
||||
if len(current_shape) == 1:
|
||||
current_img_proj_state_dict[key][:new_shape[0]] = value
|
||||
elif len(current_shape) == 2:
|
||||
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1]] = value
|
||||
elif len(current_shape) == 3:
|
||||
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value
|
||||
elif len(current_shape) == 4:
|
||||
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2],
|
||||
:new_shape[3]] = value
|
||||
else:
|
||||
raise ValueError(f"unknown shape: {current_shape}")
|
||||
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
|
||||
else:
|
||||
current_img_proj_state_dict[key] = value
|
||||
self.image_proj_model.load_state_dict(current_img_proj_state_dict)
|
||||
|
||||
# merge in ip adapter weights
|
||||
current_ip_adapter_state_dict = self.adapter_modules.state_dict()
|
||||
for key, value in state_dict["ip_adapter"].items():
|
||||
if key in current_ip_adapter_state_dict:
|
||||
current_shape = current_ip_adapter_state_dict[key].shape
|
||||
new_shape = value.shape
|
||||
if current_shape != new_shape:
|
||||
# merge in what we can and leave the other values as they are
|
||||
if len(current_shape) == 1:
|
||||
current_ip_adapter_state_dict[key][:new_shape[0]] = value
|
||||
elif len(current_shape) == 2:
|
||||
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1]] = value
|
||||
elif len(current_shape) == 3:
|
||||
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value
|
||||
elif len(current_shape) == 4:
|
||||
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2],
|
||||
:new_shape[3]] = value
|
||||
else:
|
||||
raise ValueError(f"unknown shape: {current_shape}")
|
||||
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
|
||||
else:
|
||||
current_ip_adapter_state_dict[key] = value
|
||||
self.adapter_modules.load_state_dict(current_ip_adapter_state_dict)
|
||||
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
strict = False
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
|
||||
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
|
||||
try:
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
|
||||
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("could not load ip adapter weights, trying to merge in weights")
|
||||
self.merge_in_weights(state_dict)
|
||||
if self.config.train_image_encoder and 'image_encoder' in state_dict:
|
||||
self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict)
|
||||
if self.config.image_encoder_arch == 'clip+' and 'preprocessor' in state_dict:
|
||||
if self.preprocessor is not None and 'preprocessor' in state_dict:
|
||||
self.preprocessor.load_state_dict(state_dict["preprocessor"], strict=strict)
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
|
||||
@@ -34,10 +34,9 @@ class UpsampleBlock(nn.Module):
|
||||
class CLIPImagePreProcessor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_size=672,
|
||||
input_size=896,
|
||||
clip_input_size=224,
|
||||
downscale_factor: int = 6,
|
||||
channels=None, # 108
|
||||
downscale_factor: int = 16,
|
||||
):
|
||||
super().__init__()
|
||||
# make sure they are evenly divisible
|
||||
@@ -48,27 +47,50 @@ class CLIPImagePreProcessor(nn.Module):
|
||||
self.clip_input_size = clip_input_size
|
||||
self.downscale_factor = downscale_factor
|
||||
|
||||
subpixel_channels = in_channels * downscale_factor ** 2 # 3 * 6 ** 2 = 108
|
||||
subpixel_channels = in_channels * downscale_factor ** 2 # 3 * 16 ** 2 = 768
|
||||
channels = subpixel_channels
|
||||
|
||||
if channels is None:
|
||||
channels = subpixel_channels
|
||||
upscale_factor = downscale_factor / int((input_size / clip_input_size)) # 16 / (896 / 224) = 4
|
||||
|
||||
upscale_factor = downscale_factor / int((input_size / clip_input_size)) # 6 / (672 / 224) = 2
|
||||
num_upsample_blocks = int(upscale_factor // 2) # 4 // 2 = 2
|
||||
|
||||
num_upsample_blocks = int(upscale_factor // 2) # 2 // 2 = 1
|
||||
# make the residual down up blocks
|
||||
self.upsample_blocks = nn.ModuleList()
|
||||
self.subpixel_blocks = nn.ModuleList()
|
||||
current_channels = channels
|
||||
current_downscale = downscale_factor
|
||||
for _ in range(num_upsample_blocks):
|
||||
# determine the reshuffled channel count for this dimension
|
||||
output_downscale = current_downscale // 2
|
||||
out_channels = in_channels * output_downscale ** 2
|
||||
# out_channels = current_channels // 2
|
||||
self.upsample_blocks.append(UpsampleBlock(current_channels, out_channels))
|
||||
current_channels = out_channels
|
||||
current_downscale = output_downscale
|
||||
self.subpixel_blocks.append(nn.PixelUnshuffle(current_downscale))
|
||||
|
||||
# (bs, 768, 56, 56) -> (bs, 192, 112, 112)
|
||||
# (bs, 192, 112, 112) -> (bs, 48, 224, 224)
|
||||
|
||||
self.conv_out = nn.Conv2d(
|
||||
current_channels,
|
||||
out_channels=3,
|
||||
kernel_size=3,
|
||||
padding=1
|
||||
) # (bs, 48, 224, 224) -> (bs, 3, 224, 224)
|
||||
|
||||
# do a pooling layer to downscale the input to 1/3 of the size
|
||||
# (bs, 3, 672, 672) -> (bs, 3, 224, 224)
|
||||
# (bs, 3, 896, 896) -> (bs, 3, 224, 224)
|
||||
kernel_size = input_size // clip_input_size
|
||||
self.res_down = nn.AvgPool2d(
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size
|
||||
) # (bs, 3, 672, 672) -> (bs, 3, 224, 224)
|
||||
) # (bs, 3, 896, 896) -> (bs, 3, 224, 224)
|
||||
|
||||
# make a blending for output residual with near 0 weight
|
||||
self.res_blend = nn.Parameter(torch.tensor(0.001)) # (bs, 3, 224, 224) -> (bs, 3, 224, 224)
|
||||
|
||||
self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 672, 672) -> (bs, 108, 112, 112)
|
||||
self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 896, 896) -> (bs, 768, 56, 56)
|
||||
|
||||
self.conv_in = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
@@ -78,35 +100,23 @@ class CLIPImagePreProcessor(nn.Module):
|
||||
padding=1
|
||||
),
|
||||
nn.GELU()
|
||||
) # (bs, 108, 112, 112) -> (bs, 108, 112, 112)
|
||||
|
||||
self.upsample_blocks = nn.ModuleList()
|
||||
current_channels = channels
|
||||
for _ in range(num_upsample_blocks):
|
||||
out_channels = current_channels // 2
|
||||
self.upsample_blocks.append(UpsampleBlock(current_channels, out_channels))
|
||||
current_channels = out_channels
|
||||
|
||||
# (bs, 108, 112, 112) -> (bs, 54, 224, 224)
|
||||
|
||||
self.conv_out = nn.Conv2d(
|
||||
current_channels,
|
||||
out_channels=3,
|
||||
kernel_size=3,
|
||||
padding=1
|
||||
) # (bs, 54, 224, 224) -> (bs, 3, 224, 224)
|
||||
) # (bs, 768, 56, 56) -> (bs, 768, 56, 56)
|
||||
|
||||
# make 2 deep blocks
|
||||
|
||||
def forward(self, x):
|
||||
inputs = x
|
||||
# resize to input_size x input_size
|
||||
x = nn.functional.interpolate(x, size=(self.input_size, self.input_size), mode='bicubic')
|
||||
|
||||
res = self.res_down(x)
|
||||
res = self.res_down(inputs)
|
||||
|
||||
x = self.unshuffle(x)
|
||||
x = self.conv_in(x)
|
||||
for up in self.upsample_blocks:
|
||||
for up, subpixel in zip(self.upsample_blocks, self.subpixel_blocks):
|
||||
x = up(x)
|
||||
block_res = subpixel(inputs)
|
||||
x = x + block_res
|
||||
x = self.conv_out(x)
|
||||
# blend residual
|
||||
x = x * self.res_blend + res
|
||||
|
||||
@@ -208,7 +208,6 @@ def load_t2i_model(
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
IP_ADAPTER_MODULES = ['image_proj', 'ip_adapter']
|
||||
|
||||
|
||||
def save_ip_adapter_from_diffusers(
|
||||
|
||||
Reference in New Issue
Block a user