mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Added working ilora trainer
This commit is contained in:
@@ -46,7 +46,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
||||||
super().__init__(process_id, job, config, **kwargs)
|
super().__init__(process_id, job, config, **kwargs)
|
||||||
self.assistant_adapter: Union['T2IAdapter', None]
|
self.assistant_adapter: Union['T2IAdapter', 'ControlNetModel', None]
|
||||||
self.do_prior_prediction = False
|
self.do_prior_prediction = False
|
||||||
self.do_long_prompts = False
|
self.do_long_prompts = False
|
||||||
self.do_guided_loss = False
|
self.do_guided_loss = False
|
||||||
@@ -76,10 +76,18 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
if self.train_config.adapter_assist_name_or_path is not None:
|
if self.train_config.adapter_assist_name_or_path is not None:
|
||||||
adapter_path = self.train_config.adapter_assist_name_or_path
|
adapter_path = self.train_config.adapter_assist_name_or_path
|
||||||
|
|
||||||
# dont name this adapter since we are not training it
|
if self.train_config.adapter_assist_type == "t2i":
|
||||||
self.assistant_adapter = T2IAdapter.from_pretrained(
|
# dont name this adapter since we are not training it
|
||||||
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16"
|
self.assistant_adapter = T2IAdapter.from_pretrained(
|
||||||
).to(self.device_torch)
|
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype)
|
||||||
|
).to(self.device_torch)
|
||||||
|
elif self.train_config.adapter_assist_type == "control_net":
|
||||||
|
self.assistant_adapter = ControlNetModel.from_pretrained(
|
||||||
|
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype)
|
||||||
|
).to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown adapter assist type {self.train_config.adapter_assist_type}")
|
||||||
|
|
||||||
self.assistant_adapter.eval()
|
self.assistant_adapter.eval()
|
||||||
self.assistant_adapter.requires_grad_(False)
|
self.assistant_adapter.requires_grad_(False)
|
||||||
flush()
|
flush()
|
||||||
@@ -955,10 +963,10 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
adapter_strength_max = 1.0
|
adapter_strength_max = 1.0
|
||||||
else:
|
else:
|
||||||
# training with assistance, we want it low
|
# training with assistance, we want it low
|
||||||
adapter_strength_min = 0.4
|
# adapter_strength_min = 0.4
|
||||||
adapter_strength_max = 0.7
|
# adapter_strength_max = 0.7
|
||||||
# adapter_strength_min = 0.9
|
adapter_strength_min = 0.9
|
||||||
# adapter_strength_max = 1.1
|
adapter_strength_max = 1.1
|
||||||
|
|
||||||
adapter_conditioning_scale = torch.rand(
|
adapter_conditioning_scale = torch.rand(
|
||||||
(1,), device=self.device_torch, dtype=dtype
|
(1,), device=self.device_torch, dtype=dtype
|
||||||
|
|||||||
@@ -380,8 +380,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
self.update_training_metadata()
|
self.update_training_metadata()
|
||||||
filename = f'{self.job.name}{step_num}.safetensors'
|
filename = f'{self.job.name}{step_num}.safetensors'
|
||||||
file_path = os.path.join(self.save_root, filename)
|
file_path = os.path.join(self.save_root, filename)
|
||||||
|
|
||||||
|
save_meta = copy.deepcopy(self.meta)
|
||||||
|
# get extra meta
|
||||||
|
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
||||||
|
additional_save_meta = self.adapter.get_additional_save_metadata()
|
||||||
|
if additional_save_meta is not None:
|
||||||
|
for key, value in additional_save_meta.items():
|
||||||
|
save_meta[key] = value
|
||||||
|
|
||||||
# prepare meta
|
# prepare meta
|
||||||
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
save_meta = get_meta_for_safetensors(save_meta, self.job.name)
|
||||||
if not self.is_fine_tuning:
|
if not self.is_fine_tuning:
|
||||||
if self.network is not None:
|
if self.network is not None:
|
||||||
lora_name = self.job.name
|
lora_name = self.job.name
|
||||||
|
|||||||
@@ -244,6 +244,7 @@ class TrainConfig:
|
|||||||
self.start_step = kwargs.get('start_step', None)
|
self.start_step = kwargs.get('start_step', None)
|
||||||
self.free_u = kwargs.get('free_u', False)
|
self.free_u = kwargs.get('free_u', False)
|
||||||
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
|
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
|
||||||
|
self.adapter_assist_type: Optional[str] = kwargs.get('adapter_assist_type', 't2i') # t2i, control_net
|
||||||
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||||
self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0)
|
self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model
|
|||||||
from toolkit.train_tools import get_torch_dtype
|
from toolkit.train_tools import get_torch_dtype
|
||||||
|
|
||||||
sys.path.append(REPOS_ROOT)
|
sys.path.append(REPOS_ROOT)
|
||||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional
|
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
|
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
|
||||||
AttnProcessor2_0
|
AttnProcessor2_0
|
||||||
@@ -145,6 +145,7 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
self.ilora_module = InstantLoRAModule(
|
self.ilora_module = InstantLoRAModule(
|
||||||
vision_tokens=vision_tokens,
|
vision_tokens=vision_tokens,
|
||||||
vision_hidden_size=vision_hidden_size,
|
vision_hidden_size=vision_hidden_size,
|
||||||
|
head_dim=1024,
|
||||||
sd=self.sd_ref()
|
sd=self.sd_ref()
|
||||||
)
|
)
|
||||||
elif self.adapter_type == 'text_encoder':
|
elif self.adapter_type == 'text_encoder':
|
||||||
@@ -875,3 +876,8 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
self.vision_encoder.enable_gradient_checkpointing()
|
self.vision_encoder.enable_gradient_checkpointing()
|
||||||
elif hasattr(self.vision_encoder, 'gradient_checkpointing'):
|
elif hasattr(self.vision_encoder, 'gradient_checkpointing'):
|
||||||
self.vision_encoder.gradient_checkpointing = True
|
self.vision_encoder.gradient_checkpointing = True
|
||||||
|
|
||||||
|
def get_additional_save_metadata(self) -> Dict[str, Any]:
|
||||||
|
if self.config.type == 'ilora':
|
||||||
|
return self.ilora_module.get_additional_save_metadata()
|
||||||
|
return {}
|
||||||
@@ -249,92 +249,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
skipped = []
|
skipped = []
|
||||||
attached_modules = []
|
attached_modules = []
|
||||||
for name, module in root_module.named_modules():
|
for name, module in root_module.named_modules():
|
||||||
if is_unet:
|
if module.__class__.__name__ in target_replace_modules:
|
||||||
module_name = module.__class__.__name__
|
|
||||||
if module not in attached_modules:
|
|
||||||
# if module.__class__.__name__ in target_replace_modules:
|
|
||||||
# for child_name, child_module in module.named_modules():
|
|
||||||
is_linear = module_name == 'LoRACompatibleLinear'
|
|
||||||
is_conv2d = module_name == 'LoRACompatibleConv'
|
|
||||||
# check if attn in name
|
|
||||||
is_attention = "attentions" in name
|
|
||||||
if not is_attention and attn_only:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if is_linear and self.lora_dim is None:
|
|
||||||
continue
|
|
||||||
if is_conv2d and self.conv_lora_dim is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
is_conv2d_1x1 = is_conv2d and module.kernel_size == (1, 1)
|
|
||||||
|
|
||||||
if is_conv2d_1x1:
|
|
||||||
pass
|
|
||||||
|
|
||||||
skip = False
|
|
||||||
if any([word in name for word in self.ignore_if_contains]):
|
|
||||||
skip = True
|
|
||||||
|
|
||||||
# see if it is over threshold
|
|
||||||
if count_parameters(module) < parameter_threshold:
|
|
||||||
skip = True
|
|
||||||
|
|
||||||
if (is_linear or is_conv2d) and not skip:
|
|
||||||
lora_name = prefix + "." + name
|
|
||||||
lora_name = lora_name.replace(".", "_")
|
|
||||||
|
|
||||||
dim = None
|
|
||||||
alpha = None
|
|
||||||
|
|
||||||
if modules_dim is not None:
|
|
||||||
# モジュール指定あり
|
|
||||||
if lora_name in modules_dim:
|
|
||||||
dim = modules_dim[lora_name]
|
|
||||||
alpha = modules_alpha[lora_name]
|
|
||||||
elif is_unet and block_dims is not None:
|
|
||||||
# U-Netでblock_dims指定あり
|
|
||||||
block_idx = get_block_index(lora_name)
|
|
||||||
if is_linear or is_conv2d_1x1:
|
|
||||||
dim = block_dims[block_idx]
|
|
||||||
alpha = block_alphas[block_idx]
|
|
||||||
elif conv_block_dims is not None:
|
|
||||||
dim = conv_block_dims[block_idx]
|
|
||||||
alpha = conv_block_alphas[block_idx]
|
|
||||||
else:
|
|
||||||
# 通常、すべて対象とする
|
|
||||||
if is_linear or is_conv2d_1x1:
|
|
||||||
dim = self.lora_dim
|
|
||||||
alpha = self.alpha
|
|
||||||
elif self.conv_lora_dim is not None:
|
|
||||||
dim = self.conv_lora_dim
|
|
||||||
alpha = self.conv_alpha
|
|
||||||
else:
|
|
||||||
dim = None
|
|
||||||
alpha = None
|
|
||||||
|
|
||||||
if dim is None or dim == 0:
|
|
||||||
# skipした情報を出力
|
|
||||||
if is_linear or is_conv2d_1x1 or (
|
|
||||||
self.conv_lora_dim is not None or conv_block_dims is not None):
|
|
||||||
skipped.append(lora_name)
|
|
||||||
continue
|
|
||||||
|
|
||||||
lora = module_class(
|
|
||||||
lora_name,
|
|
||||||
module,
|
|
||||||
self.multiplier,
|
|
||||||
dim,
|
|
||||||
alpha,
|
|
||||||
dropout=dropout,
|
|
||||||
rank_dropout=rank_dropout,
|
|
||||||
module_dropout=module_dropout,
|
|
||||||
network=self,
|
|
||||||
parent=module,
|
|
||||||
use_bias=use_bias,
|
|
||||||
)
|
|
||||||
loras.append(lora)
|
|
||||||
attached_modules.append(module)
|
|
||||||
elif module.__class__.__name__ in target_replace_modules:
|
|
||||||
for child_name, child_module in module.named_modules():
|
for child_name, child_module in module.named_modules():
|
||||||
is_linear = child_module.__class__.__name__ in LINEAR_MODULES
|
is_linear = child_module.__class__.__name__ in LINEAR_MODULES
|
||||||
is_conv2d = child_module.__class__.__name__ in CONV_MODULES
|
is_conv2d = child_module.__class__.__name__ in CONV_MODULES
|
||||||
|
|||||||
@@ -1,97 +1,170 @@
|
|||||||
|
import math
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, List, Dict, Any
|
||||||
from toolkit.models.clip_fusion import ZipperBlock
|
from toolkit.models.clip_fusion import ZipperBlock
|
||||||
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
|
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
|
||||||
import sys
|
import sys
|
||||||
from toolkit.paths import REPOS_ROOT
|
from toolkit.paths import REPOS_ROOT
|
||||||
sys.path.append(REPOS_ROOT)
|
sys.path.append(REPOS_ROOT)
|
||||||
from ipadapter.ip_adapter.resampler import Resampler
|
from ipadapter.ip_adapter.resampler import Resampler
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from toolkit.lora_special import LoRAModule
|
from toolkit.lora_special import LoRAModule
|
||||||
from toolkit.stable_diffusion_model import StableDiffusion
|
from toolkit.stable_diffusion_model import StableDiffusion
|
||||||
|
|
||||||
|
|
||||||
class ILoRAProjModule(torch.nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(self, num_modules=1, dim=4, embeddings_dim=512):
|
def __init__(self, in_dim, out_dim, hidden_dim, dropout=0.1, use_residual=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if use_residual:
|
||||||
self.num_modules = num_modules
|
assert in_dim == out_dim
|
||||||
self.num_dim = dim
|
self.layernorm = nn.LayerNorm(in_dim)
|
||||||
|
self.fc1 = nn.Linear(in_dim, hidden_dim)
|
||||||
self.proj = torch.nn.Sequential(
|
self.fc2 = nn.Linear(hidden_dim, out_dim)
|
||||||
torch.nn.LayerNorm(embeddings_dim),
|
self.dropout = nn.Dropout(dropout)
|
||||||
torch.nn.Linear(embeddings_dim, embeddings_dim * 2),
|
self.use_residual = use_residual
|
||||||
torch.nn.GELU(),
|
self.act_fn = nn.GELU()
|
||||||
torch.nn.Linear(embeddings_dim * 2, embeddings_dim * 2),
|
|
||||||
torch.nn.LayerNorm(embeddings_dim * 2),
|
|
||||||
|
|
||||||
torch.nn.Linear(embeddings_dim * 2, embeddings_dim * 4),
|
|
||||||
torch.nn.GELU(),
|
|
||||||
torch.nn.Linear(embeddings_dim * 4, num_modules * dim),
|
|
||||||
torch.nn.LayerNorm(num_modules * dim),
|
|
||||||
)
|
|
||||||
# Initialize the last linear layer weights near zero
|
|
||||||
torch.nn.init.uniform_(self.proj[-2].weight, a=-0.01, b=0.01)
|
|
||||||
torch.nn.init.zeros_(self.proj[-2].bias)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.proj(x)
|
residual = x
|
||||||
x = x.reshape(-1, self.num_modules, self.num_dim)
|
x = self.layernorm(x)
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act_fn(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
if self.use_residual:
|
||||||
|
x = x + residual
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
class LoRAGenerator(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int = 768, # projection dimension
|
||||||
|
hidden_size: int = 768,
|
||||||
|
head_size: int = 512,
|
||||||
|
num_mlp_layers: int = 1,
|
||||||
|
output_size: int = 768,
|
||||||
|
dropout: float = 0.5
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.input_size = input_size
|
||||||
|
|
||||||
|
self.output_size = output_size
|
||||||
|
self.lin_in = nn.Linear(input_size, hidden_size)
|
||||||
|
|
||||||
|
self.mlp_blocks = nn.Sequential(*[
|
||||||
|
MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in range(num_mlp_layers)
|
||||||
|
])
|
||||||
|
self.head = nn.Linear(hidden_size, head_size, bias=False)
|
||||||
|
self.norm = nn.LayerNorm(head_size)
|
||||||
|
|
||||||
|
self.flatten = nn.Flatten()
|
||||||
|
self.output = nn.Linear(head_size, self.output_size)
|
||||||
|
# for each output block. multiply weights by 0.01
|
||||||
|
with torch.no_grad():
|
||||||
|
self.output.weight.data *= 0.01
|
||||||
|
|
||||||
|
# allow get device
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return next(self.parameters()).dtype
|
||||||
|
|
||||||
|
def forward(self, embedding):
|
||||||
|
if len(embedding.shape) == 2:
|
||||||
|
embedding = embedding.unsqueeze(1)
|
||||||
|
|
||||||
|
x = self.lin_in(embedding)
|
||||||
|
x = self.mlp_blocks(x)
|
||||||
|
x = self.head(x)
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
head_output = x
|
||||||
|
|
||||||
|
x = self.output(head_output)
|
||||||
|
return x.squeeze(1)
|
||||||
|
|
||||||
|
|
||||||
class InstantLoRAMidModule(torch.nn.Module):
|
class InstantLoRAMidModule(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim: int,
|
|
||||||
index: int,
|
index: int,
|
||||||
lora_module: 'LoRAModule',
|
lora_module: 'LoRAModule',
|
||||||
instant_lora_module: 'InstantLoRAModule'
|
instant_lora_module: 'InstantLoRAModule',
|
||||||
|
up_shape: list = None,
|
||||||
|
down_shape: list = None,
|
||||||
):
|
):
|
||||||
super(InstantLoRAMidModule, self).__init__()
|
super(InstantLoRAMidModule, self).__init__()
|
||||||
self.dim = dim
|
self.up_shape = up_shape
|
||||||
|
self.down_shape = down_shape
|
||||||
self.index = index
|
self.index = index
|
||||||
self.lora_module_ref = weakref.ref(lora_module)
|
self.lora_module_ref = weakref.ref(lora_module)
|
||||||
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
|
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
|
||||||
|
|
||||||
def forward(self, x, *args, **kwargs):
|
self.embed = None
|
||||||
# get the vector
|
|
||||||
img_embeds = self.instant_lora_module_ref().img_embeds
|
|
||||||
# project it
|
|
||||||
scaler = img_embeds[:, self.index, :]
|
|
||||||
|
|
||||||
# remove the channel dim (index)
|
def down_forward(self, x, *args, **kwargs):
|
||||||
scaler = scaler.squeeze(1)
|
# get the embed
|
||||||
|
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
||||||
|
down_size = math.prod(self.down_shape)
|
||||||
|
down_weight = self.embed[:, :down_size]
|
||||||
|
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
|
||||||
|
# unconditional
|
||||||
|
if down_weight.shape[0] * 2 == batch_size:
|
||||||
|
down_weight = torch.cat([down_weight] * 2, dim=0)
|
||||||
|
|
||||||
|
weight_chunks = torch.chunk(down_weight, batch_size, dim=0)
|
||||||
|
x_chunks = torch.chunk(x, batch_size, dim=0)
|
||||||
|
|
||||||
|
x_out = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
weight_chunk = weight_chunks[i]
|
||||||
|
x_chunk = x_chunks[i]
|
||||||
|
# reshape
|
||||||
|
weight_chunk = weight_chunk.view(self.down_shape)
|
||||||
|
# run a simple lenear layer with the down weight
|
||||||
|
x_chunk = x_chunk @ weight_chunk.T
|
||||||
|
x_out.append(x_chunk)
|
||||||
|
x = torch.cat(x_out, dim=0)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def up_forward(self, x, *args, **kwargs):
|
||||||
|
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
||||||
|
up_size = math.prod(self.up_shape)
|
||||||
|
up_weight = self.embed[:, -up_size:]
|
||||||
|
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
|
||||||
|
# unconditional
|
||||||
|
if up_weight.shape[0] * 2 == batch_size:
|
||||||
|
up_weight = torch.cat([up_weight] * 2, dim=0)
|
||||||
|
|
||||||
|
weight_chunks = torch.chunk(up_weight, batch_size, dim=0)
|
||||||
|
x_chunks = torch.chunk(x, batch_size, dim=0)
|
||||||
|
|
||||||
|
x_out = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
weight_chunk = weight_chunks[i]
|
||||||
|
x_chunk = x_chunks[i]
|
||||||
|
# reshape
|
||||||
|
weight_chunk = weight_chunk.view(self.up_shape)
|
||||||
|
# run a simple lenear layer with the down weight
|
||||||
|
x_chunk = x_chunk @ weight_chunk.T
|
||||||
|
x_out.append(x_chunk)
|
||||||
|
x = torch.cat(x_out, dim=0)
|
||||||
|
return x
|
||||||
|
|
||||||
# double up if batch is 2x the size on x (cfg)
|
|
||||||
if x.shape[0] // 2 == scaler.shape[0]:
|
|
||||||
scaler = torch.cat([scaler, scaler], dim=0)
|
|
||||||
|
|
||||||
# multiply it by the scaler
|
|
||||||
try:
|
|
||||||
# reshape if needed
|
|
||||||
if len(x.shape) == 3:
|
|
||||||
scaler = scaler.unsqueeze(1)
|
|
||||||
if len(x.shape) == 4:
|
|
||||||
scaler = scaler.unsqueeze(-1).unsqueeze(-1)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
print(x.shape)
|
|
||||||
print(scaler.shape)
|
|
||||||
raise e
|
|
||||||
# apply tanh to limit values to -1 to 1
|
|
||||||
# scaler = torch.tanh(scaler)
|
|
||||||
try:
|
|
||||||
return x * scaler
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
print(x.shape)
|
|
||||||
print(scaler.shape)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
class InstantLoRAModule(torch.nn.Module):
|
class InstantLoRAModule(torch.nn.Module):
|
||||||
@@ -99,6 +172,7 @@ class InstantLoRAModule(torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
vision_hidden_size: int,
|
vision_hidden_size: int,
|
||||||
vision_tokens: int,
|
vision_tokens: int,
|
||||||
|
head_dim: int,
|
||||||
sd: 'StableDiffusion'
|
sd: 'StableDiffusion'
|
||||||
):
|
):
|
||||||
super(InstantLoRAModule, self).__init__()
|
super(InstantLoRAModule, self).__init__()
|
||||||
@@ -107,9 +181,10 @@ class InstantLoRAModule(torch.nn.Module):
|
|||||||
self.dim = sd.network.lora_dim
|
self.dim = sd.network.lora_dim
|
||||||
self.vision_hidden_size = vision_hidden_size
|
self.vision_hidden_size = vision_hidden_size
|
||||||
self.vision_tokens = vision_tokens
|
self.vision_tokens = vision_tokens
|
||||||
|
self.head_dim = head_dim
|
||||||
|
|
||||||
# stores the projection vector. Grabbed by modules
|
# stores the projection vector. Grabbed by modules
|
||||||
self.img_embeds: torch.Tensor = None
|
self.img_embeds: List[torch.Tensor] = None
|
||||||
|
|
||||||
# disable merging in. It is slower on inference
|
# disable merging in. It is slower on inference
|
||||||
self.sd_ref().network.can_merge_in = False
|
self.sd_ref().network.can_merge_in = False
|
||||||
@@ -118,58 +193,109 @@ class InstantLoRAModule(torch.nn.Module):
|
|||||||
|
|
||||||
lora_modules = self.sd_ref().network.get_all_modules()
|
lora_modules = self.sd_ref().network.get_all_modules()
|
||||||
|
|
||||||
# resample the output so each module gets one token with a size of its dim so we can multiply by that
|
output_size = 0
|
||||||
# self.resampler = ZipperResampler(
|
|
||||||
# in_size=self.vision_hidden_size,
|
self.embed_lengths = []
|
||||||
# in_tokens=self.vision_tokens,
|
self.weight_mapping = []
|
||||||
# out_size=self.dim,
|
|
||||||
# out_tokens=len(lora_modules),
|
|
||||||
# hidden_size=self.vision_hidden_size,
|
|
||||||
# hidden_tokens=self.vision_tokens,
|
|
||||||
# num_blocks=1,
|
|
||||||
# )
|
|
||||||
# heads = 20
|
|
||||||
# heads = 12
|
|
||||||
# dim = 1280
|
|
||||||
# output_dim = self.dim
|
|
||||||
self.proj_module = ILoRAProjModule(
|
|
||||||
num_modules=len(lora_modules),
|
|
||||||
dim=self.dim,
|
|
||||||
embeddings_dim=self.vision_hidden_size,
|
|
||||||
)
|
|
||||||
# self.resampler = Resampler(
|
|
||||||
# dim=dim,
|
|
||||||
# depth=4,
|
|
||||||
# dim_head=64,
|
|
||||||
# heads=heads,
|
|
||||||
# num_queries=len(lora_modules),
|
|
||||||
# embedding_dim=self.vision_hidden_size,
|
|
||||||
# max_seq_len=self.vision_tokens,
|
|
||||||
# output_dim=output_dim,
|
|
||||||
# ff_mult=4
|
|
||||||
# )
|
|
||||||
|
|
||||||
for idx, lora_module in enumerate(lora_modules):
|
for idx, lora_module in enumerate(lora_modules):
|
||||||
|
module_dict = lora_module.state_dict()
|
||||||
|
down_shape = list(module_dict['lora_down.weight'].shape)
|
||||||
|
up_shape = list(module_dict['lora_up.weight'].shape)
|
||||||
|
|
||||||
|
self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]])
|
||||||
|
|
||||||
|
module_size = math.prod(down_shape) + math.prod(up_shape)
|
||||||
|
output_size += module_size
|
||||||
|
self.embed_lengths.append(module_size)
|
||||||
|
|
||||||
|
|
||||||
# add a new mid module that will take the original forward and add a vector to it
|
# add a new mid module that will take the original forward and add a vector to it
|
||||||
# this will be used to add the vector to the original forward
|
# this will be used to add the vector to the original forward
|
||||||
mid_module = InstantLoRAMidModule(
|
instant_module = InstantLoRAMidModule(
|
||||||
self.dim,
|
|
||||||
idx,
|
idx,
|
||||||
lora_module,
|
lora_module,
|
||||||
self
|
self,
|
||||||
|
up_shape=up_shape,
|
||||||
|
down_shape=down_shape
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ilora_modules.append(mid_module)
|
self.ilora_modules.append(instant_module)
|
||||||
# replace the LoRA lora_mid
|
|
||||||
lora_module.lora_mid = mid_module.forward
|
# replace the LoRA forwards
|
||||||
|
lora_module.lora_down.forward = instant_module.down_forward
|
||||||
|
lora_module.lora_up.forward = instant_module.up_forward
|
||||||
|
|
||||||
|
|
||||||
|
self.output_size = output_size
|
||||||
|
|
||||||
|
if vision_tokens > 1:
|
||||||
|
self.resampler = Resampler(
|
||||||
|
dim=vision_hidden_size,
|
||||||
|
depth=4,
|
||||||
|
dim_head=64,
|
||||||
|
heads=12,
|
||||||
|
num_queries=1, # output tokens
|
||||||
|
embedding_dim=vision_hidden_size,
|
||||||
|
max_seq_len=vision_tokens,
|
||||||
|
output_dim=head_dim,
|
||||||
|
ff_mult=4
|
||||||
|
)
|
||||||
|
|
||||||
|
self.proj_module = LoRAGenerator(
|
||||||
|
input_size=head_dim,
|
||||||
|
hidden_size=head_dim,
|
||||||
|
head_size=head_dim,
|
||||||
|
num_mlp_layers=1,
|
||||||
|
output_size=self.output_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.migrate_weight_mapping()
|
||||||
|
|
||||||
|
def migrate_weight_mapping(self):
|
||||||
|
# changes the names of the modules to common ones
|
||||||
|
keymap = self.sd_ref().network.get_keymap()
|
||||||
|
save_keymap = {}
|
||||||
|
if keymap is not None:
|
||||||
|
for ldm_key, diffusers_key in keymap.items():
|
||||||
|
# invert them
|
||||||
|
save_keymap[diffusers_key] = ldm_key
|
||||||
|
|
||||||
|
new_keymap = {}
|
||||||
|
for key, value in self.weight_mapping:
|
||||||
|
if key in save_keymap:
|
||||||
|
new_keymap[save_keymap[key]] = value
|
||||||
|
else:
|
||||||
|
print(f"Key {key} not found in keymap")
|
||||||
|
new_keymap[key] = value
|
||||||
|
self.weight_mapping = new_keymap
|
||||||
|
else:
|
||||||
|
print("No keymap found. Using default names")
|
||||||
|
return
|
||||||
|
|
||||||
# add a new mid module that will take the original forward and add a vector to it
|
|
||||||
# this will be used to add the vector to the original forward
|
|
||||||
|
|
||||||
def forward(self, img_embeds):
|
def forward(self, img_embeds):
|
||||||
# expand token rank if only rank 2
|
# expand token rank if only rank 2
|
||||||
if len(img_embeds.shape) == 2:
|
if len(img_embeds.shape) == 2:
|
||||||
img_embeds = img_embeds.unsqueeze(1)
|
img_embeds = img_embeds.unsqueeze(1)
|
||||||
img_embeds = self.proj_module(img_embeds)
|
|
||||||
self.img_embeds = img_embeds
|
# resample the image embeddings
|
||||||
|
img_embeds = self.resampler(img_embeds)
|
||||||
|
img_embeds = self.proj_module(img_embeds)
|
||||||
|
if len(img_embeds.shape) == 3:
|
||||||
|
img_embeds = img_embeds.squeeze(1)
|
||||||
|
|
||||||
|
self.img_embeds = []
|
||||||
|
# get all the slices
|
||||||
|
start = 0
|
||||||
|
for length in self.embed_lengths:
|
||||||
|
self.img_embeds.append(img_embeds[:, start:start+length])
|
||||||
|
start += length
|
||||||
|
|
||||||
|
|
||||||
|
def get_additional_save_metadata(self) -> Dict[str, Any]:
|
||||||
|
# save the weight mapping
|
||||||
|
return {
|
||||||
|
"weight_mapping": self.weight_mapping
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user