From b767d29b3cc6e1966d462a76ca6928e72d84df1b Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 6 Jan 2024 11:56:53 -0700 Subject: [PATCH] Adjustments to the clip preprocessor. Allow merging in new weights for ip adapters so you can change the arcitecture while maintaining as much data as possible --- toolkit/ip_adapter.py | 64 +++++++++++++++++++++++-- toolkit/models/clip_pre_processor.py | 70 ++++++++++++++++------------ toolkit/saving.py | 1 - 3 files changed, 99 insertions(+), 36 deletions(-) diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 3c0b80d1..abae9d79 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -230,7 +230,7 @@ class IPAdapter(torch.nn.Module): if self.config.image_encoder_arch == 'clip+': # self.clip_image_processor.config # We do a 3x downscale of the image, so we need to adjust the input size - preprocessor_input_size = self.image_encoder.config.image_size * 3 + preprocessor_input_size = self.image_encoder.config.image_size * 4 # update the preprocessor so images come in at the right size self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size @@ -240,7 +240,6 @@ class IPAdapter(torch.nn.Module): self.preprocessor = CLIPImagePreProcessor( input_size=preprocessor_input_size, clip_input_size=self.image_encoder.config.image_size, - downscale_factor=6 ) self.input_size = self.clip_image_processor.size['shortest_edge'] @@ -454,13 +453,68 @@ class IPAdapter(torch.nn.Module): if self.preprocessor is not None: yield from self.preprocessor.parameters(recurse) + def merge_in_weights(self, state_dict: Mapping[str, Any]): + # merge in img_proj weights + current_img_proj_state_dict = self.image_proj_model.state_dict() + for key, value in state_dict["image_proj"].items(): + if key in current_img_proj_state_dict: + current_shape = current_img_proj_state_dict[key].shape + new_shape = value.shape + if current_shape != new_shape: + # merge in what we can and leave the other values as they are + if len(current_shape) == 1: + current_img_proj_state_dict[key][:new_shape[0]] = value + elif len(current_shape) == 2: + current_img_proj_state_dict[key][:new_shape[0], :new_shape[1]] = value + elif len(current_shape) == 3: + current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value + elif len(current_shape) == 4: + current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2], + :new_shape[3]] = value + else: + raise ValueError(f"unknown shape: {current_shape}") + print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}") + else: + current_img_proj_state_dict[key] = value + self.image_proj_model.load_state_dict(current_img_proj_state_dict) + + # merge in ip adapter weights + current_ip_adapter_state_dict = self.adapter_modules.state_dict() + for key, value in state_dict["ip_adapter"].items(): + if key in current_ip_adapter_state_dict: + current_shape = current_ip_adapter_state_dict[key].shape + new_shape = value.shape + if current_shape != new_shape: + # merge in what we can and leave the other values as they are + if len(current_shape) == 1: + current_ip_adapter_state_dict[key][:new_shape[0]] = value + elif len(current_shape) == 2: + current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1]] = value + elif len(current_shape) == 3: + current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value + elif len(current_shape) == 4: + current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2], + :new_shape[3]] = value + else: + raise ValueError(f"unknown shape: {current_shape}") + print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}") + else: + current_ip_adapter_state_dict[key] = value + self.adapter_modules.load_state_dict(current_ip_adapter_state_dict) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): strict = False - self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict) - self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict) + try: + self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict) + self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict) + except Exception as e: + print(e) + print("could not load ip adapter weights, trying to merge in weights") + self.merge_in_weights(state_dict) if self.config.train_image_encoder and 'image_encoder' in state_dict: self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict) - if self.config.image_encoder_arch == 'clip+' and 'preprocessor' in state_dict: + if self.preprocessor is not None and 'preprocessor' in state_dict: self.preprocessor.load_state_dict(state_dict["preprocessor"], strict=strict) def enable_gradient_checkpointing(self): diff --git a/toolkit/models/clip_pre_processor.py b/toolkit/models/clip_pre_processor.py index 851b0f18..7956da0b 100644 --- a/toolkit/models/clip_pre_processor.py +++ b/toolkit/models/clip_pre_processor.py @@ -34,10 +34,9 @@ class UpsampleBlock(nn.Module): class CLIPImagePreProcessor(nn.Module): def __init__( self, - input_size=672, + input_size=896, clip_input_size=224, - downscale_factor: int = 6, - channels=None, # 108 + downscale_factor: int = 16, ): super().__init__() # make sure they are evenly divisible @@ -48,27 +47,50 @@ class CLIPImagePreProcessor(nn.Module): self.clip_input_size = clip_input_size self.downscale_factor = downscale_factor - subpixel_channels = in_channels * downscale_factor ** 2 # 3 * 6 ** 2 = 108 + subpixel_channels = in_channels * downscale_factor ** 2 # 3 * 16 ** 2 = 768 + channels = subpixel_channels - if channels is None: - channels = subpixel_channels + upscale_factor = downscale_factor / int((input_size / clip_input_size)) # 16 / (896 / 224) = 4 - upscale_factor = downscale_factor / int((input_size / clip_input_size)) # 6 / (672 / 224) = 2 + num_upsample_blocks = int(upscale_factor // 2) # 4 // 2 = 2 - num_upsample_blocks = int(upscale_factor // 2) # 2 // 2 = 1 + # make the residual down up blocks + self.upsample_blocks = nn.ModuleList() + self.subpixel_blocks = nn.ModuleList() + current_channels = channels + current_downscale = downscale_factor + for _ in range(num_upsample_blocks): + # determine the reshuffled channel count for this dimension + output_downscale = current_downscale // 2 + out_channels = in_channels * output_downscale ** 2 + # out_channels = current_channels // 2 + self.upsample_blocks.append(UpsampleBlock(current_channels, out_channels)) + current_channels = out_channels + current_downscale = output_downscale + self.subpixel_blocks.append(nn.PixelUnshuffle(current_downscale)) + + # (bs, 768, 56, 56) -> (bs, 192, 112, 112) + # (bs, 192, 112, 112) -> (bs, 48, 224, 224) + + self.conv_out = nn.Conv2d( + current_channels, + out_channels=3, + kernel_size=3, + padding=1 + ) # (bs, 48, 224, 224) -> (bs, 3, 224, 224) # do a pooling layer to downscale the input to 1/3 of the size - # (bs, 3, 672, 672) -> (bs, 3, 224, 224) + # (bs, 3, 896, 896) -> (bs, 3, 224, 224) kernel_size = input_size // clip_input_size self.res_down = nn.AvgPool2d( kernel_size=kernel_size, stride=kernel_size - ) # (bs, 3, 672, 672) -> (bs, 3, 224, 224) + ) # (bs, 3, 896, 896) -> (bs, 3, 224, 224) # make a blending for output residual with near 0 weight self.res_blend = nn.Parameter(torch.tensor(0.001)) # (bs, 3, 224, 224) -> (bs, 3, 224, 224) - self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 672, 672) -> (bs, 108, 112, 112) + self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 896, 896) -> (bs, 768, 56, 56) self.conv_in = nn.Sequential( nn.Conv2d( @@ -78,35 +100,23 @@ class CLIPImagePreProcessor(nn.Module): padding=1 ), nn.GELU() - ) # (bs, 108, 112, 112) -> (bs, 108, 112, 112) - - self.upsample_blocks = nn.ModuleList() - current_channels = channels - for _ in range(num_upsample_blocks): - out_channels = current_channels // 2 - self.upsample_blocks.append(UpsampleBlock(current_channels, out_channels)) - current_channels = out_channels - - # (bs, 108, 112, 112) -> (bs, 54, 224, 224) - - self.conv_out = nn.Conv2d( - current_channels, - out_channels=3, - kernel_size=3, - padding=1 - ) # (bs, 54, 224, 224) -> (bs, 3, 224, 224) + ) # (bs, 768, 56, 56) -> (bs, 768, 56, 56) + # make 2 deep blocks def forward(self, x): + inputs = x # resize to input_size x input_size x = nn.functional.interpolate(x, size=(self.input_size, self.input_size), mode='bicubic') - res = self.res_down(x) + res = self.res_down(inputs) x = self.unshuffle(x) x = self.conv_in(x) - for up in self.upsample_blocks: + for up, subpixel in zip(self.upsample_blocks, self.subpixel_blocks): x = up(x) + block_res = subpixel(inputs) + x = x + block_res x = self.conv_out(x) # blend residual x = x * self.res_blend + res diff --git a/toolkit/saving.py b/toolkit/saving.py index 9bedcd8d..3a5789fd 100644 --- a/toolkit/saving.py +++ b/toolkit/saving.py @@ -208,7 +208,6 @@ def load_t2i_model( return converted_state_dict -IP_ADAPTER_MODULES = ['image_proj', 'ip_adapter'] def save_ip_adapter_from_diffusers(