Add a mergable linear to the mid of ilora

This commit is contained in:
Jaret Burkett
2024-07-20 21:17:53 -06:00
parent c51235c486
commit 6e92922c14

View File

@@ -9,8 +9,9 @@ from toolkit.models.clip_fusion import ZipperBlock
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
import sys
from toolkit.paths import REPOS_ROOT
sys.path.append(REPOS_ROOT)
from ipadapter.ip_adapter.resampler import Resampler
from ipadapter.ip_adapter.resampler import Resampler
from collections import OrderedDict
if TYPE_CHECKING:
@@ -41,6 +42,7 @@ class MLP(nn.Module):
x = x + residual
return x
class LoRAGenerator(torch.nn.Module):
def __init__(
self,
@@ -65,7 +67,8 @@ class LoRAGenerator(torch.nn.Module):
self.lin_in = nn.Linear(input_size, hidden_size)
self.mlp_blocks = nn.Sequential(*[
MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in range(num_mlp_layers)
MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in
range(num_mlp_layers)
])
self.head = nn.Linear(hidden_size, head_size, bias=False)
self.norm = nn.LayerNorm(head_size)
@@ -131,11 +134,11 @@ class InstantLoRAMidModule(torch.nn.Module):
self.index = index
self.lora_module_ref = weakref.ref(lora_module)
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
self.do_up = instant_lora_module.config.ilora_up
self.do_down = instant_lora_module.config.ilora_down
self.do_mid = instant_lora_module.config.ilora_mid
self.down_dim = self.down_shape[1] if self.do_down else 0
self.mid_dim = self.up_shape[1] if self.do_mid else 0
self.out_dim = self.up_shape[0] if self.do_up else 0
@@ -177,67 +180,74 @@ class InstantLoRAMidModule(torch.nn.Module):
return x
def up_forward(self, x, *args, **kwargs):
if not self.do_up and not self.do_mid:
# do mid here
x = self.mid_forward(x, *args, **kwargs)
if not self.do_up:
return self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs)
# get the embed
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
if self.do_mid:
mid_weight = self.embed[:, self.down_dim:self.down_dim+self.mid_dim]
else:
mid_weight = None
if self.do_up:
up_weight = self.embed[:, -self.out_dim:]
else:
up_weight = None
up_weight = self.embed[:, -self.out_dim:]
batch_size = x.shape[0]
# unconditional
if up_weight is not None:
if up_weight.shape[0] * 2 == batch_size:
up_weight = torch.cat([up_weight] * 2, dim=0)
if mid_weight is not None:
if mid_weight.shape[0] * 2 == batch_size:
mid_weight = torch.cat([mid_weight] * 2, dim=0)
if up_weight.shape[0] * 2 == batch_size:
up_weight = torch.cat([up_weight] * 2, dim=0)
try:
if len(x.shape) == 4:
# conv
if up_weight is not None:
up_weight = up_weight.view(batch_size, -1, 1, 1)
if mid_weight is not None:
mid_weight = mid_weight.view(batch_size, -1, 1, 1)
if x.shape[1] != mid_weight.shape[1]:
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
up_weight = up_weight.view(batch_size, -1, 1, 1)
elif len(x.shape) == 2:
if up_weight is not None:
up_weight = up_weight.view(batch_size, -1)
if mid_weight is not None:
mid_weight = mid_weight.view(batch_size, -1)
if x.shape[1] != mid_weight.shape[1]:
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
up_weight = up_weight.view(batch_size, -1)
else:
if up_weight is not None:
up_weight = up_weight.view(batch_size, 1, -1)
if mid_weight is not None:
mid_weight = mid_weight.view(batch_size, 1, -1)
if x.shape[2] != mid_weight.shape[2]:
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
# apply mid weight first
if mid_weight is not None:
x = x * mid_weight
up_weight = up_weight.view(batch_size, 1, -1)
x = self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs)
if up_weight is not None:
x = x * up_weight
x = x * up_weight
except Exception as e:
print(e)
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
return x
def mid_forward(self, x, *args, **kwargs):
if not self.do_mid:
return self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs)
batch_size = x.shape[0]
# get the embed
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
mid_weight = self.embed[:, self.down_dim:self.down_dim + self.mid_dim * self.mid_dim]
# unconditional
if mid_weight.shape[0] * 2 == batch_size:
mid_weight = torch.cat([mid_weight] * 2, dim=0)
weight_chunks = torch.chunk(mid_weight, batch_size, dim=0)
x_chunks = torch.chunk(x, batch_size, dim=0)
x_out = []
for i in range(batch_size):
weight_chunk = weight_chunks[i]
x_chunk = x_chunks[i]
# reshape
if len(x_chunk.shape) == 4:
# conv
weight_chunk = weight_chunk.view(self.mid_dim, self.mid_dim, 1, 1)
else:
weight_chunk = weight_chunk.view(self.mid_dim, self.mid_dim)
# check if is conv or linear
if len(weight_chunk.shape) == 4:
padding = 0
if weight_chunk.shape[-1] == 3:
padding = 1
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding)
else:
# run a simple linear layer with the down weight
x_chunk = x_chunk @ weight_chunk.T
x_out.append(x_chunk)
x = torch.cat(x_out, dim=0)
return x
class InstantLoRAModule(torch.nn.Module):
@@ -246,7 +256,7 @@ class InstantLoRAModule(torch.nn.Module):
vision_hidden_size: int,
vision_tokens: int,
head_dim: int,
num_heads: int, # number of heads in the resampler
num_heads: int, # number of heads in the resampler
sd: 'StableDiffusion',
config: AdapterConfig
):
@@ -258,7 +268,7 @@ class InstantLoRAModule(torch.nn.Module):
self.vision_tokens = vision_tokens
self.head_dim = head_dim
self.num_heads = num_heads
self.config: AdapterConfig = config
# stores the projection vector. Grabbed by modules
@@ -291,11 +301,10 @@ class InstantLoRAModule(torch.nn.Module):
# just doing in dim and out dim
in_dim = down_shape[1] if self.config.ilora_down else 0
mid_dim = down_shape[0] if self.config.ilora_mid else 0
mid_dim = down_shape[0] * down_shape[0] if self.config.ilora_mid else 0
out_dim = up_shape[0] if self.config.ilora_up else 0
module_size = in_dim + mid_dim + out_dim
output_size += module_size
self.embed_lengths.append(module_size)
@@ -317,7 +326,6 @@ class InstantLoRAModule(torch.nn.Module):
lora_module.lora_up.orig_forward = lora_module.lora_up.forward
lora_module.lora_up.forward = instant_module.up_forward
self.output_size = output_size
number_formatted_output_size = "{:,}".format(output_size)
@@ -377,7 +385,6 @@ class InstantLoRAModule(torch.nn.Module):
# print("No keymap found. Using default names")
# return
def forward(self, img_embeds):
# expand token rank if only rank 2
if len(img_embeds.shape) == 2:
@@ -394,10 +401,9 @@ class InstantLoRAModule(torch.nn.Module):
# get all the slices
start = 0
for length in self.embed_lengths:
self.img_embeds.append(img_embeds[:, start:start+length])
self.img_embeds.append(img_embeds[:, start:start + length])
start += length
def get_additional_save_metadata(self) -> Dict[str, Any]:
# save the weight mapping
return {
@@ -411,4 +417,3 @@ class InstantLoRAModule(torch.nn.Module):
"do_mid": self.config.ilora_mid,
"do_down": self.config.ilora_down,
}