diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index ba9ad115..be35198e 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1243,8 +1243,10 @@ class SDTrainer(BaseSDTrainProcess): has_been_preprocessed=True, quad_count=quad_count ) - # else: - # raise ValueError("Adapter images now must be loaded with dataloader or be a reg image") + else: + print("No Clip Image") + print([file_item.path for file_item in batch.file_items]) + raise ValueError("Could not find clip image") if not self.adapter_config.train_image_encoder: # we are not training the image encoder, so we need to detach the embeds diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index efcdcfdb..9f5e1a5e 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1293,11 +1293,18 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.adapter_config is not None: self.setup_adapter() if self.adapter_config.train: - # set trainable params - params.append({ - 'params': self.adapter.parameters(), - 'lr': self.train_config.adapter_lr - }) + + if isinstance(self.adapter, IPAdapter): + # we have custom LR groups for IPAdapter + adapter_param_groups = self.adapter.get_parameter_groups(self.train_config.adapter_lr) + for group in adapter_param_groups: + params.append(group) + else: + # set trainable params + params.append({ + 'params': self.adapter.parameters(), + 'lr': self.train_config.adapter_lr + }) if self.train_config.gradient_checkpointing: self.adapter.enable_gradient_checkpointing() diff --git a/toolkit/config.py b/toolkit/config.py index 30a6d538..52de47b8 100644 --- a/toolkit/config.py +++ b/toolkit/config.py @@ -43,9 +43,7 @@ def preprocess_config(config: OrderedDict, name: str = None): if "name" not in config["config"] and name is None: raise ValueError("config file must have a config.name key") # we need to replace tags. For now just [name] - if name is not None: - config["config"]["name"] = name - else: + if name is None: name = config["config"]["name"] config_string = json.dumps(config) config_string = config_string.replace("[name]", name) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 3caecbca..14cc31c2 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -181,6 +181,12 @@ class AdapterConfig: self.text_encoder_path: str = kwargs.get('text_encoder_path', None) self.text_encoder_arch: str = kwargs.get('text_encoder_arch', 'clip') # clip t5 + self.train_scaler: bool = kwargs.get('train_scaler', False) + self.scaler_lr: Optional[float] = kwargs.get('scaler_lr', None) + + # trains with a scaler to easy channel bias but merges it in on save + self.merge_scaler: bool = kwargs.get('merge_scaler', False) + class EmbeddingConfig: def __init__(self, **kwargs): diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index a59d5789..5588bf66 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -80,9 +80,14 @@ class MLPProjModelClipFace(torch.nn.Module): class CustomIPAttentionProcessor(IPAttnProcessor2_0): - def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None): + def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False): super().__init__(hidden_size, cross_attention_dim, scale=scale, num_tokens=num_tokens) self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.train_scaler = train_scaler + if train_scaler: + # self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.9999) + self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999) + self.ip_scaler.requires_grad_(True) def __call__( self, @@ -169,6 +174,13 @@ class CustomIPAttentionProcessor(IPAttnProcessor2_0): # will be none if disabled if ip_hidden_states is not None: + # apply scaler + if self.train_scaler: + weight = self.ip_scaler + # reshape to (1, self.num_tokens, 1) + weight = weight.view(1, -1, 1) + ip_hidden_states = ip_hidden_states * weight + # for ip-adapter ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) @@ -185,7 +197,8 @@ class CustomIPAttentionProcessor(IPAttnProcessor2_0): ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.to(query.dtype) - hidden_states = hidden_states + self.scale * ip_hidden_states + scale = self.scale + hidden_states = hidden_states + scale * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) @@ -202,6 +215,21 @@ class CustomIPAttentionProcessor(IPAttnProcessor2_0): return hidden_states + # this ensures that the ip_scaler is not changed when we load the model + # def _apply(self, fn): + # if hasattr(self, "ip_scaler"): + # # Overriding the _apply method to prevent the special_parameter from changing dtype + # self.ip_scaler = fn(self.ip_scaler) + # # Temporarily set the special_parameter to None to exclude it from default _apply processing + # ip_scaler = self.ip_scaler + # self.ip_scaler = None + # super(CustomIPAttentionProcessor, self)._apply(fn) + # # Restore the special_parameter after the default _apply processing + # self.ip_scaler = ip_scaler + # return self + # else: + # return super(CustomIPAttentionProcessor, self)._apply(fn) + # loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py class IPAdapter(torch.nn.Module): @@ -485,7 +513,8 @@ class IPAdapter(torch.nn.Module): cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=self.config.num_tokens, - adapter=self + adapter=self, + train_scaler=self.config.train_scaler or self.config.merge_scaler ) if self.sd_ref().is_pixart: # pixart is much more sensitive @@ -494,7 +523,7 @@ class IPAdapter(torch.nn.Module): "to_v_ip.weight": weights["to_v_ip.weight"] * 0.01, } - attn_procs[name].load_state_dict(weights) + attn_procs[name].load_state_dict(weights, strict=False) attn_processor_names.append(name) print(f"Attn Processors") print(attn_processor_names) @@ -568,9 +597,34 @@ class IPAdapter(torch.nn.Module): state_dict = OrderedDict() if self.config.train_only_image_encoder: return self.image_encoder.state_dict() + if self.config.train_scaler: + state_dict["ip_scale"] = self.adapter_modules.state_dict() + # remove items that are not scalers + for key in list(state_dict["ip_scale"].keys()): + if not key.endswith("ip_scaler"): + del state_dict["ip_scale"][key] + return state_dict state_dict["image_proj"] = self.image_proj_model.state_dict() state_dict["ip_adapter"] = self.adapter_modules.state_dict() + # handle merge scaler training + if self.config.merge_scaler: + for key in list(state_dict["ip_adapter"].keys()): + if key.endswith("ip_scaler"): + # merge in the scaler so we dont have to save it and it will be compatible with other ip adapters + scale = state_dict["ip_adapter"][key].clone() + + key_start = key.split(".")[-2] + # reshape to (1, 1) + scale = scale.view(1, 1) + del state_dict["ip_adapter"][key] + # find the to_k_ip and to_v_ip keys + for key2 in list(state_dict["ip_adapter"].keys()): + if key2.endswith(f"{key_start}.to_k_ip.weight"): + state_dict["ip_adapter"][key2] = state_dict["ip_adapter"][key2].clone() * scale + if key2.endswith(f"{key_start}.to_v_ip.weight"): + state_dict["ip_adapter"][key2] = state_dict["ip_adapter"][key2].clone() * scale + if self.config.train_image_encoder: state_dict["image_encoder"] = self.image_encoder.state_dict() if self.preprocessor is not None: @@ -866,18 +920,61 @@ class IPAdapter(torch.nn.Module): self.image_proj_model.train(mode) return super().train(mode) - def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + def get_parameter_groups(self, adapter_lr): + param_groups = [] + # when training just scaler, we do not train anything else + if not self.config.train_scaler: + param_groups.append({ + "params": self.get_non_scaler_parameters(), + "lr": adapter_lr, + }) + if self.config.train_scaler or self.config.merge_scaler: + scaler_lr = adapter_lr if self.config.scaler_lr is None else self.config.scaler_lr + param_groups.append({ + "params": self.get_scaler_parameters(), + "lr": scaler_lr, + }) + return param_groups + + def get_scaler_parameters(self): + # only get the scalera from the adapter modules + for attn_processor in self.adapter_modules: + # only get the scaler + # check if it has ip_scaler attribute + if hasattr(attn_processor, "ip_scaler"): + scaler_param = attn_processor.ip_scaler + yield scaler_param + + def get_non_scaler_parameters(self, recurse: bool = True) -> Iterator[Parameter]: if self.config.train_only_image_encoder: yield from self.image_encoder.parameters(recurse) return + if self.config.train_scaler: + # no params + return + for attn_processor in self.adapter_modules: - yield from attn_processor.parameters(recurse) + if self.config.train_scaler or self.config.merge_scaler: + # todo remove scaler + if hasattr(attn_processor, "to_k_ip"): + # yield the linear layer + yield from attn_processor.to_k_ip.parameters(recurse) + if hasattr(attn_processor, "to_v_ip"): + # yield the linear layer + yield from attn_processor.to_v_ip.parameters(recurse) + else: + yield from attn_processor.parameters(recurse) yield from self.image_proj_model.parameters(recurse) if self.config.train_image_encoder: yield from self.image_encoder.parameters(recurse) if self.preprocessor is not None: yield from self.preprocessor.parameters(recurse) + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + yield from self.get_non_scaler_parameters(recurse) + if self.config.train_scaler or self.config.merge_scaler: + yield from self.get_scaler_parameters() + def merge_in_weights(self, state_dict: Mapping[str, Any]): # merge in img_proj weights current_img_proj_state_dict = self.image_proj_model.state_dict() @@ -975,6 +1072,8 @@ class IPAdapter(torch.nn.Module): def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): strict = False + if self.config.train_scaler and 'ip_scale' in state_dict: + self.adapter_modules.load_state_dict(state_dict["ip_scale"], strict=False) if 'ip_adapter' in state_dict: try: self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict) diff --git a/toolkit/models/ilora.py b/toolkit/models/ilora.py index c7c9fd47..745a3687 100644 --- a/toolkit/models/ilora.py +++ b/toolkit/models/ilora.py @@ -5,6 +5,10 @@ import torch.nn as nn from typing import TYPE_CHECKING from toolkit.models.clip_fusion import ZipperBlock from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler +import sys +from toolkit.paths import REPOS_ROOT +sys.path.append(REPOS_ROOT) +from ipadapter.ip_adapter.resampler import Resampler if TYPE_CHECKING: from toolkit.lora_special import LoRAModule @@ -50,7 +54,7 @@ class InstantLoRAMidModule(torch.nn.Module): raise e # apply tanh to limit values to -1 to 1 # scaler = torch.tanh(scaler) - return x * scaler + return x * (scaler + 1.0) class InstantLoRAModule(torch.nn.Module): @@ -78,15 +82,30 @@ class InstantLoRAModule(torch.nn.Module): lora_modules = self.sd_ref().network.get_all_modules() # resample the output so each module gets one token with a size of its dim so we can multiply by that - self.resampler = ZipperResampler( - in_size=self.vision_hidden_size, - in_tokens=self.vision_tokens, - out_size=self.dim, - out_tokens=len(lora_modules), - hidden_size=self.vision_hidden_size, - hidden_tokens=self.vision_tokens, - num_blocks=1, - ) + # self.resampler = ZipperResampler( + # in_size=self.vision_hidden_size, + # in_tokens=self.vision_tokens, + # out_size=self.dim, + # out_tokens=len(lora_modules), + # hidden_size=self.vision_hidden_size, + # hidden_tokens=self.vision_tokens, + # num_blocks=1, + # ) + # heads = 20 + heads = 12 + dim = 1280 + output_dim = self.dim + self.resampler = Resampler( + dim=dim, + depth=4, + dim_head=64, + heads=heads, + num_queries=len(lora_modules), + embedding_dim=self.vision_hidden_size, + max_seq_len=self.vision_tokens, + output_dim=output_dim, + ff_mult=4 + ) for idx, lora_module in enumerate(lora_modules): # add a new mid module that will take the original forward and add a vector to it diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 214ab44e..c5430b37 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -39,7 +39,7 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \ - StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel + StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny import diffusers from diffusers import \ AutoencoderKL, \ @@ -872,8 +872,10 @@ class StableDiffusion: is_input_scaled=False, detach_unconditional=False, rescale_cfg=None, + return_conditional_pred=False, **kwargs, ): + conditional_pred = None # get the embeddings if text_embeddings is None and conditional_embeddings is None: raise ValueError("Either text_embeddings or conditional_embeddings must be specified") @@ -1024,9 +1026,12 @@ class StableDiffusion: **kwargs, ).sample + conditional_pred = noise_pred + if do_classifier_free_guidance: # perform guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + conditional_pred = noise_pred_text noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond ) @@ -1112,9 +1117,12 @@ class StableDiffusion: **kwargs, ).sample + conditional_pred = noise_pred + if do_classifier_free_guidance: # perform guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2, dim=0) + conditional_pred = noise_pred_text if detach_unconditional: noise_pred_uncond = noise_pred_uncond.detach() noise_pred = noise_pred_uncond + guidance_scale * ( @@ -1141,6 +1149,8 @@ class StableDiffusion: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + if return_conditional_pred: + return noise_pred, conditional_pred return noise_pred def step_scheduler(self, model_input, latent_input, timestep_tensor, noise_scheduler=None): @@ -1187,23 +1197,30 @@ class StableDiffusion: bleed_ratio: float = 0.5, bleed_latents: torch.FloatTensor = None, is_input_scaled=False, + return_first_prediction=False, **kwargs, ): timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps] + first_prediction = None + for timestep in tqdm(timesteps_to_run, leave=False): timestep = timestep.unsqueeze_(0) - noise_pred = self.predict_noise( + noise_pred, conditional_pred = self.predict_noise( latents, text_embeddings, timestep, guidance_scale=guidance_scale, add_time_ids=add_time_ids, is_input_scaled=is_input_scaled, + return_conditional_pred=True, **kwargs, ) # some schedulers need to run separately, so do that. (euler for example) + if return_first_prediction and first_prediction is None: + first_prediction = conditional_pred + latents = self.step_scheduler(noise_pred, latents, timestep) # if not last step, and bleeding, bleed in some latents @@ -1214,6 +1231,8 @@ class StableDiffusion: is_input_scaled = False # return latents_steps + if return_first_prediction: + return latents, first_prediction return latents def encode_prompt( @@ -1311,7 +1330,10 @@ class StableDiffusion: image_list[i] = Resize((image.shape[1] // 8 * 8, image.shape[2] // 8 * 8))(image) images = torch.stack(image_list) - latents = self.vae.encode(images).latent_dist.sample() + if isinstance(self.vae, AutoencoderTiny): + latents = self.vae.encode(images, return_dict=False)[0] + else: + latents = self.vae.encode(images).latent_dist.sample() # latents = self.vae.encode(images, return_dict=False)[0] latents = latents * self.vae.config['scaling_factor'] latents = latents.to(device, dtype=dtype)