mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-02 12:11:16 +00:00
Changed control lora to only have new weights and leave other input weights alone for more flexability of using multiple ones together.
This commit is contained in:
@@ -29,16 +29,17 @@ class ImgEmbedder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
adapter: 'ControlLoraAdapter',
|
||||
orig_layer: torch.nn.Module,
|
||||
in_channels=128,
|
||||
out_channels=3072,
|
||||
bias=True
|
||||
orig_layer: torch.nn.Linear,
|
||||
in_channels=64,
|
||||
out_channels=3072
|
||||
):
|
||||
super().__init__()
|
||||
# only do the weight for the new input. We combine with the original linear layer
|
||||
init = torch.randn(out_channels, in_channels, device=orig_layer.weight.device, dtype=orig_layer.weight.dtype) * 0.01
|
||||
self.weight = torch.nn.Parameter(init)
|
||||
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
self.orig_layer_ref: weakref.ref = weakref.ref(orig_layer)
|
||||
self.lora_A = torch.nn.Linear(in_channels, in_channels, bias=False) # lora down
|
||||
self.lora_B = torch.nn.Linear(in_channels, out_channels, bias=bias) # lora up
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@@ -52,26 +53,13 @@ class ImgEmbedder(torch.nn.Module):
|
||||
img_embedder = cls(
|
||||
adapter,
|
||||
orig_layer=x_embedder,
|
||||
in_channels=x_embedder.in_features * num_channel_multiplier, # adding additional control img channels
|
||||
in_channels=x_embedder.in_features * (num_channel_multiplier - 1), # only our new channels
|
||||
out_channels=x_embedder.out_features,
|
||||
bias=x_embedder.bias is not None
|
||||
)
|
||||
|
||||
# hijack the forward method
|
||||
x_embedder._orig_ctrl_lora_forward = x_embedder.forward
|
||||
x_embedder.forward = img_embedder.forward
|
||||
dtype = x_embedder.weight.dtype
|
||||
device = x_embedder.weight.device
|
||||
|
||||
# since we are adding control channels, we want those channels to be zero starting out
|
||||
# so they have no effect. It will match lora_B weight and bias, and we concat 0s for the input of the new channels
|
||||
# lora_a needs to be identity so that lora_b output matches lora_a output on init
|
||||
img_embedder.lora_A.weight.data = torch.eye(x_embedder.in_features * num_channel_multiplier).to(dtype=torch.float32, device=device)
|
||||
weight_b = x_embedder.weight.data.clone().to(dtype=torch.float32, device=device)
|
||||
# concat 0s for the new channels
|
||||
weight_b = torch.cat([weight_b, torch.zeros(weight_b.shape[0], weight_b.shape[1] * (num_channel_multiplier - 1)).to(device)], dim=1)
|
||||
img_embedder.lora_B.weight.data = weight_b.clone().to(dtype=torch.float32)
|
||||
img_embedder.lora_B.bias.data = x_embedder.bias.data.clone().to(dtype=torch.float32)
|
||||
|
||||
# update the config of the transformer
|
||||
model.config.in_channels = model.config.in_channels * num_channel_multiplier
|
||||
@@ -89,18 +77,29 @@ class ImgEmbedder(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
if not self.is_active:
|
||||
# make sure lora is not active
|
||||
if self.adapter_ref().control_lora is not None:
|
||||
self.adapter_ref().control_lora.is_active = False
|
||||
return self.orig_layer_ref()._orig_ctrl_lora_forward(x)
|
||||
|
||||
# make sure lora is active
|
||||
if self.adapter_ref().control_lora is not None:
|
||||
self.adapter_ref().control_lora.is_active = True
|
||||
|
||||
orig_device = x.device
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(self.lora_A.weight.device, dtype=self.lora_A.weight.dtype)
|
||||
|
||||
x = self.lora_A(x)
|
||||
x = self.lora_B(x)
|
||||
x = x.to(self.weight.device, dtype=self.weight.dtype)
|
||||
|
||||
orig_weight = self.orig_layer_ref().weight.data.detach()
|
||||
orig_weight = orig_weight.to(self.weight.device, dtype=self.weight.dtype)
|
||||
linear_weight = torch.cat([orig_weight, self.weight], dim=1)
|
||||
|
||||
bias = None
|
||||
if self.orig_layer_ref().bias is not None:
|
||||
bias = self.orig_layer_ref().bias.data.detach().to(self.weight.device, dtype=self.weight.dtype)
|
||||
|
||||
x = torch.nn.functional.linear(x, linear_weight, bias)
|
||||
|
||||
x = x.to(orig_device, dtype=orig_dtype)
|
||||
return x
|
||||
|
||||
@@ -120,8 +119,10 @@ class ControlLoraAdapter(torch.nn.Module):
|
||||
self.model_config: ModelConfig = sd.model_config
|
||||
self.network_config = config.lora_config
|
||||
self.train_config = train_config
|
||||
if self.network_config is None:
|
||||
raise ValueError("LoRA config is missing")
|
||||
self.device_torch = sd.device_torch
|
||||
self.control_lora = None
|
||||
|
||||
if self.network_config is not None:
|
||||
|
||||
network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs
|
||||
if hasattr(sd, 'target_lora_modules'):
|
||||
@@ -133,7 +134,6 @@ class ControlLoraAdapter(torch.nn.Module):
|
||||
# always ignore x_embedder
|
||||
network_kwargs['ignore_if_contains'].append('x_embedder')
|
||||
|
||||
self.device_torch = sd.device_torch
|
||||
self.control_lora = LoRASpecialNetwork(
|
||||
text_encoder=sd.text_encoder,
|
||||
unet=sd.unet,
|
||||
@@ -182,7 +182,7 @@ class ControlLoraAdapter(torch.nn.Module):
|
||||
self.x_embedder.to(self.device_torch)
|
||||
|
||||
def get_params(self):
|
||||
# LyCORIS doesnt have default_lr
|
||||
if self.control_lora is not None:
|
||||
config = {
|
||||
'text_encoder_lr': self.train_config.lr,
|
||||
'unet_lr': self.train_config.lr,
|
||||
@@ -205,6 +205,11 @@ class ControlLoraAdapter(torch.nn.Module):
|
||||
params.append(p)
|
||||
elif isinstance(p, list):
|
||||
params += p
|
||||
else:
|
||||
params = []
|
||||
|
||||
# make sure the embedder is float32
|
||||
self.x_embedder.to(torch.float32)
|
||||
|
||||
params += list(self.x_embedder.parameters())
|
||||
|
||||
@@ -223,11 +228,15 @@ class ControlLoraAdapter(torch.nn.Module):
|
||||
lora_sd[key] = value
|
||||
|
||||
# todo process state dict before loading
|
||||
if self.control_lora is not None:
|
||||
self.control_lora.load_weights(lora_sd)
|
||||
self.x_embedder.load_state_dict(img_embedder_sd, strict=strict)
|
||||
self.x_embedder.load_state_dict(img_embedder_sd, strict=False)
|
||||
|
||||
def get_state_dict(self):
|
||||
if self.control_lora is not None:
|
||||
lora_sd = self.control_lora.get_state_dict(dtype=torch.float32)
|
||||
else:
|
||||
lora_sd = {}
|
||||
# todo make sure we match loras elseware.
|
||||
img_embedder_sd = self.x_embedder.state_dict()
|
||||
for key, value in img_embedder_sd.items():
|
||||
|
||||
Reference in New Issue
Block a user