Added ability to target parts of lora for ilora

This commit is contained in:
Jaret Burkett
2024-07-20 22:45:52 +00:00
parent 4c249cf607
commit c2c4e8cf34
3 changed files with 59 additions and 22 deletions

View File

@@ -194,6 +194,9 @@ class AdapterConfig:
# for ilora
self.head_dim: int = kwargs.get('head_dim', 1024)
self.num_heads: int = kwargs.get('num_heads', 1)
self.ilora_down: bool = kwargs.get('ilora_down', True)
self.ilora_mid: bool = kwargs.get('ilora_mid', True)
self.ilora_up: bool = kwargs.get('ilora_up', True)
class EmbeddingConfig:

View File

@@ -148,7 +148,8 @@ class CustomAdapter(torch.nn.Module):
vision_hidden_size=vision_hidden_size,
head_dim=self.config.head_dim,
num_heads=self.config.num_heads,
sd=self.sd_ref()
sd=self.sd_ref(),
config=self.config
)
elif self.adapter_type == 'text_encoder':
if self.config.text_encoder_arch == 't5':

View File

@@ -1,6 +1,7 @@
import math
import weakref
from toolkit.config_modules import AdapterConfig
import torch
import torch.nn as nn
from typing import TYPE_CHECKING, List, Dict, Any
@@ -130,14 +131,23 @@ class InstantLoRAMidModule(torch.nn.Module):
self.index = index
self.lora_module_ref = weakref.ref(lora_module)
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
self.do_up = instant_lora_module.config.ilora_up
self.do_down = instant_lora_module.config.ilora_down
self.do_mid = instant_lora_module.config.ilora_mid
self.down_dim = self.down_shape[1] if self.do_down else 0
self.mid_dim = self.up_shape[1] if self.do_mid else 0
self.out_dim = self.up_shape[0] if self.do_down else 0
self.embed = None
def down_forward(self, x, *args, **kwargs):
if not self.do_down:
return self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs)
# get the embed
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
in_dim = self.down_shape[1]
down_weight = self.embed[:, :in_dim]
down_weight = self.embed[:, :self.down_dim]
batch_size = x.shape[0]
@@ -169,41 +179,58 @@ class InstantLoRAMidModule(torch.nn.Module):
def up_forward(self, x, *args, **kwargs):
if not self.do_up and not self.do_mid:
return self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs)
# get the embed
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
in_dim = self.down_shape[1]
mid_dim = self.up_shape[1]
out_dim = self.up_shape[0]
mid_weight = self.embed[:, in_dim:in_dim+mid_dim]
up_weight = self.embed[:, -out_dim:]
if self.do_mid:
mid_weight = self.embed[:, self.down_dim:self.down_dim+self.mid_dim]
else:
mid_weight = None
if self.do_up:
up_weight = self.embed[:, -self.out_dim:]
else:
up_weight = None
batch_size = x.shape[0]
# unconditional
if up_weight.shape[0] * 2 == batch_size:
up_weight = torch.cat([up_weight] * 2, dim=0)
mid_weight = torch.cat([mid_weight] * 2, dim=0)
if up_weight is not None:
if up_weight.shape[0] * 2 == batch_size:
up_weight = torch.cat([up_weight] * 2, dim=0)
if mid_weight is not None:
if mid_weight.shape[0] * 2 == batch_size:
mid_weight = torch.cat([mid_weight] * 2, dim=0)
try:
if len(x.shape) == 4:
# conv
up_weight = up_weight.view(batch_size, -1, 1, 1)
mid_weight = mid_weight.view(batch_size, -1, 1, 1)
if up_weight is not None:
up_weight = up_weight.view(batch_size, -1, 1, 1)
if mid_weight is not None:
mid_weight = mid_weight.view(batch_size, -1, 1, 1)
if x.shape[1] != mid_weight.shape[1]:
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
elif len(x.shape) == 2:
up_weight = up_weight.view(batch_size, -1)
mid_weight = mid_weight.view(batch_size, -1)
if up_weight is not None:
up_weight = up_weight.view(batch_size, -1)
if mid_weight is not None:
mid_weight = mid_weight.view(batch_size, -1)
if x.shape[1] != mid_weight.shape[1]:
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
else:
up_weight = up_weight.view(batch_size, 1, -1)
mid_weight = mid_weight.view(batch_size, 1, -1)
if up_weight is not None:
up_weight = up_weight.view(batch_size, 1, -1)
if mid_weight is not None:
mid_weight = mid_weight.view(batch_size, 1, -1)
if x.shape[2] != mid_weight.shape[2]:
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
# apply mid weight first
if mid_weight is not None:
x = x * mid_weight
x = self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs)
x = x * up_weight
if up_weight is not None:
x = x * up_weight
except Exception as e:
print(e)
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
@@ -220,7 +247,8 @@ class InstantLoRAModule(torch.nn.Module):
vision_tokens: int,
head_dim: int,
num_heads: int, # number of heads in the resampler
sd: 'StableDiffusion'
sd: 'StableDiffusion',
config: AdapterConfig
):
super(InstantLoRAModule, self).__init__()
# self.linear = torch.nn.Linear(2, 1)
@@ -230,6 +258,8 @@ class InstantLoRAModule(torch.nn.Module):
self.vision_tokens = vision_tokens
self.head_dim = head_dim
self.num_heads = num_heads
self.config: AdapterConfig = config
# stores the projection vector. Grabbed by modules
self.img_embeds: List[torch.Tensor] = None
@@ -260,9 +290,9 @@ class InstantLoRAModule(torch.nn.Module):
# linear weight shape is (out_features, in_features)
# just doing in dim and out dim
in_dim = down_shape[1]
mid_dim = down_shape[0]
out_dim = up_shape[0]
in_dim = down_shape[1] if self.config.ilora_down else 0
mid_dim = down_shape[0] if self.config.ilora_mid else 0
out_dim = up_shape[0] if self.config.ilora_up else 0
module_size = in_dim + mid_dim + out_dim
@@ -377,5 +407,8 @@ class InstantLoRAModule(torch.nn.Module):
"head_dim": self.head_dim,
"vision_tokens": self.vision_tokens,
"output_size": self.output_size,
"do_up": self.config.ilora_up,
"do_mid": self.config.ilora_mid,
"do_down": self.config.ilora_down,
}