From cb5d28cba9f4818641e77b6c7bbdc0f59a9a8838 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 12 Jun 2024 09:33:45 -0600 Subject: [PATCH] Added working ilora trainer --- extensions_built_in/sd_trainer/SDTrainer.py | 26 +- jobs/process/BaseSDTrainProcess.py | 11 +- toolkit/config_modules.py | 1 + toolkit/custom_adapter.py | 8 +- toolkit/lora_special.py | 87 +----- toolkit/models/ilora.py | 324 ++++++++++++++------ 6 files changed, 261 insertions(+), 196 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index c8f55f1a..9d3aeec3 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -46,7 +46,7 @@ class SDTrainer(BaseSDTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): super().__init__(process_id, job, config, **kwargs) - self.assistant_adapter: Union['T2IAdapter', None] + self.assistant_adapter: Union['T2IAdapter', 'ControlNetModel', None] self.do_prior_prediction = False self.do_long_prompts = False self.do_guided_loss = False @@ -76,10 +76,18 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.adapter_assist_name_or_path is not None: adapter_path = self.train_config.adapter_assist_name_or_path - # dont name this adapter since we are not training it - self.assistant_adapter = T2IAdapter.from_pretrained( - adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16" - ).to(self.device_torch) + if self.train_config.adapter_assist_type == "t2i": + # dont name this adapter since we are not training it + self.assistant_adapter = T2IAdapter.from_pretrained( + adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype) + ).to(self.device_torch) + elif self.train_config.adapter_assist_type == "control_net": + self.assistant_adapter = ControlNetModel.from_pretrained( + adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype) + ).to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)) + else: + raise ValueError(f"Unknown adapter assist type {self.train_config.adapter_assist_type}") + self.assistant_adapter.eval() self.assistant_adapter.requires_grad_(False) flush() @@ -955,10 +963,10 @@ class SDTrainer(BaseSDTrainProcess): adapter_strength_max = 1.0 else: # training with assistance, we want it low - adapter_strength_min = 0.4 - adapter_strength_max = 0.7 - # adapter_strength_min = 0.9 - # adapter_strength_max = 1.1 + # adapter_strength_min = 0.4 + # adapter_strength_max = 0.7 + adapter_strength_min = 0.9 + adapter_strength_max = 1.1 adapter_conditioning_scale = torch.rand( (1,), device=self.device_torch, dtype=dtype diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 85674770..34f5d09f 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -380,8 +380,17 @@ class BaseSDTrainProcess(BaseTrainProcess): self.update_training_metadata() filename = f'{self.job.name}{step_num}.safetensors' file_path = os.path.join(self.save_root, filename) + + save_meta = copy.deepcopy(self.meta) + # get extra meta + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + additional_save_meta = self.adapter.get_additional_save_metadata() + if additional_save_meta is not None: + for key, value in additional_save_meta.items(): + save_meta[key] = value + # prepare meta - save_meta = get_meta_for_safetensors(self.meta, self.job.name) + save_meta = get_meta_for_safetensors(save_meta, self.job.name) if not self.is_fine_tuning: if self.network is not None: lora_name = self.job.name diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 4d291317..bd948913 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -244,6 +244,7 @@ class TrainConfig: self.start_step = kwargs.get('start_step', None) self.free_u = kwargs.get('free_u', False) self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None) + self.adapter_assist_type: Optional[str] = kwargs.get('adapter_assist_type', 't2i') # t2i, control_net self.noise_multiplier = kwargs.get('noise_multiplier', 1.0) self.img_multiplier = kwargs.get('img_multiplier', 1.0) self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0) diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 32e6984b..e953df82 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -19,7 +19,7 @@ from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model from toolkit.train_tools import get_torch_dtype sys.path.append(REPOS_ROOT) -from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional +from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict from collections import OrderedDict from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \ AttnProcessor2_0 @@ -145,6 +145,7 @@ class CustomAdapter(torch.nn.Module): self.ilora_module = InstantLoRAModule( vision_tokens=vision_tokens, vision_hidden_size=vision_hidden_size, + head_dim=1024, sd=self.sd_ref() ) elif self.adapter_type == 'text_encoder': @@ -875,3 +876,8 @@ class CustomAdapter(torch.nn.Module): self.vision_encoder.enable_gradient_checkpointing() elif hasattr(self.vision_encoder, 'gradient_checkpointing'): self.vision_encoder.gradient_checkpointing = True + + def get_additional_save_metadata(self) -> Dict[str, Any]: + if self.config.type == 'ilora': + return self.ilora_module.get_additional_save_metadata() + return {} \ No newline at end of file diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 9d912b5d..409fef2d 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -249,92 +249,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): skipped = [] attached_modules = [] for name, module in root_module.named_modules(): - if is_unet: - module_name = module.__class__.__name__ - if module not in attached_modules: - # if module.__class__.__name__ in target_replace_modules: - # for child_name, child_module in module.named_modules(): - is_linear = module_name == 'LoRACompatibleLinear' - is_conv2d = module_name == 'LoRACompatibleConv' - # check if attn in name - is_attention = "attentions" in name - if not is_attention and attn_only: - continue - - if is_linear and self.lora_dim is None: - continue - if is_conv2d and self.conv_lora_dim is None: - continue - - is_conv2d_1x1 = is_conv2d and module.kernel_size == (1, 1) - - if is_conv2d_1x1: - pass - - skip = False - if any([word in name for word in self.ignore_if_contains]): - skip = True - - # see if it is over threshold - if count_parameters(module) < parameter_threshold: - skip = True - - if (is_linear or is_conv2d) and not skip: - lora_name = prefix + "." + name - lora_name = lora_name.replace(".", "_") - - dim = None - alpha = None - - if modules_dim is not None: - # モジュール指定あり - if lora_name in modules_dim: - dim = modules_dim[lora_name] - alpha = modules_alpha[lora_name] - elif is_unet and block_dims is not None: - # U-Netでblock_dims指定あり - block_idx = get_block_index(lora_name) - if is_linear or is_conv2d_1x1: - dim = block_dims[block_idx] - alpha = block_alphas[block_idx] - elif conv_block_dims is not None: - dim = conv_block_dims[block_idx] - alpha = conv_block_alphas[block_idx] - else: - # 通常、すべて対象とする - if is_linear or is_conv2d_1x1: - dim = self.lora_dim - alpha = self.alpha - elif self.conv_lora_dim is not None: - dim = self.conv_lora_dim - alpha = self.conv_alpha - else: - dim = None - alpha = None - - if dim is None or dim == 0: - # skipした情報を出力 - if is_linear or is_conv2d_1x1 or ( - self.conv_lora_dim is not None or conv_block_dims is not None): - skipped.append(lora_name) - continue - - lora = module_class( - lora_name, - module, - self.multiplier, - dim, - alpha, - dropout=dropout, - rank_dropout=rank_dropout, - module_dropout=module_dropout, - network=self, - parent=module, - use_bias=use_bias, - ) - loras.append(lora) - attached_modules.append(module) - elif module.__class__.__name__ in target_replace_modules: + if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): is_linear = child_module.__class__.__name__ in LINEAR_MODULES is_conv2d = child_module.__class__.__name__ in CONV_MODULES diff --git a/toolkit/models/ilora.py b/toolkit/models/ilora.py index 03459ba8..93478acc 100644 --- a/toolkit/models/ilora.py +++ b/toolkit/models/ilora.py @@ -1,97 +1,170 @@ +import math import weakref import torch import torch.nn as nn -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Dict, Any from toolkit.models.clip_fusion import ZipperBlock from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler import sys from toolkit.paths import REPOS_ROOT sys.path.append(REPOS_ROOT) from ipadapter.ip_adapter.resampler import Resampler +from collections import OrderedDict if TYPE_CHECKING: from toolkit.lora_special import LoRAModule from toolkit.stable_diffusion_model import StableDiffusion -class ILoRAProjModule(torch.nn.Module): - def __init__(self, num_modules=1, dim=4, embeddings_dim=512): +class MLP(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, dropout=0.1, use_residual=True): super().__init__() - - self.num_modules = num_modules - self.num_dim = dim - - self.proj = torch.nn.Sequential( - torch.nn.LayerNorm(embeddings_dim), - torch.nn.Linear(embeddings_dim, embeddings_dim * 2), - torch.nn.GELU(), - torch.nn.Linear(embeddings_dim * 2, embeddings_dim * 2), - torch.nn.LayerNorm(embeddings_dim * 2), - - torch.nn.Linear(embeddings_dim * 2, embeddings_dim * 4), - torch.nn.GELU(), - torch.nn.Linear(embeddings_dim * 4, num_modules * dim), - torch.nn.LayerNorm(num_modules * dim), - ) - # Initialize the last linear layer weights near zero - torch.nn.init.uniform_(self.proj[-2].weight, a=-0.01, b=0.01) - torch.nn.init.zeros_(self.proj[-2].bias) + if use_residual: + assert in_dim == out_dim + self.layernorm = nn.LayerNorm(in_dim) + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, out_dim) + self.dropout = nn.Dropout(dropout) + self.use_residual = use_residual + self.act_fn = nn.GELU() def forward(self, x): - x = self.proj(x) - x = x.reshape(-1, self.num_modules, self.num_dim) + residual = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + x = self.dropout(x) + if self.use_residual: + x = x + residual return x +class LoRAGenerator(torch.nn.Module): + def __init__( + self, + input_size: int = 768, # projection dimension + hidden_size: int = 768, + head_size: int = 512, + num_mlp_layers: int = 1, + output_size: int = 768, + dropout: float = 0.5 + ): + super().__init__() + self.input_size = input_size + + self.output_size = output_size + self.lin_in = nn.Linear(input_size, hidden_size) + + self.mlp_blocks = nn.Sequential(*[ + MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in range(num_mlp_layers) + ]) + self.head = nn.Linear(hidden_size, head_size, bias=False) + self.norm = nn.LayerNorm(head_size) + + self.flatten = nn.Flatten() + self.output = nn.Linear(head_size, self.output_size) + # for each output block. multiply weights by 0.01 + with torch.no_grad(): + self.output.weight.data *= 0.01 + + # allow get device + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def forward(self, embedding): + if len(embedding.shape) == 2: + embedding = embedding.unsqueeze(1) + + x = self.lin_in(embedding) + x = self.mlp_blocks(x) + x = self.head(x) + x = self.norm(x) + + head_output = x + + x = self.output(head_output) + return x.squeeze(1) + class InstantLoRAMidModule(torch.nn.Module): def __init__( self, - dim: int, index: int, lora_module: 'LoRAModule', - instant_lora_module: 'InstantLoRAModule' + instant_lora_module: 'InstantLoRAModule', + up_shape: list = None, + down_shape: list = None, ): super(InstantLoRAMidModule, self).__init__() - self.dim = dim + self.up_shape = up_shape + self.down_shape = down_shape self.index = index self.lora_module_ref = weakref.ref(lora_module) self.instant_lora_module_ref = weakref.ref(instant_lora_module) - def forward(self, x, *args, **kwargs): - # get the vector - img_embeds = self.instant_lora_module_ref().img_embeds - # project it - scaler = img_embeds[:, self.index, :] + self.embed = None - # remove the channel dim (index) - scaler = scaler.squeeze(1) + def down_forward(self, x, *args, **kwargs): + # get the embed + self.embed = self.instant_lora_module_ref().img_embeds[self.index] + down_size = math.prod(self.down_shape) + down_weight = self.embed[:, :down_size] + + batch_size = x.shape[0] + + # unconditional + if down_weight.shape[0] * 2 == batch_size: + down_weight = torch.cat([down_weight] * 2, dim=0) + + weight_chunks = torch.chunk(down_weight, batch_size, dim=0) + x_chunks = torch.chunk(x, batch_size, dim=0) + + x_out = [] + for i in range(batch_size): + weight_chunk = weight_chunks[i] + x_chunk = x_chunks[i] + # reshape + weight_chunk = weight_chunk.view(self.down_shape) + # run a simple lenear layer with the down weight + x_chunk = x_chunk @ weight_chunk.T + x_out.append(x_chunk) + x = torch.cat(x_out, dim=0) + return x + + + def up_forward(self, x, *args, **kwargs): + self.embed = self.instant_lora_module_ref().img_embeds[self.index] + up_size = math.prod(self.up_shape) + up_weight = self.embed[:, -up_size:] + + batch_size = x.shape[0] + + # unconditional + if up_weight.shape[0] * 2 == batch_size: + up_weight = torch.cat([up_weight] * 2, dim=0) + + weight_chunks = torch.chunk(up_weight, batch_size, dim=0) + x_chunks = torch.chunk(x, batch_size, dim=0) + + x_out = [] + for i in range(batch_size): + weight_chunk = weight_chunks[i] + x_chunk = x_chunks[i] + # reshape + weight_chunk = weight_chunk.view(self.up_shape) + # run a simple lenear layer with the down weight + x_chunk = x_chunk @ weight_chunk.T + x_out.append(x_chunk) + x = torch.cat(x_out, dim=0) + return x - # double up if batch is 2x the size on x (cfg) - if x.shape[0] // 2 == scaler.shape[0]: - scaler = torch.cat([scaler, scaler], dim=0) - # multiply it by the scaler - try: - # reshape if needed - if len(x.shape) == 3: - scaler = scaler.unsqueeze(1) - if len(x.shape) == 4: - scaler = scaler.unsqueeze(-1).unsqueeze(-1) - except Exception as e: - print(e) - print(x.shape) - print(scaler.shape) - raise e - # apply tanh to limit values to -1 to 1 - # scaler = torch.tanh(scaler) - try: - return x * scaler - except Exception as e: - print(e) - print(x.shape) - print(scaler.shape) - raise e class InstantLoRAModule(torch.nn.Module): @@ -99,6 +172,7 @@ class InstantLoRAModule(torch.nn.Module): self, vision_hidden_size: int, vision_tokens: int, + head_dim: int, sd: 'StableDiffusion' ): super(InstantLoRAModule, self).__init__() @@ -107,9 +181,10 @@ class InstantLoRAModule(torch.nn.Module): self.dim = sd.network.lora_dim self.vision_hidden_size = vision_hidden_size self.vision_tokens = vision_tokens + self.head_dim = head_dim # stores the projection vector. Grabbed by modules - self.img_embeds: torch.Tensor = None + self.img_embeds: List[torch.Tensor] = None # disable merging in. It is slower on inference self.sd_ref().network.can_merge_in = False @@ -118,58 +193,109 @@ class InstantLoRAModule(torch.nn.Module): lora_modules = self.sd_ref().network.get_all_modules() - # resample the output so each module gets one token with a size of its dim so we can multiply by that - # self.resampler = ZipperResampler( - # in_size=self.vision_hidden_size, - # in_tokens=self.vision_tokens, - # out_size=self.dim, - # out_tokens=len(lora_modules), - # hidden_size=self.vision_hidden_size, - # hidden_tokens=self.vision_tokens, - # num_blocks=1, - # ) - # heads = 20 - # heads = 12 - # dim = 1280 - # output_dim = self.dim - self.proj_module = ILoRAProjModule( - num_modules=len(lora_modules), - dim=self.dim, - embeddings_dim=self.vision_hidden_size, - ) - # self.resampler = Resampler( - # dim=dim, - # depth=4, - # dim_head=64, - # heads=heads, - # num_queries=len(lora_modules), - # embedding_dim=self.vision_hidden_size, - # max_seq_len=self.vision_tokens, - # output_dim=output_dim, - # ff_mult=4 - # ) + output_size = 0 + + self.embed_lengths = [] + self.weight_mapping = [] for idx, lora_module in enumerate(lora_modules): + module_dict = lora_module.state_dict() + down_shape = list(module_dict['lora_down.weight'].shape) + up_shape = list(module_dict['lora_up.weight'].shape) + + self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]]) + + module_size = math.prod(down_shape) + math.prod(up_shape) + output_size += module_size + self.embed_lengths.append(module_size) + + # add a new mid module that will take the original forward and add a vector to it # this will be used to add the vector to the original forward - mid_module = InstantLoRAMidModule( - self.dim, + instant_module = InstantLoRAMidModule( idx, lora_module, - self + self, + up_shape=up_shape, + down_shape=down_shape ) - self.ilora_modules.append(mid_module) - # replace the LoRA lora_mid - lora_module.lora_mid = mid_module.forward + self.ilora_modules.append(instant_module) + + # replace the LoRA forwards + lora_module.lora_down.forward = instant_module.down_forward + lora_module.lora_up.forward = instant_module.up_forward + + + self.output_size = output_size + + if vision_tokens > 1: + self.resampler = Resampler( + dim=vision_hidden_size, + depth=4, + dim_head=64, + heads=12, + num_queries=1, # output tokens + embedding_dim=vision_hidden_size, + max_seq_len=vision_tokens, + output_dim=head_dim, + ff_mult=4 + ) + + self.proj_module = LoRAGenerator( + input_size=head_dim, + hidden_size=head_dim, + head_size=head_dim, + num_mlp_layers=1, + output_size=self.output_size, + ) + + self.migrate_weight_mapping() + + def migrate_weight_mapping(self): + # changes the names of the modules to common ones + keymap = self.sd_ref().network.get_keymap() + save_keymap = {} + if keymap is not None: + for ldm_key, diffusers_key in keymap.items(): + # invert them + save_keymap[diffusers_key] = ldm_key + + new_keymap = {} + for key, value in self.weight_mapping: + if key in save_keymap: + new_keymap[save_keymap[key]] = value + else: + print(f"Key {key} not found in keymap") + new_keymap[key] = value + self.weight_mapping = new_keymap + else: + print("No keymap found. Using default names") + return - # add a new mid module that will take the original forward and add a vector to it - # this will be used to add the vector to the original forward def forward(self, img_embeds): # expand token rank if only rank 2 if len(img_embeds.shape) == 2: img_embeds = img_embeds.unsqueeze(1) - img_embeds = self.proj_module(img_embeds) - self.img_embeds = img_embeds + + # resample the image embeddings + img_embeds = self.resampler(img_embeds) + img_embeds = self.proj_module(img_embeds) + if len(img_embeds.shape) == 3: + img_embeds = img_embeds.squeeze(1) + + self.img_embeds = [] + # get all the slices + start = 0 + for length in self.embed_lengths: + self.img_embeds.append(img_embeds[:, start:start+length]) + start += length + + + def get_additional_save_metadata(self) -> Dict[str, Any]: + # save the weight mapping + return { + "weight_mapping": self.weight_mapping + }