From e074058faa0e972106d2b05839de3c4fefe0ecfb Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 10 Feb 2024 09:00:05 -0700 Subject: [PATCH] Work on additional image embedding methods. Finalized zipper resampler. It works amazing --- extensions_built_in/sd_trainer/SDTrainer.py | 2 +- toolkit/custom_adapter.py | 17 +- toolkit/ip_adapter.py | 55 ++++++- toolkit/models/clip_fusion.py | 41 +---- toolkit/models/ilora.py | 20 ++- toolkit/models/zipper_resampler.py | 171 ++++++++++++++++++++ toolkit/network_mixins.py | 2 +- 7 files changed, 261 insertions(+), 47 deletions(-) create mode 100644 toolkit/models/zipper_resampler.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 1647d513..43e87ac6 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -998,7 +998,7 @@ class SDTrainer(BaseSDTrainProcess): tensors_0_1=clip_images, is_training=True, has_been_preprocessed=True, - quad_count=quad_count + quad_count=quad_count, ) with self.timer('encode_prompt'): diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index c1c73909..c6f1c925 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -557,6 +557,21 @@ class CustomAdapter(torch.nn.Module): quad_count=4, ) -> PromptEmbeds: if self.adapter_type == 'ilora': + if tensors_0_1 is None: + # scale the noise down + tensors_0_1 = torch.rand([1, 3, self.input_size, self.input_size], device=self.device) + noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device, + dtype=get_torch_dtype(self.sd_ref().dtype)) + tensors_0_1 = tensors_0_1 * noise_scale + # tensors_0_1 = tensors_0_1 * 0 + mean = torch.tensor(self.clip_image_processor.image_mean).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + std = torch.tensor(self.clip_image_processor.image_std).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0 + clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1]) with torch.no_grad(): # on training the clip image is created in the dataloader if not has_been_preprocessed: @@ -626,7 +641,7 @@ class CustomAdapter(torch.nn.Module): if not is_training or not self.config.train_image_encoder: img_embeds = img_embeds.detach() - self.ilora_module.img_embeds = img_embeds + self.ilora_module(img_embeds) def parameters(self, recurse: bool = True) -> Iterator[Parameter]: if self.config.type == 'photo_maker': diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 3a31c8b0..1751f3e2 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -9,6 +9,7 @@ from torch.nn.modules.module import T from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from toolkit.models.clip_pre_processor import CLIPImagePreProcessor +from toolkit.models.zipper_resampler import ZipperResampler from toolkit.paths import REPOS_ROOT from toolkit.saving import load_ip_adapter_model from toolkit.train_tools import get_torch_dtype @@ -33,6 +34,7 @@ from transformers import ( CLIPVisionModel, AutoImageProcessor, ConvNextModel, + ConvNextV2ForImageClassification, ConvNextForImageClassification, ConvNextImageProcessor ) @@ -226,6 +228,20 @@ class IPAdapter(torch.nn.Module): adapter_config.image_encoder_path, use_safetensors=True, ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'convnextv2': + try: + self.clip_image_processor = AutoImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + print(f"could not load image processor from {adapter_config.image_encoder_path}") + self.clip_image_processor = ConvNextImageProcessor( + size=512, + image_mean=[0.485,0.456,0.406], + image_std=[0.229, 0.224, 0.225], + ) + self.image_encoder = ConvNextV2ForImageClassification.from_pretrained( + adapter_config.image_encoder_path, + use_safetensors=True, + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) elif self.config.image_encoder_arch == 'vit-hybrid': try: self.clip_image_processor = ViTHybridImageProcessor.from_pretrained(adapter_config.image_encoder_path) @@ -275,8 +291,12 @@ class IPAdapter(torch.nn.Module): ) if 'height' in self.clip_image_processor.size: self.input_size = self.clip_image_processor.size['height'] - else: + elif hasattr(self.clip_image_processor, 'crop_size'): self.input_size = self.clip_image_processor.crop_size['height'] + elif 'shortest_edge' in self.clip_image_processor.size.keys(): + self.input_size = self.clip_image_processor.size['shortest_edge'] + else: + raise ValueError(f"unknown image processor size: {self.clip_image_processor.size}") self.current_scale = 1.0 self.is_active = True if adapter_config.type == 'ip': @@ -311,6 +331,39 @@ class IPAdapter(torch.nn.Module): output_dim=sd.unet.config['cross_attention_dim'], ff_mult=4 ) + elif adapter_config.type == 'ipz': + dim = sd.unet.config['cross_attention_dim'] + if hasattr(self.image_encoder.config, 'hidden_sizes'): + embedding_dim = self.image_encoder.config.hidden_sizes[-1] + else: + embedding_dim = self.image_encoder.config.hidden_size + + image_encoder_state_dict = self.image_encoder.state_dict() + # max_seq_len = CLIP tokens + CLS token + in_tokens = 257 + if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict: + # clip + in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) + + if self.config.image_encoder_arch.startswith('convnext'): + in_tokens = 16 * 16 + embedding_dim = self.image_encoder.config.hidden_sizes[-1] + + is_conv_next = self.config.image_encoder_arch.startswith('convnext') + + out_tokens = self.config.num_tokens if self.config.num_tokens > 0 else in_tokens + # ip-adapter-plus + image_proj_model = ZipperResampler( + in_size=embedding_dim, + in_tokens=in_tokens, + out_size=dim, + out_tokens=out_tokens, + hidden_size=embedding_dim, + hidden_tokens=in_tokens, + # num_blocks=1 if not is_conv_next else 2, + num_blocks=1 if not is_conv_next else 2, + is_conv_input=is_conv_next + ) elif adapter_config.type == 'ilora': # we apply the clip encodings to the LoRA image_proj_model = None diff --git a/toolkit/models/clip_fusion.py b/toolkit/models/clip_fusion.py index d2674175..f4346fd5 100644 --- a/toolkit/models/clip_fusion.py +++ b/toolkit/models/clip_fusion.py @@ -1,6 +1,8 @@ import torch import torch.nn as nn +from toolkit.models.zipper_resampler import ContextualAlphaMask + # Conv1d MLP # MLP that can alternately be used as a conv1d on dim 1 @@ -86,46 +88,7 @@ class ZipperBlock(nn.Module): return x -class ContextualAlphaMask(nn.Module): - def __init__( - self, - dim: int = 768, - ): - super(ContextualAlphaMask, self).__init__() - self.dim = dim - half_dim = dim // 2 - quarter_dim = dim // 4 - - self.fc1 = nn.Linear(self.dim, self.dim) - self.fc2 = nn.Linear(self.dim, half_dim) - self.norm1 = nn.LayerNorm(half_dim) - self.fc3 = nn.Linear(half_dim, half_dim) - self.fc4 = nn.Linear(half_dim, quarter_dim) - self.norm2 = nn.LayerNorm(quarter_dim) - self.fc5 = nn.Linear(quarter_dim, quarter_dim) - self.fc6 = nn.Linear(quarter_dim, 1) - # set fc6 weights to near zero - self.fc6.weight.data.normal_(mean=0.0, std=0.0001) - self.act_fn = nn.GELU() - - def forward(self, x): - # x = (batch_size, 77, 768) - x = self.fc1(x) - x = self.act_fn(x) - x = self.fc2(x) - x = self.norm1(x) - x = self.act_fn(x) - x = self.fc3(x) - x = self.act_fn(x) - x = self.fc4(x) - x = self.norm2(x) - x = self.act_fn(x) - x = self.fc5(x) - x = self.act_fn(x) - x = self.fc6(x) - x = torch.sigmoid(x) - return x diff --git a/toolkit/models/ilora.py b/toolkit/models/ilora.py index 968f28a9..9d18b6c4 100644 --- a/toolkit/models/ilora.py +++ b/toolkit/models/ilora.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn from typing import TYPE_CHECKING from toolkit.models.clip_fusion import ZipperBlock +from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler if TYPE_CHECKING: from toolkit.lora_special import LoRAModule @@ -26,7 +27,7 @@ class InstantLoRAMidModule(torch.nn.Module): self.lora_module_ref = weakref.ref(lora_module) self.instant_lora_module_ref = weakref.ref(instant_lora_module) - self.zip = ZipperBlock( + self.zip = ZipperModule( in_size=self.vision_hidden_size, in_tokens=self.vision_tokens, out_size=self.dim, @@ -71,7 +72,7 @@ class InstantLoRAModule(torch.nn.Module): sd: 'StableDiffusion' ): super(InstantLoRAModule, self).__init__() - self.linear = torch.nn.Linear(2, 1) + # self.linear = torch.nn.Linear(2, 1) self.sd_ref = weakref.ref(sd) self.dim = sd.network.lora_dim self.vision_hidden_size = vision_hidden_size @@ -83,6 +84,15 @@ class InstantLoRAModule(torch.nn.Module): # disable merging in. It is slower on inference self.sd_ref().network.can_merge_in = False + self.resampler = ZipperResampler( + in_size=self.vision_hidden_size, + in_tokens=self.vision_tokens, + out_size=self.vision_hidden_size, + out_tokens=self.vision_tokens, + hidden_size=self.vision_hidden_size, + hidden_tokens=self.vision_tokens + ) + self.ilora_modules = torch.nn.ModuleList() lora_modules = self.sd_ref().network.get_all_modules() @@ -99,5 +109,7 @@ class InstantLoRAModule(torch.nn.Module): # add a new mid module that will take the original forward and add a vector to it # this will be used to add the vector to the original forward - def forward(self, x): - return self.linear(x) + def forward(self, img_embeds): + img_embeds = self.resampler(img_embeds) + self.img_embeds = img_embeds + diff --git a/toolkit/models/zipper_resampler.py b/toolkit/models/zipper_resampler.py new file mode 100644 index 00000000..35f018b0 --- /dev/null +++ b/toolkit/models/zipper_resampler.py @@ -0,0 +1,171 @@ +import torch +import torch.nn as nn + + +class ContextualAlphaMask(nn.Module): + def __init__( + self, + dim: int = 768, + ): + super(ContextualAlphaMask, self).__init__() + self.dim = dim + + half_dim = dim // 2 + quarter_dim = dim // 4 + + self.fc1 = nn.Linear(self.dim, self.dim) + self.fc2 = nn.Linear(self.dim, half_dim) + self.norm1 = nn.LayerNorm(half_dim) + self.fc3 = nn.Linear(half_dim, half_dim) + self.fc4 = nn.Linear(half_dim, quarter_dim) + self.norm2 = nn.LayerNorm(quarter_dim) + self.fc5 = nn.Linear(quarter_dim, quarter_dim) + self.fc6 = nn.Linear(quarter_dim, 1) + # set fc6 weights to near zero + self.fc6.weight.data.normal_(mean=0.0, std=0.0001) + self.act_fn = nn.GELU() + + def forward(self, x): + # x = (batch_size, 77, 768) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + x = self.norm1(x) + x = self.act_fn(x) + x = self.fc3(x) + x = self.act_fn(x) + x = self.fc4(x) + x = self.norm2(x) + x = self.act_fn(x) + x = self.fc5(x) + x = self.act_fn(x) + x = self.fc6(x) + x = torch.sigmoid(x) + return x + + +class ZipperModule(nn.Module): + def __init__( + self, + in_size, + in_tokens, + out_size, + out_tokens, + hidden_size, + hidden_tokens, + use_residual=False, + ): + super().__init__() + self.in_size = in_size + self.in_tokens = in_tokens + self.out_size = out_size + self.out_tokens = out_tokens + self.hidden_size = hidden_size + self.hidden_tokens = hidden_tokens + self.use_residual = use_residual + + self.act_fn = nn.GELU() + self.layernorm = nn.LayerNorm(self.in_size) + + self.conv1 = nn.Conv1d(self.in_tokens, self.hidden_tokens, 1) + # act + self.fc1 = nn.Linear(self.in_size, self.hidden_size) + # act + self.conv2 = nn.Conv1d(self.hidden_tokens, self.out_tokens, 1) + # act + self.fc2 = nn.Linear(self.hidden_size, self.out_size) + + def forward(self, x): + residual = x + x = self.layernorm(x) + x = self.conv1(x) + x = self.act_fn(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.conv2(x) + x = self.act_fn(x) + x = self.fc2(x) + if self.use_residual: + x = x + residual + return x + + +class ZipperResampler(nn.Module): + def __init__( + self, + in_size, + in_tokens, + out_size, + out_tokens, + hidden_size, + hidden_tokens, + num_blocks=1, + is_conv_input=False, + ): + super().__init__() + self.is_conv_input = is_conv_input + + module_list = [] + for i in range(num_blocks): + + this_in_size = in_size + this_in_tokens = in_tokens + this_out_size = out_size + this_out_tokens = out_tokens + this_hidden_size = hidden_size + this_hidden_tokens = hidden_tokens + use_residual = False + + # maintain middle sizes as hidden_size + if i == 0: # first block + this_in_size = in_size + this_in_tokens = in_tokens + if num_blocks == 1: + this_out_size = out_size + this_out_tokens = out_tokens + else: + this_out_size = hidden_size + this_out_tokens = hidden_tokens + elif i == num_blocks - 1: # last block + this_out_size = out_size + this_out_tokens = out_tokens + if num_blocks == 1: + this_in_size = in_size + this_in_tokens = in_tokens + else: + this_in_size = hidden_size + this_in_tokens = hidden_tokens + else: # middle blocks + this_out_size = hidden_size + this_out_tokens = hidden_tokens + this_in_size = hidden_size + this_in_tokens = hidden_tokens + use_residual = True + + module_list.append(ZipperModule( + in_size=this_in_size, + in_tokens=this_in_tokens, + out_size=this_out_size, + out_tokens=this_out_tokens, + hidden_size=this_hidden_size, + hidden_tokens=this_hidden_tokens, + use_residual=use_residual + )) + + self.blocks = nn.ModuleList(module_list) + + self.ctx_alpha = ContextualAlphaMask( + dim=out_size, + ) + + def forward(self, x): + if self.is_conv_input: + # flatten + x = x.view(x.size(0), x.size(1), -1) + # rearrange to (batch, tokens, size) + x = x.permute(0, 2, 1) + + for block in self.blocks: + x = block(x) + alpha = self.ctx_alpha(x) + return x * alpha diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 2b7519d3..bfc6bc90 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -473,7 +473,7 @@ class ToolkitNetworkMixin: del load_sd[key] print(f"Missing keys: {to_delete}") - if len(to_delete) > 0 and self.is_v1: + if len(to_delete) > 0 and self.is_v1 and not force_weight_mapping and not (len(to_delete) == 1 and 'emb_params' in to_delete): print(" Attempting to load with forced keymap") return self.load_weights(file, force_weight_mapping=True)