mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-27 00:49:47 +00:00
Added ability to target parts of lora for ilora
This commit is contained in:
@@ -194,6 +194,9 @@ class AdapterConfig:
|
||||
# for ilora
|
||||
self.head_dim: int = kwargs.get('head_dim', 1024)
|
||||
self.num_heads: int = kwargs.get('num_heads', 1)
|
||||
self.ilora_down: bool = kwargs.get('ilora_down', True)
|
||||
self.ilora_mid: bool = kwargs.get('ilora_mid', True)
|
||||
self.ilora_up: bool = kwargs.get('ilora_up', True)
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
|
||||
@@ -148,7 +148,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
vision_hidden_size=vision_hidden_size,
|
||||
head_dim=self.config.head_dim,
|
||||
num_heads=self.config.num_heads,
|
||||
sd=self.sd_ref()
|
||||
sd=self.sd_ref(),
|
||||
config=self.config
|
||||
)
|
||||
elif self.adapter_type == 'text_encoder':
|
||||
if self.config.text_encoder_arch == 't5':
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import math
|
||||
import weakref
|
||||
|
||||
from toolkit.config_modules import AdapterConfig
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import TYPE_CHECKING, List, Dict, Any
|
||||
@@ -130,14 +131,23 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
self.index = index
|
||||
self.lora_module_ref = weakref.ref(lora_module)
|
||||
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
|
||||
|
||||
self.do_up = instant_lora_module.config.ilora_up
|
||||
self.do_down = instant_lora_module.config.ilora_down
|
||||
self.do_mid = instant_lora_module.config.ilora_mid
|
||||
|
||||
self.down_dim = self.down_shape[1] if self.do_down else 0
|
||||
self.mid_dim = self.up_shape[1] if self.do_mid else 0
|
||||
self.out_dim = self.up_shape[0] if self.do_down else 0
|
||||
|
||||
self.embed = None
|
||||
|
||||
def down_forward(self, x, *args, **kwargs):
|
||||
if not self.do_down:
|
||||
return self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs)
|
||||
# get the embed
|
||||
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
||||
in_dim = self.down_shape[1]
|
||||
down_weight = self.embed[:, :in_dim]
|
||||
down_weight = self.embed[:, :self.down_dim]
|
||||
|
||||
batch_size = x.shape[0]
|
||||
|
||||
@@ -169,41 +179,58 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
|
||||
|
||||
def up_forward(self, x, *args, **kwargs):
|
||||
if not self.do_up and not self.do_mid:
|
||||
return self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs)
|
||||
# get the embed
|
||||
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
||||
in_dim = self.down_shape[1]
|
||||
mid_dim = self.up_shape[1]
|
||||
out_dim = self.up_shape[0]
|
||||
mid_weight = self.embed[:, in_dim:in_dim+mid_dim]
|
||||
up_weight = self.embed[:, -out_dim:]
|
||||
if self.do_mid:
|
||||
mid_weight = self.embed[:, self.down_dim:self.down_dim+self.mid_dim]
|
||||
else:
|
||||
mid_weight = None
|
||||
if self.do_up:
|
||||
up_weight = self.embed[:, -self.out_dim:]
|
||||
else:
|
||||
up_weight = None
|
||||
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# unconditional
|
||||
if up_weight.shape[0] * 2 == batch_size:
|
||||
up_weight = torch.cat([up_weight] * 2, dim=0)
|
||||
mid_weight = torch.cat([mid_weight] * 2, dim=0)
|
||||
if up_weight is not None:
|
||||
if up_weight.shape[0] * 2 == batch_size:
|
||||
up_weight = torch.cat([up_weight] * 2, dim=0)
|
||||
if mid_weight is not None:
|
||||
if mid_weight.shape[0] * 2 == batch_size:
|
||||
mid_weight = torch.cat([mid_weight] * 2, dim=0)
|
||||
|
||||
try:
|
||||
if len(x.shape) == 4:
|
||||
# conv
|
||||
up_weight = up_weight.view(batch_size, -1, 1, 1)
|
||||
mid_weight = mid_weight.view(batch_size, -1, 1, 1)
|
||||
if up_weight is not None:
|
||||
up_weight = up_weight.view(batch_size, -1, 1, 1)
|
||||
if mid_weight is not None:
|
||||
mid_weight = mid_weight.view(batch_size, -1, 1, 1)
|
||||
if x.shape[1] != mid_weight.shape[1]:
|
||||
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
|
||||
elif len(x.shape) == 2:
|
||||
up_weight = up_weight.view(batch_size, -1)
|
||||
mid_weight = mid_weight.view(batch_size, -1)
|
||||
if up_weight is not None:
|
||||
up_weight = up_weight.view(batch_size, -1)
|
||||
if mid_weight is not None:
|
||||
mid_weight = mid_weight.view(batch_size, -1)
|
||||
if x.shape[1] != mid_weight.shape[1]:
|
||||
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
|
||||
else:
|
||||
up_weight = up_weight.view(batch_size, 1, -1)
|
||||
mid_weight = mid_weight.view(batch_size, 1, -1)
|
||||
if up_weight is not None:
|
||||
up_weight = up_weight.view(batch_size, 1, -1)
|
||||
if mid_weight is not None:
|
||||
mid_weight = mid_weight.view(batch_size, 1, -1)
|
||||
if x.shape[2] != mid_weight.shape[2]:
|
||||
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
|
||||
# apply mid weight first
|
||||
if mid_weight is not None:
|
||||
x = x * mid_weight
|
||||
x = self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs)
|
||||
x = x * up_weight
|
||||
if up_weight is not None:
|
||||
x = x * up_weight
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
|
||||
@@ -220,7 +247,8 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
vision_tokens: int,
|
||||
head_dim: int,
|
||||
num_heads: int, # number of heads in the resampler
|
||||
sd: 'StableDiffusion'
|
||||
sd: 'StableDiffusion',
|
||||
config: AdapterConfig
|
||||
):
|
||||
super(InstantLoRAModule, self).__init__()
|
||||
# self.linear = torch.nn.Linear(2, 1)
|
||||
@@ -230,6 +258,8 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
self.vision_tokens = vision_tokens
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.config: AdapterConfig = config
|
||||
|
||||
# stores the projection vector. Grabbed by modules
|
||||
self.img_embeds: List[torch.Tensor] = None
|
||||
@@ -260,9 +290,9 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
# linear weight shape is (out_features, in_features)
|
||||
|
||||
# just doing in dim and out dim
|
||||
in_dim = down_shape[1]
|
||||
mid_dim = down_shape[0]
|
||||
out_dim = up_shape[0]
|
||||
in_dim = down_shape[1] if self.config.ilora_down else 0
|
||||
mid_dim = down_shape[0] if self.config.ilora_mid else 0
|
||||
out_dim = up_shape[0] if self.config.ilora_up else 0
|
||||
module_size = in_dim + mid_dim + out_dim
|
||||
|
||||
|
||||
@@ -377,5 +407,8 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
"head_dim": self.head_dim,
|
||||
"vision_tokens": self.vision_tokens,
|
||||
"output_size": self.output_size,
|
||||
"do_up": self.config.ilora_up,
|
||||
"do_mid": self.config.ilora_mid,
|
||||
"do_down": self.config.ilora_down,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user