mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
265 lines
8.7 KiB
Python
265 lines
8.7 KiB
Python
import math
|
|
import weakref
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from typing import TYPE_CHECKING, List, Dict, Any
|
|
from toolkit.models.clip_fusion import ZipperBlock
|
|
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
|
|
import sys
|
|
from collections import OrderedDict
|
|
|
|
if TYPE_CHECKING:
|
|
from toolkit.lora_special import LoRAModule
|
|
from toolkit.stable_diffusion_model import StableDiffusion
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
def __init__(self, d_model, nhead, dim_feedforward):
|
|
super().__init__()
|
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
|
|
self.cross_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
|
|
self.feed_forward = nn.Sequential(
|
|
nn.Linear(d_model, dim_feedforward),
|
|
nn.ReLU(),
|
|
nn.Linear(dim_feedforward, d_model)
|
|
)
|
|
self.norm1 = nn.LayerNorm(d_model)
|
|
self.norm2 = nn.LayerNorm(d_model)
|
|
self.norm3 = nn.LayerNorm(d_model)
|
|
|
|
def forward(self, x, cross_attn_input):
|
|
# Self-attention
|
|
attn_output, _ = self.self_attn(x, x, x)
|
|
x = self.norm1(x + attn_output)
|
|
|
|
# Cross-attention
|
|
cross_attn_output, _ = self.cross_attn(x, cross_attn_input, cross_attn_input)
|
|
x = self.norm2(x + cross_attn_output)
|
|
|
|
# Feed-forward
|
|
ff_output = self.feed_forward(x)
|
|
x = self.norm3(x + ff_output)
|
|
|
|
return x
|
|
|
|
|
|
class InstantLoRAMidModule(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
index: int,
|
|
lora_module: 'LoRAModule',
|
|
instant_lora_module: 'InstantLoRAModule',
|
|
up_shape: list = None,
|
|
down_shape: list = None,
|
|
):
|
|
super(InstantLoRAMidModule, self).__init__()
|
|
self.up_shape = up_shape
|
|
self.down_shape = down_shape
|
|
self.index = index
|
|
self.lora_module_ref = weakref.ref(lora_module)
|
|
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
|
|
|
|
self.embed = None
|
|
|
|
def down_forward(self, x, *args, **kwargs):
|
|
# get the embed
|
|
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
|
down_size = math.prod(self.down_shape)
|
|
down_weight = self.embed[:, :down_size]
|
|
|
|
batch_size = x.shape[0]
|
|
|
|
# unconditional
|
|
if down_weight.shape[0] * 2 == batch_size:
|
|
down_weight = torch.cat([down_weight] * 2, dim=0)
|
|
|
|
weight_chunks = torch.chunk(down_weight, batch_size, dim=0)
|
|
x_chunks = torch.chunk(x, batch_size, dim=0)
|
|
|
|
x_out = []
|
|
for i in range(batch_size):
|
|
weight_chunk = weight_chunks[i]
|
|
x_chunk = x_chunks[i]
|
|
# reshape
|
|
weight_chunk = weight_chunk.view(self.down_shape)
|
|
# check if is conv or linear
|
|
if len(weight_chunk.shape) == 4:
|
|
padding = 0
|
|
if weight_chunk.shape[-1] == 3:
|
|
padding = 1
|
|
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding)
|
|
else:
|
|
# run a simple linear layer with the down weight
|
|
x_chunk = x_chunk @ weight_chunk.T
|
|
x_out.append(x_chunk)
|
|
x = torch.cat(x_out, dim=0)
|
|
return x
|
|
|
|
|
|
def up_forward(self, x, *args, **kwargs):
|
|
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
|
up_size = math.prod(self.up_shape)
|
|
up_weight = self.embed[:, -up_size:]
|
|
|
|
batch_size = x.shape[0]
|
|
|
|
# unconditional
|
|
if up_weight.shape[0] * 2 == batch_size:
|
|
up_weight = torch.cat([up_weight] * 2, dim=0)
|
|
|
|
weight_chunks = torch.chunk(up_weight, batch_size, dim=0)
|
|
x_chunks = torch.chunk(x, batch_size, dim=0)
|
|
|
|
x_out = []
|
|
for i in range(batch_size):
|
|
weight_chunk = weight_chunks[i]
|
|
x_chunk = x_chunks[i]
|
|
# reshape
|
|
weight_chunk = weight_chunk.view(self.up_shape)
|
|
# check if is conv or linear
|
|
if len(weight_chunk.shape) == 4:
|
|
padding = 0
|
|
if weight_chunk.shape[-1] == 3:
|
|
padding = 1
|
|
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding)
|
|
else:
|
|
# run a simple linear layer with the down weight
|
|
x_chunk = x_chunk @ weight_chunk.T
|
|
x_out.append(x_chunk)
|
|
x = torch.cat(x_out, dim=0)
|
|
return x
|
|
|
|
|
|
# Initialize the network
|
|
# num_blocks = 8
|
|
# d_model = 1024 # Adjust as needed
|
|
# nhead = 16 # Adjust as needed
|
|
# dim_feedforward = 4096 # Adjust as needed
|
|
# latent_dim = 1695744
|
|
|
|
class LoRAFormer(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
num_blocks,
|
|
d_model=1024,
|
|
nhead=16,
|
|
dim_feedforward=4096,
|
|
sd: 'StableDiffusion'=None,
|
|
):
|
|
super(LoRAFormer, self).__init__()
|
|
# self.linear = torch.nn.Linear(2, 1)
|
|
self.sd_ref = weakref.ref(sd)
|
|
self.dim = sd.network.lora_dim
|
|
|
|
# stores the projection vector. Grabbed by modules
|
|
self.img_embeds: List[torch.Tensor] = None
|
|
|
|
# disable merging in. It is slower on inference
|
|
self.sd_ref().network.can_merge_in = False
|
|
|
|
self.ilora_modules = torch.nn.ModuleList()
|
|
|
|
lora_modules = self.sd_ref().network.get_all_modules()
|
|
|
|
output_size = 0
|
|
|
|
self.embed_lengths = []
|
|
self.weight_mapping = []
|
|
|
|
for idx, lora_module in enumerate(lora_modules):
|
|
module_dict = lora_module.state_dict()
|
|
down_shape = list(module_dict['lora_down.weight'].shape)
|
|
up_shape = list(module_dict['lora_up.weight'].shape)
|
|
|
|
self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]])
|
|
|
|
module_size = math.prod(down_shape) + math.prod(up_shape)
|
|
output_size += module_size
|
|
self.embed_lengths.append(module_size)
|
|
|
|
|
|
# add a new mid module that will take the original forward and add a vector to it
|
|
# this will be used to add the vector to the original forward
|
|
instant_module = InstantLoRAMidModule(
|
|
idx,
|
|
lora_module,
|
|
self,
|
|
up_shape=up_shape,
|
|
down_shape=down_shape
|
|
)
|
|
|
|
self.ilora_modules.append(instant_module)
|
|
|
|
# replace the LoRA forwards
|
|
lora_module.lora_down.forward = instant_module.down_forward
|
|
lora_module.lora_up.forward = instant_module.up_forward
|
|
|
|
|
|
self.output_size = output_size
|
|
|
|
self.latent = nn.Parameter(torch.randn(1, output_size))
|
|
self.latent_proj = nn.Linear(output_size, d_model)
|
|
self.blocks = nn.ModuleList([
|
|
TransformerBlock(d_model, nhead, dim_feedforward)
|
|
for _ in range(num_blocks)
|
|
])
|
|
self.final_proj = nn.Linear(d_model, output_size)
|
|
|
|
self.migrate_weight_mapping()
|
|
|
|
def migrate_weight_mapping(self):
|
|
return
|
|
# # changes the names of the modules to common ones
|
|
# keymap = self.sd_ref().network.get_keymap()
|
|
# save_keymap = {}
|
|
# if keymap is not None:
|
|
# for ldm_key, diffusers_key in keymap.items():
|
|
# # invert them
|
|
# save_keymap[diffusers_key] = ldm_key
|
|
#
|
|
# new_keymap = {}
|
|
# for key, value in self.weight_mapping:
|
|
# if key in save_keymap:
|
|
# new_keymap[save_keymap[key]] = value
|
|
# else:
|
|
# print(f"Key {key} not found in keymap")
|
|
# new_keymap[key] = value
|
|
# self.weight_mapping = new_keymap
|
|
# else:
|
|
# print("No keymap found. Using default names")
|
|
# return
|
|
|
|
|
|
def forward(self, img_embeds):
|
|
# expand token rank if only rank 2
|
|
if len(img_embeds.shape) == 2:
|
|
img_embeds = img_embeds.unsqueeze(1)
|
|
|
|
# resample the image embeddings
|
|
img_embeds = self.resampler(img_embeds)
|
|
img_embeds = self.proj_module(img_embeds)
|
|
if len(img_embeds.shape) == 3:
|
|
# merge the heads
|
|
img_embeds = img_embeds.mean(dim=1)
|
|
|
|
self.img_embeds = []
|
|
# get all the slices
|
|
start = 0
|
|
for length in self.embed_lengths:
|
|
self.img_embeds.append(img_embeds[:, start:start+length])
|
|
start += length
|
|
|
|
|
|
def get_additional_save_metadata(self) -> Dict[str, Any]:
|
|
# save the weight mapping
|
|
return {
|
|
"weight_mapping": self.weight_mapping,
|
|
"num_heads": self.num_heads,
|
|
"vision_hidden_size": self.vision_hidden_size,
|
|
"head_dim": self.head_dim,
|
|
"vision_tokens": self.vision_tokens,
|
|
"output_size": self.output_size,
|
|
}
|
|
|