From 1bd94f0f0160c38f48b527e47d20c30be92b1d87 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 23 Feb 2024 05:55:41 -0700 Subject: [PATCH] Added early DoRA support, but will change shortly. Dont use right now. --- extensions_built_in/sd_trainer/SDTrainer.py | 4 +- jobs/process/BaseSDTrainProcess.py | 1 + toolkit/lora_special.py | 6 ++ toolkit/models/DoRA.py | 98 +++++++++++++++++++++ toolkit/network_mixins.py | 36 +++++++- 5 files changed, 140 insertions(+), 5 deletions(-) create mode 100644 toolkit/models/DoRA.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index bd7aed65..a7ae422c 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1206,8 +1206,8 @@ class SDTrainer(BaseSDTrainProcess): has_been_preprocessed=True, quad_count=quad_count ) - else: - raise ValueError("Adapter images now must be loaded with dataloader or be a reg image") + # else: + # raise ValueError("Adapter images now must be loaded with dataloader or be a reg image") if not self.adapter_config.train_image_encoder: # we are not training the image encoder, so we need to detach the embeds diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 37834a4d..c47c89ca 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1192,6 +1192,7 @@ class BaseSDTrainProcess(BaseTrainProcess): use_bias=is_lorm, is_lorm=is_lorm, network_config=self.network_config, + network_type=self.network_config.type, **network_kwargs ) diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 744111bf..9d912b5d 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -15,6 +15,7 @@ from .paths import SD_SCRIPTS_ROOT sys.path.append(SD_SCRIPTS_ROOT) from networks.lora import LoRANetwork, get_block_index +from toolkit.models.DoRA import DoRAModule from torch.utils.checkpoint import checkpoint @@ -159,6 +160,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): attn_only: bool = False, target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE, target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3, + network_type: str = "lora", **kwargs ) -> None: """ @@ -199,6 +201,10 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self.is_sdxl = is_sdxl self.is_v2 = is_v2 self.is_pixart = is_pixart + self.network_type = network_type + if self.network_type.lower() == "dora": + self.module_class = DoRAModule + module_class = DoRAModule if modules_dim is not None: print(f"create LoRA network from weights") diff --git a/toolkit/models/DoRA.py b/toolkit/models/DoRA.py new file mode 100644 index 00000000..0f010352 --- /dev/null +++ b/toolkit/models/DoRA.py @@ -0,0 +1,98 @@ +#based off https://github.com/catid/dora/blob/main/dora.py +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import TYPE_CHECKING, Union, List + +from toolkit.network_mixins import ToolkitModuleMixin, ExtractableModuleMixin + +if TYPE_CHECKING: + from toolkit.lora_special import LoRASpecialNetwork + +# diffusers specific stuff +LINEAR_MODULES = [ + 'Linear', + 'LoRACompatibleLinear' + # 'GroupNorm', +] +CONV_MODULES = [ + 'Conv2d', + 'LoRACompatibleConv' +] + + +class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): + # def __init__(self, d_in, d_out, rank=4, weight=None, bias=None): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + network: 'LoRASpecialNetwork' = None, + use_bias: bool = False, + **kwargs + ): + self.can_merge_in = False + """if alpha == 0 or None, alpha is rank (no scaling).""" + ToolkitModuleMixin.__init__(self, network=network) + torch.nn.Module.__init__(self) + self.lora_name = lora_name + self.scalar = torch.tensor(1.0) + + self.lora_dim = lora_dim + + if org_module.__class__.__name__ in CONV_MODULES: + raise NotImplementedError("Convolutional layers are not supported yet") + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + # self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える eng: treat as constant + + self.multiplier: Union[float, List[float]] = multiplier + # wrap the original module so it doesn't get weights updated + self.org_module = [org_module] + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.is_checkpointing = False + + # m = Magnitude column-wise across output dimension + self.magnitude = nn.Parameter(self.get_orig_weight().norm(p=2, dim=0, keepdim=True)) + + d_out = org_module.out_features + d_in = org_module.in_features + + std_dev = 1 / torch.sqrt(torch.tensor(self.lora_dim).float()) + self.lora_up = nn.Parameter(torch.randn(d_out, self.lora_dim) * std_dev) + self.lora_down = nn.Parameter(torch.zeros(self.lora_dim, d_in)) + + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + # del self.org_module + + def get_orig_weight(self): + return self.org_module[0].weight.data.detach() + + def get_orig_bias(self): + if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None: + return self.org_module[0].bias.data.detach() + return None + + def dora_forward(self, x, *args, **kwargs): + lora = torch.matmul(self.lora_up, self.lora_down) + adapted = self.get_orig_weight() + lora + column_norm = adapted.norm(p=2, dim=0, keepdim=True) + norm_adapted = adapted / column_norm + calc_weights = self.magnitude * norm_adapted + return F.linear(x, calc_weights, self.get_orig_bias()) + diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index bfc6bc90..cb3fc747 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -19,9 +19,10 @@ if TYPE_CHECKING: from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule from toolkit.lora_special import LoRASpecialNetwork, LoRAModule from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.models.DoRA import DoRAModule Network = Union['LycorisSpecialNetwork', 'LoRASpecialNetwork'] -Module = Union['LoConSpecialModule', 'LoRAModule'] +Module = Union['LoConSpecialModule', 'LoRAModule', 'DoRAModule'] LINEAR_MODULES = [ 'Linear', @@ -247,6 +248,10 @@ class ToolkitModuleMixin: # network is not active, avoid doing anything return self.org_forward(x, *args, **kwargs) + if self.__class__.__name__ == "DoRAModule": + # return dora forward + return self.dora_forward(x, *args, **kwargs) + org_forwarded = self.org_forward(x, *args, **kwargs) lora_output = self._call_forward(x) multiplier = self.network_ref().torch_multiplier @@ -276,6 +281,8 @@ class ToolkitModuleMixin: @torch.no_grad() def merge_in(self: Module, merge_weight=1.0): + if not self.can_merge_in: + return # get up/down weight up_weight = self.lora_up.weight.clone().float() down_weight = self.lora_down.weight.clone().float() @@ -400,6 +407,23 @@ class ToolkitNetworkMixin: # get keymap from weights keymap = get_lora_keymap_from_model_keymap(keymap) + # upgrade keymaps for DoRA + if self.network_type.lower() == 'dora': + if keymap is not None: + new_keymap = {} + for ldm_key, diffusers_key in keymap.items(): + ldm_key = ldm_key.replace('.alpha', '.magnitude') + ldm_key = ldm_key.replace('.lora_down.weight', '.lora_down') + ldm_key = ldm_key.replace('.lora_up.weight', '.lora_up') + + diffusers_key = diffusers_key.replace('.alpha', '.magnitude') + diffusers_key = diffusers_key.replace('.lora_down.weight', '.lora_down') + diffusers_key = diffusers_key.replace('.lora_up.weight', '.lora_up') + + new_keymap[ldm_key] = diffusers_key + + keymap = new_keymap + return keymap def save_weights( @@ -489,8 +513,12 @@ class ToolkitNetworkMixin: multiplier = self._multiplier # get first module first_module = self.get_all_modules()[0] - device = first_module.lora_down.weight.device - dtype = first_module.lora_down.weight.dtype + if self.network_type.lower() == 'dora': + device = first_module.lora_down.device + dtype = first_module.lora_down.dtype + else: + device = first_module.lora_down.weight.device + dtype = first_module.lora_down.weight.dtype with torch.no_grad(): tensor_multiplier = None if isinstance(multiplier, int) or isinstance(multiplier, float): @@ -559,6 +587,8 @@ class ToolkitNetworkMixin: self._update_checkpointing() def merge_in(self, merge_weight=1.0): + if self.network_type.lower() == 'dora': + return self.is_merged_in = True for module in self.get_all_modules(): module.merge_in(merge_weight)