From 3812957bc95a55c224f8b5703e42cf1b900de0e2 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 14 Mar 2025 18:03:00 -0600 Subject: [PATCH] Added ability to train control loras. Other important bug fixes thrown in --- extensions_built_in/sd_trainer/SDTrainer.py | 2 + jobs/process/BaseSDTrainProcess.py | 3 + toolkit/config_modules.py | 37 ++- toolkit/custom_adapter.py | 63 +++++- toolkit/models/control_lora_adapter.py | 239 ++++++++++++++++++++ toolkit/network_mixins.py | 26 ++- toolkit/stable_diffusion_model.py | 14 +- 7 files changed, 365 insertions(+), 19 deletions(-) create mode 100644 toolkit/models/control_lora_adapter.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 6d4f836a..929aafec 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1674,6 +1674,8 @@ class SDTrainer(BaseSDTrainProcess): with self.timer('predict_unet'): if unconditional_embeds is not None: unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() + if self.adapter and isinstance(self.adapter, CustomAdapter): + noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch) noise_pred = self.predict_noise( noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), timesteps=timesteps, diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 149c6668..30813a2e 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -580,6 +580,8 @@ class BaseSDTrainProcess(BaseTrainProcess): direct_save = True if self.adapter_config.type == 'redux': direct_save = True + if self.adapter_config.type == 'control_lora': + direct_save = True save_ip_adapter_from_diffusers( state_dict, output_file=file_path, @@ -1362,6 +1364,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.adapter = CustomAdapter( sd=self.sd, adapter_config=self.adapter_config, + train_config=self.train_config, ) self.adapter.to(self.device_torch, dtype=dtype) if latest_save_path is not None and not is_control_net: diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index fcff3b3f..d67b79eb 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -106,7 +106,7 @@ class LoRMConfig: }) -NetworkType = Literal['lora', 'locon', 'lorm'] +NetworkType = Literal['lora', 'locon', 'lorm', 'lokr'] class NetworkConfig: @@ -151,7 +151,7 @@ class NetworkConfig: self.lokr_factor = kwargs.get('lokr_factor', -1) -AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net'] +AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora'] CLIPLayer = Literal['penultimate_hidden_states', 'image_embeds', 'last_hidden_state'] @@ -234,6 +234,13 @@ class AdapterConfig: # for llm adapter self.num_cloned_blocks: int = kwargs.get('num_cloned_blocks', 0) self.quantize_llm: bool = kwargs.get('quantize_llm', False) + + # for control lora only + lora_config: dict = kwargs.get('lora_config', None) + if lora_config is not None: + self.lora_config: NetworkConfig = NetworkConfig(**lora_config) + else: + self.lora_config = None class EmbeddingConfig: @@ -521,6 +528,32 @@ class ModelConfig: self.arch: ModelArch = kwargs.get("arch", None) # handle migrating to new model arch + if self.arch is not None: + # reverse the arch to the old style + if self.arch == 'sd2': + self.is_v2 = True + elif self.arch == 'sd3': + self.is_v3 = True + elif self.arch == 'sdxl': + self.is_xl = True + elif self.arch == 'pixart': + self.is_pixart = True + elif self.arch == 'pixart_sigma': + self.is_pixart_sigma = True + elif self.arch == 'auraflow': + self.is_auraflow = True + elif self.arch == 'flux': + self.is_flux = True + elif self.arch == 'flex2': + self.is_flex2 = True + elif self.arch == 'lumina2': + self.is_lumina2 = True + elif self.arch == 'vega': + self.is_vega = True + elif self.arch == 'ssd': + self.is_ssd = True + else: + pass if self.arch is None: if kwargs.get('is_v2', False): self.arch = 'sd2' diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 81f7f455..5d5ee945 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -7,8 +7,10 @@ from torch.nn import Parameter from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, CLIPTextModel, \ CLIPTokenizer, T5Tokenizer +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.models.clip_fusion import CLIPFusionModule from toolkit.models.clip_pre_processor import CLIPImagePreProcessor +from toolkit.models.control_lora_adapter import ControlLoraAdapter from toolkit.models.ilora import InstantLoRAModule from toolkit.models.single_value_adapter import SingleValueAdapter from toolkit.models.te_adapter import TEAdapter @@ -29,7 +31,7 @@ from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProces AttnProcessor2_0 from ipadapter.ip_adapter.ip_adapter import ImageProjModel from ipadapter.ip_adapter.resampler import Resampler -from toolkit.config_modules import AdapterConfig, AdapterTypes +from toolkit.config_modules import AdapterConfig, AdapterTypes, TrainConfig from toolkit.prompt_utils import PromptEmbeds import weakref @@ -58,10 +60,11 @@ import torch.nn.functional as F class CustomAdapter(torch.nn.Module): - def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'): + def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig', train_config: 'TrainConfig'): super().__init__() self.config = adapter_config self.sd_ref: weakref.ref = weakref.ref(sd) + self.train_config = train_config self.device = self.sd_ref().unet.device self.image_processor: CLIPImageProcessor = None self.input_size = 224 @@ -97,6 +100,7 @@ class CustomAdapter(torch.nn.Module): self.vd_adapter: VisionDirectAdapter = None self.single_value_adapter: SingleValueAdapter = None self.redux_adapter: ReduxImageEncoder = None + self.control_lora: ControlLoraAdapter = None self.conditional_embeds: Optional[torch.Tensor] = None self.unconditional_embeds: Optional[torch.Tensor] = None @@ -240,6 +244,13 @@ class CustomAdapter(torch.nn.Module): elif self.adapter_type == 'redux': vision_hidden_size = self.vision_encoder.config.hidden_size self.redux_adapter = ReduxImageEncoder(vision_hidden_size, 4096, self.device, torch_dtype) + elif self.adapter_type == 'control_lora': + self.control_lora = ControlLoraAdapter( + self, + sd=self.sd_ref(), + config=self.config, + train_config=self.train_config + ) else: raise ValueError(f"unknown adapter type: {self.adapter_type}") @@ -271,7 +282,7 @@ class CustomAdapter(torch.nn.Module): def setup_clip(self): adapter_config = self.config sd = self.sd_ref() - if self.config.type in ["text_encoder", "llm_adapter", "single_value"]: + if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora"]: return if self.config.type == 'photo_maker': try: @@ -481,6 +492,14 @@ class CustomAdapter(torch.nn.Module): for k2, v2 in v.items(): new_dict[k + '.' + k2] = v2 self.redux_adapter.load_state_dict(new_dict, strict=True) + + if self.adapter_type == 'control_lora': + # state dict is seperated. so recombine it + new_dict = {} + for k, v in state_dict.items(): + for k2, v2 in v.items(): + new_dict[k + '.' + k2] = v2 + self.control_lora.load_weights(new_dict, strict=strict) pass @@ -532,6 +551,11 @@ class CustomAdapter(torch.nn.Module): for k, v in d.items(): state_dict[k] = v return state_dict + elif self.adapter_type == 'control_lora': + d = self.control_lora.get_state_dict() + for k, v in d.items(): + state_dict[k] = v + return state_dict else: raise NotImplementedError @@ -541,6 +565,33 @@ class CustomAdapter(torch.nn.Module): self.unconditional_embeds = extra_values.to(self.device, get_torch_dtype(self.sd_ref().dtype)) else: self.conditional_embeds = extra_values.to(self.device, get_torch_dtype(self.sd_ref().dtype)) + + def condition_noisy_latents(self, latents: torch.Tensor, batch:DataLoaderBatchDTO): + with torch.no_grad(): + if self.adapter_type in ['control_lora']: + sd: StableDiffusion = self.sd_ref() + control_tensor = batch.control_tensor + if control_tensor is None: + # concat random normal noise onto the latents + # check dimension, this is before they are rearranged + # it is latent_model_input = torch.cat([latents, control_image], dim=2) after rearranging + latents = torch.cat((latents, torch.randn_like(latents)), dim=1) + return latents.detach() + # it is 0-1 need to convert to -1 to 1 + control_tensor = control_tensor * 2 - 1 + + control_tensor = control_tensor.to(sd.vae_device_torch, dtype=sd.torch_dtype) + + # if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it + if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]: + control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bicubic') + + # encode it + control_latent = sd.encode_images(control_tensor).to(latents.device, latents.dtype) + # concat it onto the latents + latents = torch.cat((latents, control_latent), dim=1) + return latents.detach() + return latents def condition_prompt( @@ -548,7 +599,7 @@ class CustomAdapter(torch.nn.Module): prompt: Union[List[str], str], is_unconditional: bool = False, ): - if self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'redux': + if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora']: return prompt elif self.adapter_type == 'text_encoder': # todo allow for training @@ -1067,6 +1118,10 @@ class CustomAdapter(torch.nn.Module): yield from self.single_value_adapter.parameters(recurse) elif self.config.type == 'redux': yield from self.redux_adapter.parameters(recurse) + elif self.config.type == 'control_lora': + param_list = self.control_lora.get_params() + for param in param_list: + yield param else: raise NotImplementedError diff --git a/toolkit/models/control_lora_adapter.py b/toolkit/models/control_lora_adapter.py new file mode 100644 index 00000000..387400ed --- /dev/null +++ b/toolkit/models/control_lora_adapter.py @@ -0,0 +1,239 @@ +import inspect +import weakref +import torch +from typing import TYPE_CHECKING +from toolkit.lora_special import LoRASpecialNetwork +from diffusers import FluxTransformer2DModel +# weakref + + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig + from toolkit.custom_adapter import CustomAdapter + + +# after each step we concat the control image with the latents +# latent_model_input = torch.cat([latents, control_image], dim=2) +# the x_embedder has a full rank lora to handle the additional channels +# this replaces the x_embedder with a full rank lora. on flux this is +# x_embedder(diffusers) or img_in(bfl) + +# Flux +# img_in.lora_A.weight [128, 128] +# img_in.lora_B.bias [3 072] +# img_in.lora_B.weight [3 072, 128] + + +class ImgEmbedder(torch.nn.Module): + def __init__( + self, + adapter: 'ControlLoraAdapter', + orig_layer: torch.nn.Module, + in_channels=128, + out_channels=3072, + bias=True + ): + super().__init__() + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.orig_layer_ref: weakref.ref = weakref.ref(orig_layer) + self.lora_A = torch.nn.Linear(in_channels, in_channels, bias=False) # lora down + self.lora_B = torch.nn.Linear(in_channels, out_channels, bias=bias) # lora up + + @classmethod + def from_model( + cls, + model: FluxTransformer2DModel, + adapter: 'ControlLoraAdapter', + num_channel_multiplier=2 + ): + if model.__class__.__name__ == 'FluxTransformer2DModel': + x_embedder: torch.nn.Linear = model.x_embedder + img_embedder = cls( + adapter, + orig_layer=x_embedder, + in_channels=x_embedder.in_features * num_channel_multiplier, # adding additional control img channels + out_channels=x_embedder.out_features, + bias=x_embedder.bias is not None + ) + + # hijack the forward method + x_embedder._orig_ctrl_lora_forward = x_embedder.forward + x_embedder.forward = img_embedder.forward + dtype = x_embedder.weight.dtype + device = x_embedder.weight.device + + # since we are adding control channels, we want those channels to be zero starting out + # so they have no effect. It will match lora_B weight and bias, and we concat 0s for the input of the new channels + # lora_a needs to be identity so that lora_b output matches lora_a output on init + img_embedder.lora_A.weight.data = torch.eye(x_embedder.in_features * num_channel_multiplier).to(dtype=torch.float32, device=device) + weight_b = x_embedder.weight.data.clone().to(dtype=torch.float32, device=device) + # concat 0s for the new channels + weight_b = torch.cat([weight_b, torch.zeros(weight_b.shape[0], weight_b.shape[1] * (num_channel_multiplier - 1)).to(device)], dim=1) + img_embedder.lora_B.weight.data = weight_b.clone().to(dtype=torch.float32) + img_embedder.lora_B.bias.data = x_embedder.bias.data.clone().to(dtype=torch.float32) + + # update the config of the transformer + model.config.in_channels = model.config.in_channels * num_channel_multiplier + model.config["in_channels"] = model.config.in_channels + + return img_embedder + else: + raise ValueError("Model not supported") + + @property + def is_active(self): + return self.adapter_ref().is_active + + + def forward(self, x): + if not self.is_active: + # make sure lora is not active + self.adapter_ref().control_lora.is_active = False + return self.orig_layer_ref()._orig_ctrl_lora_forward(x) + + # make sure lora is active + self.adapter_ref().control_lora.is_active = True + + orig_device = x.device + orig_dtype = x.dtype + x = x.to(self.lora_A.weight.device, dtype=self.lora_A.weight.dtype) + + x = self.lora_A(x) + x = self.lora_B(x) + x = x.to(orig_device, dtype=orig_dtype) + return x + + + +class ControlLoraAdapter(torch.nn.Module): + def __init__( + self, + adapter: 'CustomAdapter', + sd: 'StableDiffusion', + config: 'AdapterConfig', + train_config: 'TrainConfig' + ): + super().__init__() + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.sd_ref = weakref.ref(sd) + self.model_config: ModelConfig = sd.model_config + self.network_config = config.lora_config + self.train_config = train_config + if self.network_config is None: + raise ValueError("LoRA config is missing") + + network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs + if hasattr(sd, 'target_lora_modules'): + network_kwargs['target_lin_modules'] = self.sd.target_lora_modules + + if 'ignore_if_contains' not in network_kwargs: + network_kwargs['ignore_if_contains'] = [] + + # always ignore x_embedder + network_kwargs['ignore_if_contains'].append('x_embedder') + + self.device_torch = sd.device_torch + self.control_lora = LoRASpecialNetwork( + text_encoder=sd.text_encoder, + unet=sd.unet, + lora_dim=self.network_config.linear, + multiplier=1.0, + alpha=self.network_config.linear_alpha, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + conv_lora_dim=self.network_config.conv, + conv_alpha=self.network_config.conv_alpha, + is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, + is_v2=self.model_config.is_v2, + is_v3=self.model_config.is_v3, + is_pixart=self.model_config.is_pixart, + is_auraflow=self.model_config.is_auraflow, + is_flux=self.model_config.is_flux, + is_lumina2=self.model_config.is_lumina2, + is_ssd=self.model_config.is_ssd, + is_vega=self.model_config.is_vega, + dropout=self.network_config.dropout, + use_text_encoder_1=self.model_config.use_text_encoder_1, + use_text_encoder_2=self.model_config.use_text_encoder_2, + use_bias=False, + is_lorm=False, + network_config=self.network_config, + network_type=self.network_config.type, + transformer_only=self.network_config.transformer_only, + is_transformer=sd.is_transformer, + base_model=sd, + **network_kwargs + ) + self.control_lora.force_to(self.device_torch, dtype=torch.float32) + self.control_lora._update_torch_multiplier() + self.control_lora.apply_to( + sd.text_encoder, + sd.unet, + self.train_config.train_text_encoder, + self.train_config.train_unet + ) + self.control_lora.can_merge_in = False + self.control_lora.prepare_grad_etc(sd.text_encoder, sd.unet) + if self.train_config.gradient_checkpointing: + self.control_lora.enable_gradient_checkpointing() + + self.x_embedder = ImgEmbedder.from_model(sd.unet, self) + self.x_embedder.to(self.device_torch) + + def get_params(self): + # LyCORIS doesnt have default_lr + config = { + 'text_encoder_lr': self.train_config.lr, + 'unet_lr': self.train_config.lr, + } + sig = inspect.signature(self.control_lora.prepare_optimizer_params) + if 'default_lr' in sig.parameters: + config['default_lr'] = self.train_config.lr + if 'learning_rate' in sig.parameters: + config['learning_rate'] = self.train_config.lr + params_net = self.control_lora.prepare_optimizer_params( + **config + ) + + # we want only tensors here + params = [] + for p in params_net: + if isinstance(p, dict): + params += p["params"] + elif isinstance(p, torch.Tensor): + params.append(p) + elif isinstance(p, list): + params += p + + params += list(self.x_embedder.parameters()) + + # we need to be able to yield from the list like yield from params + + return params + + def load_weights(self, state_dict, strict=True): + lora_sd = {} + img_embedder_sd = {} + for key, value in state_dict.items(): + if "x_embedder" in key: + new_key = key.replace("transformer.x_embedder.", "") + img_embedder_sd[new_key] = value + else: + lora_sd[key] = value + + # todo process state dict before loading + self.control_lora.load_weights(lora_sd) + self.x_embedder.load_state_dict(img_embedder_sd, strict=strict) + + def get_state_dict(self): + lora_sd = self.control_lora.get_state_dict(dtype=torch.float32) + # todo make sure we match loras elseware. + img_embedder_sd = self.x_embedder.state_dict() + for key, value in img_embedder_sd.items(): + lora_sd[f"transformer.x_embedder.{key}"] = value + return lora_sd + + @property + def is_active(self): + return self.adapter_ref().is_active diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 6346eacd..b52af32b 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -491,13 +491,8 @@ class ToolkitNetworkMixin: keymap = new_keymap return keymap - - def save_weights( - self: Network, - file, dtype=torch.float16, - metadata=None, - extra_state_dict: Optional[OrderedDict] = None - ): + + def get_state_dict(self: Network, extra_state_dict=None, dtype=torch.float16): keymap = self.get_keymap() save_keymap = {} @@ -506,9 +501,6 @@ class ToolkitNetworkMixin: # invert them save_keymap[diffusers_key] = ldm_key - if metadata is not None and len(metadata) == 0: - metadata = None - state_dict = self.state_dict() save_dict = OrderedDict() @@ -556,10 +548,22 @@ class ToolkitNetworkMixin: save_dict = new_save_dict save_dict = self.base_model_ref().convert_lora_weights_before_save(save_dict) + return save_dict + + def save_weights( + self: Network, + file, dtype=torch.float16, + metadata=None, + extra_state_dict: Optional[OrderedDict] = None + ): + save_dict = self.get_state_dict(extra_state_dict=extra_state_dict, dtype=dtype) + + if metadata is not None and len(metadata) == 0: + metadata = None if metadata is None: metadata = OrderedDict() - metadata = add_model_hash_to_meta(state_dict, metadata) + metadata = add_model_hash_to_meta(save_dict, metadata) if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import save_file save_file(save_dict, file, metadata) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 892385b6..5f459e56 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -49,7 +49,8 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAda StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \ StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \ StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \ - FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, Lumina2Text2ImgPipeline + FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, Lumina2Text2ImgPipeline, \ + FluxControlPipeline from toolkit.models.lumina2 import Lumina2Transformer2DModel from toolkit.models.flex2 import Flex2Pipeline import diffusers @@ -155,6 +156,7 @@ class StableDiffusion: self.model_config = model_config self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" + self.arch = model_config.arch self.device_state = None @@ -1239,6 +1241,10 @@ class StableDiffusion: Pipe = FluxPipeline if self.is_flex2: Pipe = Flex2Pipeline + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + # see if it is a control lora + if self.adapter.control_lora is not None: + Pipe = FluxControlPipeline pipeline = Pipe( vae=self.vae, @@ -1358,6 +1364,9 @@ class StableDiffusion: validation_image = validation_image.resize((gen_config.width, gen_config.height)) extra['image'] = validation_image extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale + if isinstance(self.adapter, CustomAdapter) and self.adapter.control_lora is not None: + validation_image = validation_image.resize((gen_config.width, gen_config.height)) + extra['control_image'] = validation_image if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter): transform = transforms.Compose([ transforms.ToTensor(), @@ -2136,7 +2145,8 @@ class StableDiffusion: w=latent_model_input.shape[3] // 2, ph=2, pw=2, - c=latent_model_input.shape[1], + # c=latent_model_input.shape[1], + c=self.vae.config.latent_channels ) if bypass_guidance_embedding: