From 1ad58c581691e9b7c01a3a6f55fdf1ceae322e62 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 22 Mar 2025 10:24:52 -0600 Subject: [PATCH] Changed control lora to only have new weights and leave other input weights alone for more flexability of using multiple ones together. --- toolkit/models/control_lora_adapter.py | 225 +++++++++++++------------ 1 file changed, 117 insertions(+), 108 deletions(-) diff --git a/toolkit/models/control_lora_adapter.py b/toolkit/models/control_lora_adapter.py index 387400ed..532004f7 100644 --- a/toolkit/models/control_lora_adapter.py +++ b/toolkit/models/control_lora_adapter.py @@ -29,16 +29,17 @@ class ImgEmbedder(torch.nn.Module): def __init__( self, adapter: 'ControlLoraAdapter', - orig_layer: torch.nn.Module, - in_channels=128, - out_channels=3072, - bias=True + orig_layer: torch.nn.Linear, + in_channels=64, + out_channels=3072 ): super().__init__() + # only do the weight for the new input. We combine with the original linear layer + init = torch.randn(out_channels, in_channels, device=orig_layer.weight.device, dtype=orig_layer.weight.dtype) * 0.01 + self.weight = torch.nn.Parameter(init) + self.adapter_ref: weakref.ref = weakref.ref(adapter) self.orig_layer_ref: weakref.ref = weakref.ref(orig_layer) - self.lora_A = torch.nn.Linear(in_channels, in_channels, bias=False) # lora down - self.lora_B = torch.nn.Linear(in_channels, out_channels, bias=bias) # lora up @classmethod def from_model( @@ -52,27 +53,14 @@ class ImgEmbedder(torch.nn.Module): img_embedder = cls( adapter, orig_layer=x_embedder, - in_channels=x_embedder.in_features * num_channel_multiplier, # adding additional control img channels + in_channels=x_embedder.in_features * (num_channel_multiplier - 1), # only our new channels out_channels=x_embedder.out_features, - bias=x_embedder.bias is not None ) # hijack the forward method x_embedder._orig_ctrl_lora_forward = x_embedder.forward x_embedder.forward = img_embedder.forward - dtype = x_embedder.weight.dtype - device = x_embedder.weight.device - - # since we are adding control channels, we want those channels to be zero starting out - # so they have no effect. It will match lora_B weight and bias, and we concat 0s for the input of the new channels - # lora_a needs to be identity so that lora_b output matches lora_a output on init - img_embedder.lora_A.weight.data = torch.eye(x_embedder.in_features * num_channel_multiplier).to(dtype=torch.float32, device=device) - weight_b = x_embedder.weight.data.clone().to(dtype=torch.float32, device=device) - # concat 0s for the new channels - weight_b = torch.cat([weight_b, torch.zeros(weight_b.shape[0], weight_b.shape[1] * (num_channel_multiplier - 1)).to(device)], dim=1) - img_embedder.lora_B.weight.data = weight_b.clone().to(dtype=torch.float32) - img_embedder.lora_B.bias.data = x_embedder.bias.data.clone().to(dtype=torch.float32) - + # update the config of the transformer model.config.in_channels = model.config.in_channels * num_channel_multiplier model.config["in_channels"] = model.config.in_channels @@ -89,18 +77,29 @@ class ImgEmbedder(torch.nn.Module): def forward(self, x): if not self.is_active: # make sure lora is not active - self.adapter_ref().control_lora.is_active = False + if self.adapter_ref().control_lora is not None: + self.adapter_ref().control_lora.is_active = False return self.orig_layer_ref()._orig_ctrl_lora_forward(x) # make sure lora is active - self.adapter_ref().control_lora.is_active = True + if self.adapter_ref().control_lora is not None: + self.adapter_ref().control_lora.is_active = True orig_device = x.device orig_dtype = x.dtype - x = x.to(self.lora_A.weight.device, dtype=self.lora_A.weight.dtype) + + x = x.to(self.weight.device, dtype=self.weight.dtype) + + orig_weight = self.orig_layer_ref().weight.data.detach() + orig_weight = orig_weight.to(self.weight.device, dtype=self.weight.dtype) + linear_weight = torch.cat([orig_weight, self.weight], dim=1) + + bias = None + if self.orig_layer_ref().bias is not None: + bias = self.orig_layer_ref().bias.data.detach().to(self.weight.device, dtype=self.weight.dtype) + + x = torch.nn.functional.linear(x, linear_weight, bias) - x = self.lora_A(x) - x = self.lora_B(x) x = x.to(orig_device, dtype=orig_dtype) return x @@ -120,91 +119,97 @@ class ControlLoraAdapter(torch.nn.Module): self.model_config: ModelConfig = sd.model_config self.network_config = config.lora_config self.train_config = train_config - if self.network_config is None: - raise ValueError("LoRA config is missing") - - network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs - if hasattr(sd, 'target_lora_modules'): - network_kwargs['target_lin_modules'] = self.sd.target_lora_modules - - if 'ignore_if_contains' not in network_kwargs: - network_kwargs['ignore_if_contains'] = [] - - # always ignore x_embedder - network_kwargs['ignore_if_contains'].append('x_embedder') - self.device_torch = sd.device_torch - self.control_lora = LoRASpecialNetwork( - text_encoder=sd.text_encoder, - unet=sd.unet, - lora_dim=self.network_config.linear, - multiplier=1.0, - alpha=self.network_config.linear_alpha, - train_unet=self.train_config.train_unet, - train_text_encoder=self.train_config.train_text_encoder, - conv_lora_dim=self.network_config.conv, - conv_alpha=self.network_config.conv_alpha, - is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, - is_v2=self.model_config.is_v2, - is_v3=self.model_config.is_v3, - is_pixart=self.model_config.is_pixart, - is_auraflow=self.model_config.is_auraflow, - is_flux=self.model_config.is_flux, - is_lumina2=self.model_config.is_lumina2, - is_ssd=self.model_config.is_ssd, - is_vega=self.model_config.is_vega, - dropout=self.network_config.dropout, - use_text_encoder_1=self.model_config.use_text_encoder_1, - use_text_encoder_2=self.model_config.use_text_encoder_2, - use_bias=False, - is_lorm=False, - network_config=self.network_config, - network_type=self.network_config.type, - transformer_only=self.network_config.transformer_only, - is_transformer=sd.is_transformer, - base_model=sd, - **network_kwargs - ) - self.control_lora.force_to(self.device_torch, dtype=torch.float32) - self.control_lora._update_torch_multiplier() - self.control_lora.apply_to( - sd.text_encoder, - sd.unet, - self.train_config.train_text_encoder, - self.train_config.train_unet - ) - self.control_lora.can_merge_in = False - self.control_lora.prepare_grad_etc(sd.text_encoder, sd.unet) - if self.train_config.gradient_checkpointing: - self.control_lora.enable_gradient_checkpointing() + self.control_lora = None + + if self.network_config is not None: + + network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs + if hasattr(sd, 'target_lora_modules'): + network_kwargs['target_lin_modules'] = self.sd.target_lora_modules + + if 'ignore_if_contains' not in network_kwargs: + network_kwargs['ignore_if_contains'] = [] + + # always ignore x_embedder + network_kwargs['ignore_if_contains'].append('x_embedder') + + self.control_lora = LoRASpecialNetwork( + text_encoder=sd.text_encoder, + unet=sd.unet, + lora_dim=self.network_config.linear, + multiplier=1.0, + alpha=self.network_config.linear_alpha, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + conv_lora_dim=self.network_config.conv, + conv_alpha=self.network_config.conv_alpha, + is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, + is_v2=self.model_config.is_v2, + is_v3=self.model_config.is_v3, + is_pixart=self.model_config.is_pixart, + is_auraflow=self.model_config.is_auraflow, + is_flux=self.model_config.is_flux, + is_lumina2=self.model_config.is_lumina2, + is_ssd=self.model_config.is_ssd, + is_vega=self.model_config.is_vega, + dropout=self.network_config.dropout, + use_text_encoder_1=self.model_config.use_text_encoder_1, + use_text_encoder_2=self.model_config.use_text_encoder_2, + use_bias=False, + is_lorm=False, + network_config=self.network_config, + network_type=self.network_config.type, + transformer_only=self.network_config.transformer_only, + is_transformer=sd.is_transformer, + base_model=sd, + **network_kwargs + ) + self.control_lora.force_to(self.device_torch, dtype=torch.float32) + self.control_lora._update_torch_multiplier() + self.control_lora.apply_to( + sd.text_encoder, + sd.unet, + self.train_config.train_text_encoder, + self.train_config.train_unet + ) + self.control_lora.can_merge_in = False + self.control_lora.prepare_grad_etc(sd.text_encoder, sd.unet) + if self.train_config.gradient_checkpointing: + self.control_lora.enable_gradient_checkpointing() self.x_embedder = ImgEmbedder.from_model(sd.unet, self) self.x_embedder.to(self.device_torch) def get_params(self): - # LyCORIS doesnt have default_lr - config = { - 'text_encoder_lr': self.train_config.lr, - 'unet_lr': self.train_config.lr, - } - sig = inspect.signature(self.control_lora.prepare_optimizer_params) - if 'default_lr' in sig.parameters: - config['default_lr'] = self.train_config.lr - if 'learning_rate' in sig.parameters: - config['learning_rate'] = self.train_config.lr - params_net = self.control_lora.prepare_optimizer_params( - **config - ) - - # we want only tensors here - params = [] - for p in params_net: - if isinstance(p, dict): - params += p["params"] - elif isinstance(p, torch.Tensor): - params.append(p) - elif isinstance(p, list): - params += p + if self.control_lora is not None: + config = { + 'text_encoder_lr': self.train_config.lr, + 'unet_lr': self.train_config.lr, + } + sig = inspect.signature(self.control_lora.prepare_optimizer_params) + if 'default_lr' in sig.parameters: + config['default_lr'] = self.train_config.lr + if 'learning_rate' in sig.parameters: + config['learning_rate'] = self.train_config.lr + params_net = self.control_lora.prepare_optimizer_params( + **config + ) + + # we want only tensors here + params = [] + for p in params_net: + if isinstance(p, dict): + params += p["params"] + elif isinstance(p, torch.Tensor): + params.append(p) + elif isinstance(p, list): + params += p + else: + params = [] + + # make sure the embedder is float32 + self.x_embedder.to(torch.float32) params += list(self.x_embedder.parameters()) @@ -223,11 +228,15 @@ class ControlLoraAdapter(torch.nn.Module): lora_sd[key] = value # todo process state dict before loading - self.control_lora.load_weights(lora_sd) - self.x_embedder.load_state_dict(img_embedder_sd, strict=strict) + if self.control_lora is not None: + self.control_lora.load_weights(lora_sd) + self.x_embedder.load_state_dict(img_embedder_sd, strict=False) def get_state_dict(self): - lora_sd = self.control_lora.get_state_dict(dtype=torch.float32) + if self.control_lora is not None: + lora_sd = self.control_lora.get_state_dict(dtype=torch.float32) + else: + lora_sd = {} # todo make sure we match loras elseware. img_embedder_sd = self.x_embedder.state_dict() for key, value in img_embedder_sd.items():