mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Added ability to target parts of lora for ilora
This commit is contained in:
@@ -194,6 +194,9 @@ class AdapterConfig:
|
|||||||
# for ilora
|
# for ilora
|
||||||
self.head_dim: int = kwargs.get('head_dim', 1024)
|
self.head_dim: int = kwargs.get('head_dim', 1024)
|
||||||
self.num_heads: int = kwargs.get('num_heads', 1)
|
self.num_heads: int = kwargs.get('num_heads', 1)
|
||||||
|
self.ilora_down: bool = kwargs.get('ilora_down', True)
|
||||||
|
self.ilora_mid: bool = kwargs.get('ilora_mid', True)
|
||||||
|
self.ilora_up: bool = kwargs.get('ilora_up', True)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingConfig:
|
class EmbeddingConfig:
|
||||||
|
|||||||
@@ -148,7 +148,8 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
vision_hidden_size=vision_hidden_size,
|
vision_hidden_size=vision_hidden_size,
|
||||||
head_dim=self.config.head_dim,
|
head_dim=self.config.head_dim,
|
||||||
num_heads=self.config.num_heads,
|
num_heads=self.config.num_heads,
|
||||||
sd=self.sd_ref()
|
sd=self.sd_ref(),
|
||||||
|
config=self.config
|
||||||
)
|
)
|
||||||
elif self.adapter_type == 'text_encoder':
|
elif self.adapter_type == 'text_encoder':
|
||||||
if self.config.text_encoder_arch == 't5':
|
if self.config.text_encoder_arch == 't5':
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
|
from toolkit.config_modules import AdapterConfig
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import TYPE_CHECKING, List, Dict, Any
|
from typing import TYPE_CHECKING, List, Dict, Any
|
||||||
@@ -130,14 +131,23 @@ class InstantLoRAMidModule(torch.nn.Module):
|
|||||||
self.index = index
|
self.index = index
|
||||||
self.lora_module_ref = weakref.ref(lora_module)
|
self.lora_module_ref = weakref.ref(lora_module)
|
||||||
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
|
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
|
||||||
|
|
||||||
|
self.do_up = instant_lora_module.config.ilora_up
|
||||||
|
self.do_down = instant_lora_module.config.ilora_down
|
||||||
|
self.do_mid = instant_lora_module.config.ilora_mid
|
||||||
|
|
||||||
|
self.down_dim = self.down_shape[1] if self.do_down else 0
|
||||||
|
self.mid_dim = self.up_shape[1] if self.do_mid else 0
|
||||||
|
self.out_dim = self.up_shape[0] if self.do_down else 0
|
||||||
|
|
||||||
self.embed = None
|
self.embed = None
|
||||||
|
|
||||||
def down_forward(self, x, *args, **kwargs):
|
def down_forward(self, x, *args, **kwargs):
|
||||||
|
if not self.do_down:
|
||||||
|
return self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs)
|
||||||
# get the embed
|
# get the embed
|
||||||
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
||||||
in_dim = self.down_shape[1]
|
down_weight = self.embed[:, :self.down_dim]
|
||||||
down_weight = self.embed[:, :in_dim]
|
|
||||||
|
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
|
|
||||||
@@ -169,41 +179,58 @@ class InstantLoRAMidModule(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def up_forward(self, x, *args, **kwargs):
|
def up_forward(self, x, *args, **kwargs):
|
||||||
|
if not self.do_up and not self.do_mid:
|
||||||
|
return self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs)
|
||||||
# get the embed
|
# get the embed
|
||||||
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
||||||
in_dim = self.down_shape[1]
|
if self.do_mid:
|
||||||
mid_dim = self.up_shape[1]
|
mid_weight = self.embed[:, self.down_dim:self.down_dim+self.mid_dim]
|
||||||
out_dim = self.up_shape[0]
|
else:
|
||||||
mid_weight = self.embed[:, in_dim:in_dim+mid_dim]
|
mid_weight = None
|
||||||
up_weight = self.embed[:, -out_dim:]
|
if self.do_up:
|
||||||
|
up_weight = self.embed[:, -self.out_dim:]
|
||||||
|
else:
|
||||||
|
up_weight = None
|
||||||
|
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
|
|
||||||
# unconditional
|
# unconditional
|
||||||
if up_weight.shape[0] * 2 == batch_size:
|
if up_weight is not None:
|
||||||
up_weight = torch.cat([up_weight] * 2, dim=0)
|
if up_weight.shape[0] * 2 == batch_size:
|
||||||
mid_weight = torch.cat([mid_weight] * 2, dim=0)
|
up_weight = torch.cat([up_weight] * 2, dim=0)
|
||||||
|
if mid_weight is not None:
|
||||||
|
if mid_weight.shape[0] * 2 == batch_size:
|
||||||
|
mid_weight = torch.cat([mid_weight] * 2, dim=0)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if len(x.shape) == 4:
|
if len(x.shape) == 4:
|
||||||
# conv
|
# conv
|
||||||
up_weight = up_weight.view(batch_size, -1, 1, 1)
|
if up_weight is not None:
|
||||||
mid_weight = mid_weight.view(batch_size, -1, 1, 1)
|
up_weight = up_weight.view(batch_size, -1, 1, 1)
|
||||||
|
if mid_weight is not None:
|
||||||
|
mid_weight = mid_weight.view(batch_size, -1, 1, 1)
|
||||||
if x.shape[1] != mid_weight.shape[1]:
|
if x.shape[1] != mid_weight.shape[1]:
|
||||||
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
|
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
|
||||||
elif len(x.shape) == 2:
|
elif len(x.shape) == 2:
|
||||||
up_weight = up_weight.view(batch_size, -1)
|
if up_weight is not None:
|
||||||
mid_weight = mid_weight.view(batch_size, -1)
|
up_weight = up_weight.view(batch_size, -1)
|
||||||
|
if mid_weight is not None:
|
||||||
|
mid_weight = mid_weight.view(batch_size, -1)
|
||||||
if x.shape[1] != mid_weight.shape[1]:
|
if x.shape[1] != mid_weight.shape[1]:
|
||||||
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
|
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
|
||||||
else:
|
else:
|
||||||
up_weight = up_weight.view(batch_size, 1, -1)
|
if up_weight is not None:
|
||||||
mid_weight = mid_weight.view(batch_size, 1, -1)
|
up_weight = up_weight.view(batch_size, 1, -1)
|
||||||
|
if mid_weight is not None:
|
||||||
|
mid_weight = mid_weight.view(batch_size, 1, -1)
|
||||||
if x.shape[2] != mid_weight.shape[2]:
|
if x.shape[2] != mid_weight.shape[2]:
|
||||||
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
|
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
|
||||||
# apply mid weight first
|
# apply mid weight first
|
||||||
|
if mid_weight is not None:
|
||||||
|
x = x * mid_weight
|
||||||
x = self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs)
|
x = self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs)
|
||||||
x = x * up_weight
|
if up_weight is not None:
|
||||||
|
x = x * up_weight
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
|
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
|
||||||
@@ -220,7 +247,8 @@ class InstantLoRAModule(torch.nn.Module):
|
|||||||
vision_tokens: int,
|
vision_tokens: int,
|
||||||
head_dim: int,
|
head_dim: int,
|
||||||
num_heads: int, # number of heads in the resampler
|
num_heads: int, # number of heads in the resampler
|
||||||
sd: 'StableDiffusion'
|
sd: 'StableDiffusion',
|
||||||
|
config: AdapterConfig
|
||||||
):
|
):
|
||||||
super(InstantLoRAModule, self).__init__()
|
super(InstantLoRAModule, self).__init__()
|
||||||
# self.linear = torch.nn.Linear(2, 1)
|
# self.linear = torch.nn.Linear(2, 1)
|
||||||
@@ -230,6 +258,8 @@ class InstantLoRAModule(torch.nn.Module):
|
|||||||
self.vision_tokens = vision_tokens
|
self.vision_tokens = vision_tokens
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
self.config: AdapterConfig = config
|
||||||
|
|
||||||
# stores the projection vector. Grabbed by modules
|
# stores the projection vector. Grabbed by modules
|
||||||
self.img_embeds: List[torch.Tensor] = None
|
self.img_embeds: List[torch.Tensor] = None
|
||||||
@@ -260,9 +290,9 @@ class InstantLoRAModule(torch.nn.Module):
|
|||||||
# linear weight shape is (out_features, in_features)
|
# linear weight shape is (out_features, in_features)
|
||||||
|
|
||||||
# just doing in dim and out dim
|
# just doing in dim and out dim
|
||||||
in_dim = down_shape[1]
|
in_dim = down_shape[1] if self.config.ilora_down else 0
|
||||||
mid_dim = down_shape[0]
|
mid_dim = down_shape[0] if self.config.ilora_mid else 0
|
||||||
out_dim = up_shape[0]
|
out_dim = up_shape[0] if self.config.ilora_up else 0
|
||||||
module_size = in_dim + mid_dim + out_dim
|
module_size = in_dim + mid_dim + out_dim
|
||||||
|
|
||||||
|
|
||||||
@@ -377,5 +407,8 @@ class InstantLoRAModule(torch.nn.Module):
|
|||||||
"head_dim": self.head_dim,
|
"head_dim": self.head_dim,
|
||||||
"vision_tokens": self.vision_tokens,
|
"vision_tokens": self.vision_tokens,
|
||||||
"output_size": self.output_size,
|
"output_size": self.output_size,
|
||||||
|
"do_up": self.config.ilora_up,
|
||||||
|
"do_mid": self.config.ilora_mid,
|
||||||
|
"do_down": self.config.ilora_down,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user