mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Many bug fixes. Ip adapter bug fixes. Added noise to unconditional, it works better. added an ilora adapter for 1 shotting LoRAs
This commit is contained in:
104
toolkit/models/ilora.py
Normal file
104
toolkit/models/ilora.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import weakref
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.models.clip_fusion import ZipperBlock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRAModule
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
class InstantLoRAMidModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
vision_tokens: int,
|
||||
vision_hidden_size: int,
|
||||
lora_module: 'LoRAModule',
|
||||
instant_lora_module: 'InstantLoRAModule'
|
||||
):
|
||||
super(InstantLoRAMidModule, self).__init__()
|
||||
self.dim = dim
|
||||
self.vision_tokens = vision_tokens
|
||||
self.vision_hidden_size = vision_hidden_size
|
||||
self.lora_module_ref = weakref.ref(lora_module)
|
||||
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
|
||||
|
||||
self.zip = ZipperBlock(
|
||||
in_size=self.vision_hidden_size,
|
||||
in_tokens=self.vision_tokens,
|
||||
out_size=self.dim,
|
||||
out_tokens=1,
|
||||
hidden_size=self.dim,
|
||||
hidden_tokens=self.vision_tokens
|
||||
)
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
# get the vector
|
||||
img_embeds = self.instant_lora_module_ref().img_embeds
|
||||
# project it
|
||||
scaler = self.zip(img_embeds) # (batch_size, 1, dim)
|
||||
|
||||
# remove the channel dim
|
||||
scaler = scaler.squeeze(1)
|
||||
|
||||
# double up if batch is 2x the size on x (cfg)
|
||||
if x.shape[0] // 2 == scaler.shape[0]:
|
||||
scaler = torch.cat([scaler, scaler], dim=0)
|
||||
|
||||
# multiply it by the scaler
|
||||
try:
|
||||
# reshape if needed
|
||||
if len(x.shape) == 3:
|
||||
scaler = scaler.unsqueeze(1)
|
||||
x = x * scaler
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(x.shape)
|
||||
print(scaler.shape)
|
||||
raise e
|
||||
# apply tanh to limit values to -1 to 1
|
||||
scaler = torch.tanh(scaler)
|
||||
return x * scaler
|
||||
|
||||
|
||||
class InstantLoRAModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vision_hidden_size: int,
|
||||
vision_tokens: int,
|
||||
sd: 'StableDiffusion'
|
||||
):
|
||||
super(InstantLoRAModule, self).__init__()
|
||||
self.linear = torch.nn.Linear(2, 1)
|
||||
self.sd_ref = weakref.ref(sd)
|
||||
self.dim = sd.network.lora_dim
|
||||
self.vision_hidden_size = vision_hidden_size
|
||||
self.vision_tokens = vision_tokens
|
||||
|
||||
# stores the projection vector. Grabbed by modules
|
||||
self.img_embeds: torch.Tensor = None
|
||||
|
||||
# disable merging in. It is slower on inference
|
||||
self.sd_ref().network.can_merge_in = False
|
||||
|
||||
self.ilora_modules = torch.nn.ModuleList()
|
||||
|
||||
lora_modules = self.sd_ref().network.get_all_modules()
|
||||
|
||||
for lora_module in lora_modules:
|
||||
# add a new mid module that will take the original forward and add a vector to it
|
||||
# this will be used to add the vector to the original forward
|
||||
mid_module = InstantLoRAMidModule(self.dim, self.vision_tokens, self.vision_hidden_size, lora_module, self)
|
||||
|
||||
self.ilora_modules.append(mid_module)
|
||||
# replace the LoRA lora_mid
|
||||
lora_module.lora_mid = mid_module.forward
|
||||
|
||||
# add a new mid module that will take the original forward and add a vector to it
|
||||
# this will be used to add the vector to the original forward
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
Reference in New Issue
Block a user