Files
ai-toolkit/toolkit/models/ilora.py
2024-04-05 08:17:09 -06:00

134 lines
4.3 KiB
Python

import weakref
import torch
import torch.nn as nn
from typing import TYPE_CHECKING
from toolkit.models.clip_fusion import ZipperBlock
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
import sys
from toolkit.paths import REPOS_ROOT
sys.path.append(REPOS_ROOT)
from ipadapter.ip_adapter.resampler import Resampler
if TYPE_CHECKING:
from toolkit.lora_special import LoRAModule
from toolkit.stable_diffusion_model import StableDiffusion
class InstantLoRAMidModule(torch.nn.Module):
def __init__(
self,
dim: int,
index: int,
lora_module: 'LoRAModule',
instant_lora_module: 'InstantLoRAModule'
):
super(InstantLoRAMidModule, self).__init__()
self.dim = dim
self.index = index
self.lora_module_ref = weakref.ref(lora_module)
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
def forward(self, x, *args, **kwargs):
# get the vector
img_embeds = self.instant_lora_module_ref().img_embeds
# project it
scaler = img_embeds[:, self.index, :]
# remove the channel dim (index)
scaler = scaler.squeeze(1)
# double up if batch is 2x the size on x (cfg)
if x.shape[0] // 2 == scaler.shape[0]:
scaler = torch.cat([scaler, scaler], dim=0)
# multiply it by the scaler
try:
# reshape if needed
if len(x.shape) == 3:
scaler = scaler.unsqueeze(1)
except Exception as e:
print(e)
print(x.shape)
print(scaler.shape)
raise e
# apply tanh to limit values to -1 to 1
# scaler = torch.tanh(scaler)
return x * (scaler + 1.0)
class InstantLoRAModule(torch.nn.Module):
def __init__(
self,
vision_hidden_size: int,
vision_tokens: int,
sd: 'StableDiffusion'
):
super(InstantLoRAModule, self).__init__()
# self.linear = torch.nn.Linear(2, 1)
self.sd_ref = weakref.ref(sd)
self.dim = sd.network.lora_dim
self.vision_hidden_size = vision_hidden_size
self.vision_tokens = vision_tokens
# stores the projection vector. Grabbed by modules
self.img_embeds: torch.Tensor = None
# disable merging in. It is slower on inference
self.sd_ref().network.can_merge_in = False
self.ilora_modules = torch.nn.ModuleList()
lora_modules = self.sd_ref().network.get_all_modules()
# resample the output so each module gets one token with a size of its dim so we can multiply by that
# self.resampler = ZipperResampler(
# in_size=self.vision_hidden_size,
# in_tokens=self.vision_tokens,
# out_size=self.dim,
# out_tokens=len(lora_modules),
# hidden_size=self.vision_hidden_size,
# hidden_tokens=self.vision_tokens,
# num_blocks=1,
# )
# heads = 20
heads = 12
dim = 1280
output_dim = self.dim
self.resampler = Resampler(
dim=dim,
depth=4,
dim_head=64,
heads=heads,
num_queries=len(lora_modules),
embedding_dim=self.vision_hidden_size,
max_seq_len=self.vision_tokens,
output_dim=output_dim,
ff_mult=4
)
for idx, lora_module in enumerate(lora_modules):
# add a new mid module that will take the original forward and add a vector to it
# this will be used to add the vector to the original forward
mid_module = InstantLoRAMidModule(
self.dim,
idx,
lora_module,
self
)
self.ilora_modules.append(mid_module)
# replace the LoRA lora_mid
lora_module.lora_mid = mid_module.forward
# add a new mid module that will take the original forward and add a vector to it
# this will be used to add the vector to the original forward
def forward(self, img_embeds):
# expand token rank if only rank 2
if len(img_embeds.shape) == 2:
img_embeds = img_embeds.unsqueeze(1)
img_embeds = self.resampler(img_embeds)
self.img_embeds = img_embeds