mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-23 13:53:57 +00:00
Reworked ilora arch
This commit is contained in:
381
toolkit/models/ilora2.py
Normal file
381
toolkit/models/ilora2.py
Normal file
@@ -0,0 +1,381 @@
|
||||
import math
|
||||
import weakref
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import TYPE_CHECKING, List, Dict, Any
|
||||
from toolkit.models.clip_fusion import ZipperBlock
|
||||
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
|
||||
import sys
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from ipadapter.ip_adapter.resampler import Resampler
|
||||
from collections import OrderedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRAModule
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, hidden_dim, dropout=0.1, use_residual=True):
|
||||
super().__init__()
|
||||
if use_residual:
|
||||
assert in_dim == out_dim
|
||||
self.layernorm = nn.LayerNorm(in_dim)
|
||||
self.fc1 = nn.Linear(in_dim, hidden_dim)
|
||||
self.fc2 = nn.Linear(hidden_dim, out_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.use_residual = use_residual
|
||||
self.act_fn = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.layernorm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout(x)
|
||||
if self.use_residual:
|
||||
x = x + residual
|
||||
return x
|
||||
|
||||
class LoRAGenerator(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int = 768, # projection dimension
|
||||
hidden_size: int = 768,
|
||||
head_size: int = 512,
|
||||
num_heads: int = 1,
|
||||
num_mlp_layers: int = 1,
|
||||
output_size: int = 768,
|
||||
dropout: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.num_heads = num_heads
|
||||
self.simple = False
|
||||
|
||||
self.output_size = output_size
|
||||
|
||||
if self.simple:
|
||||
self.head = nn.Linear(input_size, head_size, bias=False)
|
||||
else:
|
||||
self.lin_in = nn.Linear(input_size, hidden_size)
|
||||
|
||||
self.mlp_blocks = nn.Sequential(*[
|
||||
MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in range(num_mlp_layers)
|
||||
])
|
||||
self.head = nn.Linear(hidden_size, head_size, bias=False)
|
||||
self.norm = nn.LayerNorm(head_size)
|
||||
|
||||
if num_heads == 1:
|
||||
self.output = nn.Linear(head_size, self.output_size)
|
||||
# for each output block. multiply weights by 0.01
|
||||
with torch.no_grad():
|
||||
self.output.weight.data *= 0.01
|
||||
else:
|
||||
head_output_size = output_size // num_heads
|
||||
self.outputs = nn.ModuleList([nn.Linear(head_size, head_output_size) for _ in range(num_heads)])
|
||||
# for each output block. multiply weights by 0.01
|
||||
with torch.no_grad():
|
||||
for output in self.outputs:
|
||||
output.weight.data *= 0.01
|
||||
|
||||
# allow get device
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
def forward(self, embedding):
|
||||
if len(embedding.shape) == 2:
|
||||
embedding = embedding.unsqueeze(1)
|
||||
|
||||
x = embedding
|
||||
|
||||
if not self.simple:
|
||||
x = self.lin_in(embedding)
|
||||
x = self.mlp_blocks(x)
|
||||
x = self.head(x)
|
||||
x = self.norm(x)
|
||||
|
||||
if self.num_heads == 1:
|
||||
x = self.output(x)
|
||||
else:
|
||||
out_chunks = torch.chunk(x, self.num_heads, dim=1)
|
||||
x = []
|
||||
for out_layer, chunk in zip(self.outputs, out_chunks):
|
||||
x.append(out_layer(chunk))
|
||||
x = torch.cat(x, dim=-1)
|
||||
|
||||
return x.squeeze(1)
|
||||
|
||||
|
||||
class InstantLoRAMidModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
index: int,
|
||||
lora_module: 'LoRAModule',
|
||||
instant_lora_module: 'InstantLoRAModule',
|
||||
up_shape: list = None,
|
||||
down_shape: list = None,
|
||||
):
|
||||
super(InstantLoRAMidModule, self).__init__()
|
||||
self.up_shape = up_shape
|
||||
self.down_shape = down_shape
|
||||
self.index = index
|
||||
self.lora_module_ref = weakref.ref(lora_module)
|
||||
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
|
||||
|
||||
self.embed = None
|
||||
|
||||
def down_forward(self, x, *args, **kwargs):
|
||||
# get the embed
|
||||
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
||||
in_dim = self.down_shape[1]
|
||||
down_weight = self.embed[:, :in_dim]
|
||||
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# unconditional
|
||||
if down_weight.shape[0] * 2 == batch_size:
|
||||
down_weight = torch.cat([down_weight] * 2, dim=0)
|
||||
|
||||
try:
|
||||
if len(x.shape) == 4:
|
||||
# conv
|
||||
down_weight = down_weight.view(batch_size, -1, 1, 1)
|
||||
if x.shape[1] != down_weight.shape[1]:
|
||||
raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}")
|
||||
elif len(x.shape) == 2:
|
||||
down_weight = down_weight.view(batch_size, -1)
|
||||
if x.shape[1] != down_weight.shape[1]:
|
||||
raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}")
|
||||
else:
|
||||
down_weight = down_weight.view(batch_size, 1, -1)
|
||||
if x.shape[2] != down_weight.shape[2]:
|
||||
raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}")
|
||||
x = x * down_weight
|
||||
x = self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def up_forward(self, x, *args, **kwargs):
|
||||
# get the embed
|
||||
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
||||
in_dim = self.down_shape[1]
|
||||
mid_dim = self.up_shape[1]
|
||||
out_dim = self.up_shape[0]
|
||||
mid_weight = self.embed[:, in_dim:in_dim+mid_dim]
|
||||
up_weight = self.embed[:, -out_dim:]
|
||||
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# unconditional
|
||||
if up_weight.shape[0] * 2 == batch_size:
|
||||
up_weight = torch.cat([up_weight] * 2, dim=0)
|
||||
mid_weight = torch.cat([mid_weight] * 2, dim=0)
|
||||
|
||||
try:
|
||||
if len(x.shape) == 4:
|
||||
# conv
|
||||
up_weight = up_weight.view(batch_size, -1, 1, 1)
|
||||
mid_weight = mid_weight.view(batch_size, -1, 1, 1)
|
||||
if x.shape[1] != mid_weight.shape[1]:
|
||||
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
|
||||
elif len(x.shape) == 2:
|
||||
up_weight = up_weight.view(batch_size, -1)
|
||||
mid_weight = mid_weight.view(batch_size, -1)
|
||||
if x.shape[1] != mid_weight.shape[1]:
|
||||
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
|
||||
else:
|
||||
up_weight = up_weight.view(batch_size, 1, -1)
|
||||
mid_weight = mid_weight.view(batch_size, 1, -1)
|
||||
if x.shape[2] != mid_weight.shape[2]:
|
||||
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
|
||||
# apply mid weight first
|
||||
x = self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs)
|
||||
x = x * up_weight
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
class InstantLoRAModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vision_hidden_size: int,
|
||||
vision_tokens: int,
|
||||
head_dim: int,
|
||||
num_heads: int, # number of heads in the resampler
|
||||
sd: 'StableDiffusion'
|
||||
):
|
||||
super(InstantLoRAModule, self).__init__()
|
||||
# self.linear = torch.nn.Linear(2, 1)
|
||||
self.sd_ref = weakref.ref(sd)
|
||||
self.dim = sd.network.lora_dim
|
||||
self.vision_hidden_size = vision_hidden_size
|
||||
self.vision_tokens = vision_tokens
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
# stores the projection vector. Grabbed by modules
|
||||
self.img_embeds: List[torch.Tensor] = None
|
||||
|
||||
# disable merging in. It is slower on inference
|
||||
self.sd_ref().network.can_merge_in = False
|
||||
|
||||
self.ilora_modules = torch.nn.ModuleList()
|
||||
|
||||
lora_modules = self.sd_ref().network.get_all_modules()
|
||||
|
||||
output_size = 0
|
||||
|
||||
self.embed_lengths = []
|
||||
self.weight_mapping = []
|
||||
|
||||
for idx, lora_module in enumerate(lora_modules):
|
||||
module_dict = lora_module.state_dict()
|
||||
down_shape = list(module_dict['lora_down.weight'].shape)
|
||||
up_shape = list(module_dict['lora_up.weight'].shape)
|
||||
|
||||
self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]])
|
||||
|
||||
#
|
||||
# module_size = math.prod(down_shape) + math.prod(up_shape)
|
||||
|
||||
# conv weight shape is (out_channels, in_channels, kernel_size, kernel_size)
|
||||
# linear weight shape is (out_features, in_features)
|
||||
|
||||
# just doing in dim and out dim
|
||||
in_dim = down_shape[1]
|
||||
mid_dim = down_shape[0]
|
||||
out_dim = up_shape[0]
|
||||
module_size = in_dim + mid_dim + out_dim
|
||||
|
||||
|
||||
output_size += module_size
|
||||
self.embed_lengths.append(module_size)
|
||||
|
||||
# add a new mid module that will take the original forward and add a vector to it
|
||||
# this will be used to add the vector to the original forward
|
||||
instant_module = InstantLoRAMidModule(
|
||||
idx,
|
||||
lora_module,
|
||||
self,
|
||||
up_shape=up_shape,
|
||||
down_shape=down_shape
|
||||
)
|
||||
|
||||
self.ilora_modules.append(instant_module)
|
||||
|
||||
# replace the LoRA forwards
|
||||
lora_module.lora_down.orig_forward = lora_module.lora_down.forward
|
||||
lora_module.lora_down.forward = instant_module.down_forward
|
||||
lora_module.lora_up.orig_forward = lora_module.lora_up.forward
|
||||
lora_module.lora_up.forward = instant_module.up_forward
|
||||
|
||||
|
||||
self.output_size = output_size
|
||||
|
||||
number_formatted_output_size = "{:,}".format(output_size)
|
||||
|
||||
print(f" ILORA output size: {number_formatted_output_size}")
|
||||
|
||||
# if not evenly divisible, error
|
||||
if self.output_size % self.num_heads != 0:
|
||||
raise ValueError("Output size must be divisible by the number of heads")
|
||||
|
||||
self.head_output_size = self.output_size // self.num_heads
|
||||
|
||||
if vision_tokens > 1:
|
||||
self.resampler = Resampler(
|
||||
dim=vision_hidden_size,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=12,
|
||||
num_queries=num_heads, # output tokens
|
||||
embedding_dim=vision_hidden_size,
|
||||
max_seq_len=vision_tokens,
|
||||
output_dim=head_dim,
|
||||
apply_pos_emb=True, # this is new
|
||||
ff_mult=4
|
||||
)
|
||||
|
||||
self.proj_module = LoRAGenerator(
|
||||
input_size=head_dim,
|
||||
hidden_size=head_dim,
|
||||
head_size=head_dim,
|
||||
num_mlp_layers=1,
|
||||
num_heads=self.num_heads,
|
||||
output_size=self.output_size,
|
||||
)
|
||||
|
||||
self.migrate_weight_mapping()
|
||||
|
||||
def migrate_weight_mapping(self):
|
||||
return
|
||||
# # changes the names of the modules to common ones
|
||||
# keymap = self.sd_ref().network.get_keymap()
|
||||
# save_keymap = {}
|
||||
# if keymap is not None:
|
||||
# for ldm_key, diffusers_key in keymap.items():
|
||||
# # invert them
|
||||
# save_keymap[diffusers_key] = ldm_key
|
||||
#
|
||||
# new_keymap = {}
|
||||
# for key, value in self.weight_mapping:
|
||||
# if key in save_keymap:
|
||||
# new_keymap[save_keymap[key]] = value
|
||||
# else:
|
||||
# print(f"Key {key} not found in keymap")
|
||||
# new_keymap[key] = value
|
||||
# self.weight_mapping = new_keymap
|
||||
# else:
|
||||
# print("No keymap found. Using default names")
|
||||
# return
|
||||
|
||||
|
||||
def forward(self, img_embeds):
|
||||
# expand token rank if only rank 2
|
||||
if len(img_embeds.shape) == 2:
|
||||
img_embeds = img_embeds.unsqueeze(1)
|
||||
|
||||
# resample the image embeddings
|
||||
img_embeds = self.resampler(img_embeds)
|
||||
img_embeds = self.proj_module(img_embeds)
|
||||
if len(img_embeds.shape) == 3:
|
||||
# merge the heads
|
||||
img_embeds = img_embeds.mean(dim=1)
|
||||
|
||||
self.img_embeds = []
|
||||
# get all the slices
|
||||
start = 0
|
||||
for length in self.embed_lengths:
|
||||
self.img_embeds.append(img_embeds[:, start:start+length])
|
||||
start += length
|
||||
|
||||
|
||||
def get_additional_save_metadata(self) -> Dict[str, Any]:
|
||||
# save the weight mapping
|
||||
return {
|
||||
"weight_mapping": self.weight_mapping,
|
||||
"num_heads": self.num_heads,
|
||||
"vision_hidden_size": self.vision_hidden_size,
|
||||
"head_dim": self.head_dim,
|
||||
"vision_tokens": self.vision_tokens,
|
||||
"output_size": self.output_size,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user