Adjustments to the clip preprocessor. Allow merging in new weights for ip adapters so you can change the arcitecture while maintaining as much data as possible

This commit is contained in:
Jaret Burkett
2024-01-06 11:56:53 -07:00
parent 645b27f97a
commit b767d29b3c
3 changed files with 99 additions and 36 deletions

View File

@@ -230,7 +230,7 @@ class IPAdapter(torch.nn.Module):
if self.config.image_encoder_arch == 'clip+':
# self.clip_image_processor.config
# We do a 3x downscale of the image, so we need to adjust the input size
preprocessor_input_size = self.image_encoder.config.image_size * 3
preprocessor_input_size = self.image_encoder.config.image_size * 4
# update the preprocessor so images come in at the right size
self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size
@@ -240,7 +240,6 @@ class IPAdapter(torch.nn.Module):
self.preprocessor = CLIPImagePreProcessor(
input_size=preprocessor_input_size,
clip_input_size=self.image_encoder.config.image_size,
downscale_factor=6
)
self.input_size = self.clip_image_processor.size['shortest_edge']
@@ -454,13 +453,68 @@ class IPAdapter(torch.nn.Module):
if self.preprocessor is not None:
yield from self.preprocessor.parameters(recurse)
def merge_in_weights(self, state_dict: Mapping[str, Any]):
# merge in img_proj weights
current_img_proj_state_dict = self.image_proj_model.state_dict()
for key, value in state_dict["image_proj"].items():
if key in current_img_proj_state_dict:
current_shape = current_img_proj_state_dict[key].shape
new_shape = value.shape
if current_shape != new_shape:
# merge in what we can and leave the other values as they are
if len(current_shape) == 1:
current_img_proj_state_dict[key][:new_shape[0]] = value
elif len(current_shape) == 2:
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1]] = value
elif len(current_shape) == 3:
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value
elif len(current_shape) == 4:
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2],
:new_shape[3]] = value
else:
raise ValueError(f"unknown shape: {current_shape}")
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
else:
current_img_proj_state_dict[key] = value
self.image_proj_model.load_state_dict(current_img_proj_state_dict)
# merge in ip adapter weights
current_ip_adapter_state_dict = self.adapter_modules.state_dict()
for key, value in state_dict["ip_adapter"].items():
if key in current_ip_adapter_state_dict:
current_shape = current_ip_adapter_state_dict[key].shape
new_shape = value.shape
if current_shape != new_shape:
# merge in what we can and leave the other values as they are
if len(current_shape) == 1:
current_ip_adapter_state_dict[key][:new_shape[0]] = value
elif len(current_shape) == 2:
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1]] = value
elif len(current_shape) == 3:
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value
elif len(current_shape) == 4:
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2],
:new_shape[3]] = value
else:
raise ValueError(f"unknown shape: {current_shape}")
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
else:
current_ip_adapter_state_dict[key] = value
self.adapter_modules.load_state_dict(current_ip_adapter_state_dict)
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
strict = False
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
try:
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
except Exception as e:
print(e)
print("could not load ip adapter weights, trying to merge in weights")
self.merge_in_weights(state_dict)
if self.config.train_image_encoder and 'image_encoder' in state_dict:
self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict)
if self.config.image_encoder_arch == 'clip+' and 'preprocessor' in state_dict:
if self.preprocessor is not None and 'preprocessor' in state_dict:
self.preprocessor.load_state_dict(state_dict["preprocessor"], strict=strict)
def enable_gradient_checkpointing(self):

View File

@@ -34,10 +34,9 @@ class UpsampleBlock(nn.Module):
class CLIPImagePreProcessor(nn.Module):
def __init__(
self,
input_size=672,
input_size=896,
clip_input_size=224,
downscale_factor: int = 6,
channels=None, # 108
downscale_factor: int = 16,
):
super().__init__()
# make sure they are evenly divisible
@@ -48,27 +47,50 @@ class CLIPImagePreProcessor(nn.Module):
self.clip_input_size = clip_input_size
self.downscale_factor = downscale_factor
subpixel_channels = in_channels * downscale_factor ** 2 # 3 * 6 ** 2 = 108
subpixel_channels = in_channels * downscale_factor ** 2 # 3 * 16 ** 2 = 768
channels = subpixel_channels
if channels is None:
channels = subpixel_channels
upscale_factor = downscale_factor / int((input_size / clip_input_size)) # 16 / (896 / 224) = 4
upscale_factor = downscale_factor / int((input_size / clip_input_size)) # 6 / (672 / 224) = 2
num_upsample_blocks = int(upscale_factor // 2) # 4 // 2 = 2
num_upsample_blocks = int(upscale_factor // 2) # 2 // 2 = 1
# make the residual down up blocks
self.upsample_blocks = nn.ModuleList()
self.subpixel_blocks = nn.ModuleList()
current_channels = channels
current_downscale = downscale_factor
for _ in range(num_upsample_blocks):
# determine the reshuffled channel count for this dimension
output_downscale = current_downscale // 2
out_channels = in_channels * output_downscale ** 2
# out_channels = current_channels // 2
self.upsample_blocks.append(UpsampleBlock(current_channels, out_channels))
current_channels = out_channels
current_downscale = output_downscale
self.subpixel_blocks.append(nn.PixelUnshuffle(current_downscale))
# (bs, 768, 56, 56) -> (bs, 192, 112, 112)
# (bs, 192, 112, 112) -> (bs, 48, 224, 224)
self.conv_out = nn.Conv2d(
current_channels,
out_channels=3,
kernel_size=3,
padding=1
) # (bs, 48, 224, 224) -> (bs, 3, 224, 224)
# do a pooling layer to downscale the input to 1/3 of the size
# (bs, 3, 672, 672) -> (bs, 3, 224, 224)
# (bs, 3, 896, 896) -> (bs, 3, 224, 224)
kernel_size = input_size // clip_input_size
self.res_down = nn.AvgPool2d(
kernel_size=kernel_size,
stride=kernel_size
) # (bs, 3, 672, 672) -> (bs, 3, 224, 224)
) # (bs, 3, 896, 896) -> (bs, 3, 224, 224)
# make a blending for output residual with near 0 weight
self.res_blend = nn.Parameter(torch.tensor(0.001)) # (bs, 3, 224, 224) -> (bs, 3, 224, 224)
self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 672, 672) -> (bs, 108, 112, 112)
self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 896, 896) -> (bs, 768, 56, 56)
self.conv_in = nn.Sequential(
nn.Conv2d(
@@ -78,35 +100,23 @@ class CLIPImagePreProcessor(nn.Module):
padding=1
),
nn.GELU()
) # (bs, 108, 112, 112) -> (bs, 108, 112, 112)
self.upsample_blocks = nn.ModuleList()
current_channels = channels
for _ in range(num_upsample_blocks):
out_channels = current_channels // 2
self.upsample_blocks.append(UpsampleBlock(current_channels, out_channels))
current_channels = out_channels
# (bs, 108, 112, 112) -> (bs, 54, 224, 224)
self.conv_out = nn.Conv2d(
current_channels,
out_channels=3,
kernel_size=3,
padding=1
) # (bs, 54, 224, 224) -> (bs, 3, 224, 224)
) # (bs, 768, 56, 56) -> (bs, 768, 56, 56)
# make 2 deep blocks
def forward(self, x):
inputs = x
# resize to input_size x input_size
x = nn.functional.interpolate(x, size=(self.input_size, self.input_size), mode='bicubic')
res = self.res_down(x)
res = self.res_down(inputs)
x = self.unshuffle(x)
x = self.conv_in(x)
for up in self.upsample_blocks:
for up, subpixel in zip(self.upsample_blocks, self.subpixel_blocks):
x = up(x)
block_res = subpixel(inputs)
x = x + block_res
x = self.conv_out(x)
# blend residual
x = x * self.res_blend + res

View File

@@ -208,7 +208,6 @@ def load_t2i_model(
return converted_state_dict
IP_ADAPTER_MODULES = ['image_proj', 'ip_adapter']
def save_ip_adapter_from_diffusers(