diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index f798cbc8..7ba5862f 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -376,7 +376,8 @@ class SDTrainer(BaseSDTrainProcess): # 3 just do mode for now? # if args.weighting_scheme == "sigma_sqrt": sigmas = self.sd.noise_scheduler.get_sigmas(timesteps, pred.ndim, dtype, self.device_torch) - weighting = (sigmas ** -2.0).float() + # weighting = (sigmas ** -2.0).float() + weighting = torch.ones_like(sigmas) # elif args.weighting_scheme == "logit_normal": # # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). # u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 75c4d753..c6a3245b 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -188,6 +188,10 @@ class AdapterConfig: # trains with a scaler to easy channel bias but merges it in on save self.merge_scaler: bool = kwargs.get('merge_scaler', False) + # for ilora + self.head_dim: int = kwargs.get('head_dim', 1024) + self.num_heads: int = kwargs.get('num_heads', 1) + class EmbeddingConfig: def __init__(self, **kwargs): diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index e953df82..44750c9b 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -145,7 +145,8 @@ class CustomAdapter(torch.nn.Module): self.ilora_module = InstantLoRAModule( vision_tokens=vision_tokens, vision_hidden_size=vision_hidden_size, - head_dim=1024, + head_dim=self.config.head_dim, + num_heads=self.config.num_heads, sd=self.sd_ref() ) elif self.adapter_type == 'text_encoder': @@ -878,6 +879,11 @@ class CustomAdapter(torch.nn.Module): self.vision_encoder.gradient_checkpointing = True def get_additional_save_metadata(self) -> Dict[str, Any]: + additional = {} if self.config.type == 'ilora': - return self.ilora_module.get_additional_save_metadata() - return {} \ No newline at end of file + extra = self.ilora_module.get_additional_save_metadata() + for k, v in extra.items(): + additional[k] = v + additional['clip_layer'] = self.config.clip_layer + additional['image_encoder_arch'] = self.config.head_dim + return additional \ No newline at end of file diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index af917026..f48b9612 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -157,6 +157,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): use_bias: bool = False, is_lorm: bool = False, ignore_if_contains = None, + only_if_contains = None, parameter_threshold: float = 0.0, attn_only: bool = False, target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE, @@ -186,6 +187,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): if ignore_if_contains is None: ignore_if_contains = [] self.ignore_if_contains = ignore_if_contains + + self.only_if_contains: Union[List, None] = only_if_contains + self.lora_dim = lora_dim self.alpha = alpha self.conv_lora_dim = conv_lora_dim @@ -250,6 +254,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): loras = [] skipped = [] attached_modules = [] + lora_shape_dict = {} for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): @@ -269,6 +274,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): lora_name = prefix + "." + name + "." + child_name lora_name = lora_name.replace(".", "_") + if self.only_if_contains is not None and not any([word in lora_name for word in self.only_if_contains]): + continue + dim = None alpha = None @@ -316,6 +324,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): use_bias=use_bias, ) loras.append(lora) + lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape) + ] return loras, skipped text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] diff --git a/toolkit/models/ilora.py b/toolkit/models/ilora.py index 97914990..925b59e6 100644 --- a/toolkit/models/ilora.py +++ b/toolkit/models/ilora.py @@ -181,6 +181,7 @@ class InstantLoRAModule(torch.nn.Module): vision_hidden_size: int, vision_tokens: int, head_dim: int, + num_heads: int, # number of heads in the resampler sd: 'StableDiffusion' ): super(InstantLoRAModule, self).__init__() @@ -190,6 +191,7 @@ class InstantLoRAModule(torch.nn.Module): self.vision_hidden_size = vision_hidden_size self.vision_tokens = vision_tokens self.head_dim = head_dim + self.num_heads = num_heads # stores the projection vector. Grabbed by modules self.img_embeds: List[torch.Tensor] = None @@ -243,7 +245,7 @@ class InstantLoRAModule(torch.nn.Module): depth=4, dim_head=64, heads=12, - num_queries=1, # output tokens + num_queries=num_heads, # output tokens embedding_dim=vision_hidden_size, max_seq_len=vision_tokens, output_dim=head_dim, @@ -261,25 +263,26 @@ class InstantLoRAModule(torch.nn.Module): self.migrate_weight_mapping() def migrate_weight_mapping(self): - # changes the names of the modules to common ones - keymap = self.sd_ref().network.get_keymap() - save_keymap = {} - if keymap is not None: - for ldm_key, diffusers_key in keymap.items(): - # invert them - save_keymap[diffusers_key] = ldm_key - - new_keymap = {} - for key, value in self.weight_mapping: - if key in save_keymap: - new_keymap[save_keymap[key]] = value - else: - print(f"Key {key} not found in keymap") - new_keymap[key] = value - self.weight_mapping = new_keymap - else: - print("No keymap found. Using default names") - return + return + # # changes the names of the modules to common ones + # keymap = self.sd_ref().network.get_keymap() + # save_keymap = {} + # if keymap is not None: + # for ldm_key, diffusers_key in keymap.items(): + # # invert them + # save_keymap[diffusers_key] = ldm_key + # + # new_keymap = {} + # for key, value in self.weight_mapping: + # if key in save_keymap: + # new_keymap[save_keymap[key]] = value + # else: + # print(f"Key {key} not found in keymap") + # new_keymap[key] = value + # self.weight_mapping = new_keymap + # else: + # print("No keymap found. Using default names") + # return def forward(self, img_embeds): @@ -291,7 +294,8 @@ class InstantLoRAModule(torch.nn.Module): img_embeds = self.resampler(img_embeds) img_embeds = self.proj_module(img_embeds) if len(img_embeds.shape) == 3: - img_embeds = img_embeds.squeeze(1) + # merge the heads + img_embeds = img_embeds.mean(dim=1) self.img_embeds = [] # get all the slices @@ -304,6 +308,11 @@ class InstantLoRAModule(torch.nn.Module): def get_additional_save_metadata(self) -> Dict[str, Any]: # save the weight mapping return { - "weight_mapping": self.weight_mapping + "weight_mapping": self.weight_mapping, + "num_heads": self.num_heads, + "vision_hidden_size": self.vision_hidden_size, + "head_dim": self.head_dim, + "vision_tokens": self.vision_tokens, + "output_size": self.output_size, } diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index d792d359..1f62f348 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -576,13 +576,11 @@ class StableDiffusion: if self.is_xl: pipeline = Pipe( vae=self.vae, - transformer=self.unet, + unet=self.unet, text_encoder=self.text_encoder[0], text_encoder_2=self.text_encoder[1], - text_encoder_3=self.text_encoder[2], tokenizer=self.tokenizer[0], tokenizer_2=self.tokenizer[1], - tokenizer_3=self.tokenizer[2], scheduler=noise_scheduler, **extra_args ).to(self.device_torch)