mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Added early DoRA support, but will change shortly. Dont use right now.
This commit is contained in:
@@ -1206,8 +1206,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
has_been_preprocessed=True,
|
||||
quad_count=quad_count
|
||||
)
|
||||
else:
|
||||
raise ValueError("Adapter images now must be loaded with dataloader or be a reg image")
|
||||
# else:
|
||||
# raise ValueError("Adapter images now must be loaded with dataloader or be a reg image")
|
||||
|
||||
if not self.adapter_config.train_image_encoder:
|
||||
# we are not training the image encoder, so we need to detach the embeds
|
||||
|
||||
@@ -1192,6 +1192,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
use_bias=is_lorm,
|
||||
is_lorm=is_lorm,
|
||||
network_config=self.network_config,
|
||||
network_type=self.network_config.type,
|
||||
**network_kwargs
|
||||
)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from .paths import SD_SCRIPTS_ROOT
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
|
||||
from networks.lora import LoRANetwork, get_block_index
|
||||
from toolkit.models.DoRA import DoRAModule
|
||||
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
@@ -159,6 +160,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
attn_only: bool = False,
|
||||
target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE,
|
||||
target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3,
|
||||
network_type: str = "lora",
|
||||
**kwargs
|
||||
) -> None:
|
||||
"""
|
||||
@@ -199,6 +201,10 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.is_sdxl = is_sdxl
|
||||
self.is_v2 = is_v2
|
||||
self.is_pixart = is_pixart
|
||||
self.network_type = network_type
|
||||
if self.network_type.lower() == "dora":
|
||||
self.module_class = DoRAModule
|
||||
module_class = DoRAModule
|
||||
|
||||
if modules_dim is not None:
|
||||
print(f"create LoRA network from weights")
|
||||
|
||||
98
toolkit/models/DoRA.py
Normal file
98
toolkit/models/DoRA.py
Normal file
@@ -0,0 +1,98 @@
|
||||
#based off https://github.com/catid/dora/blob/main/dora.py
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import TYPE_CHECKING, Union, List
|
||||
|
||||
from toolkit.network_mixins import ToolkitModuleMixin, ExtractableModuleMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
|
||||
# diffusers specific stuff
|
||||
LINEAR_MODULES = [
|
||||
'Linear',
|
||||
'LoRACompatibleLinear'
|
||||
# 'GroupNorm',
|
||||
]
|
||||
CONV_MODULES = [
|
||||
'Conv2d',
|
||||
'LoRACompatibleConv'
|
||||
]
|
||||
|
||||
|
||||
class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||
# def __init__(self, d_in, d_out, rank=4, weight=None, bias=None):
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
network: 'LoRASpecialNetwork' = None,
|
||||
use_bias: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
self.can_merge_in = False
|
||||
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
||||
ToolkitModuleMixin.__init__(self, network=network)
|
||||
torch.nn.Module.__init__(self)
|
||||
self.lora_name = lora_name
|
||||
self.scalar = torch.tensor(1.0)
|
||||
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
if org_module.__class__.__name__ in CONV_MODULES:
|
||||
raise NotImplementedError("Convolutional layers are not supported yet")
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
||||
self.scale = alpha / self.lora_dim
|
||||
# self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える eng: treat as constant
|
||||
|
||||
self.multiplier: Union[float, List[float]] = multiplier
|
||||
# wrap the original module so it doesn't get weights updated
|
||||
self.org_module = [org_module]
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.is_checkpointing = False
|
||||
|
||||
# m = Magnitude column-wise across output dimension
|
||||
self.magnitude = nn.Parameter(self.get_orig_weight().norm(p=2, dim=0, keepdim=True))
|
||||
|
||||
d_out = org_module.out_features
|
||||
d_in = org_module.in_features
|
||||
|
||||
std_dev = 1 / torch.sqrt(torch.tensor(self.lora_dim).float())
|
||||
self.lora_up = nn.Parameter(torch.randn(d_out, self.lora_dim) * std_dev)
|
||||
self.lora_down = nn.Parameter(torch.zeros(self.lora_dim, d_in))
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module[0].forward
|
||||
self.org_module[0].forward = self.forward
|
||||
# del self.org_module
|
||||
|
||||
def get_orig_weight(self):
|
||||
return self.org_module[0].weight.data.detach()
|
||||
|
||||
def get_orig_bias(self):
|
||||
if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None:
|
||||
return self.org_module[0].bias.data.detach()
|
||||
return None
|
||||
|
||||
def dora_forward(self, x, *args, **kwargs):
|
||||
lora = torch.matmul(self.lora_up, self.lora_down)
|
||||
adapted = self.get_orig_weight() + lora
|
||||
column_norm = adapted.norm(p=2, dim=0, keepdim=True)
|
||||
norm_adapted = adapted / column_norm
|
||||
calc_weights = self.magnitude * norm_adapted
|
||||
return F.linear(x, calc_weights, self.get_orig_bias())
|
||||
|
||||
@@ -19,9 +19,10 @@ if TYPE_CHECKING:
|
||||
from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule
|
||||
from toolkit.lora_special import LoRASpecialNetwork, LoRAModule
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.models.DoRA import DoRAModule
|
||||
|
||||
Network = Union['LycorisSpecialNetwork', 'LoRASpecialNetwork']
|
||||
Module = Union['LoConSpecialModule', 'LoRAModule']
|
||||
Module = Union['LoConSpecialModule', 'LoRAModule', 'DoRAModule']
|
||||
|
||||
LINEAR_MODULES = [
|
||||
'Linear',
|
||||
@@ -247,6 +248,10 @@ class ToolkitModuleMixin:
|
||||
# network is not active, avoid doing anything
|
||||
return self.org_forward(x, *args, **kwargs)
|
||||
|
||||
if self.__class__.__name__ == "DoRAModule":
|
||||
# return dora forward
|
||||
return self.dora_forward(x, *args, **kwargs)
|
||||
|
||||
org_forwarded = self.org_forward(x, *args, **kwargs)
|
||||
lora_output = self._call_forward(x)
|
||||
multiplier = self.network_ref().torch_multiplier
|
||||
@@ -276,6 +281,8 @@ class ToolkitModuleMixin:
|
||||
|
||||
@torch.no_grad()
|
||||
def merge_in(self: Module, merge_weight=1.0):
|
||||
if not self.can_merge_in:
|
||||
return
|
||||
# get up/down weight
|
||||
up_weight = self.lora_up.weight.clone().float()
|
||||
down_weight = self.lora_down.weight.clone().float()
|
||||
@@ -400,6 +407,23 @@ class ToolkitNetworkMixin:
|
||||
# get keymap from weights
|
||||
keymap = get_lora_keymap_from_model_keymap(keymap)
|
||||
|
||||
# upgrade keymaps for DoRA
|
||||
if self.network_type.lower() == 'dora':
|
||||
if keymap is not None:
|
||||
new_keymap = {}
|
||||
for ldm_key, diffusers_key in keymap.items():
|
||||
ldm_key = ldm_key.replace('.alpha', '.magnitude')
|
||||
ldm_key = ldm_key.replace('.lora_down.weight', '.lora_down')
|
||||
ldm_key = ldm_key.replace('.lora_up.weight', '.lora_up')
|
||||
|
||||
diffusers_key = diffusers_key.replace('.alpha', '.magnitude')
|
||||
diffusers_key = diffusers_key.replace('.lora_down.weight', '.lora_down')
|
||||
diffusers_key = diffusers_key.replace('.lora_up.weight', '.lora_up')
|
||||
|
||||
new_keymap[ldm_key] = diffusers_key
|
||||
|
||||
keymap = new_keymap
|
||||
|
||||
return keymap
|
||||
|
||||
def save_weights(
|
||||
@@ -489,8 +513,12 @@ class ToolkitNetworkMixin:
|
||||
multiplier = self._multiplier
|
||||
# get first module
|
||||
first_module = self.get_all_modules()[0]
|
||||
device = first_module.lora_down.weight.device
|
||||
dtype = first_module.lora_down.weight.dtype
|
||||
if self.network_type.lower() == 'dora':
|
||||
device = first_module.lora_down.device
|
||||
dtype = first_module.lora_down.dtype
|
||||
else:
|
||||
device = first_module.lora_down.weight.device
|
||||
dtype = first_module.lora_down.weight.dtype
|
||||
with torch.no_grad():
|
||||
tensor_multiplier = None
|
||||
if isinstance(multiplier, int) or isinstance(multiplier, float):
|
||||
@@ -559,6 +587,8 @@ class ToolkitNetworkMixin:
|
||||
self._update_checkpointing()
|
||||
|
||||
def merge_in(self, merge_weight=1.0):
|
||||
if self.network_type.lower() == 'dora':
|
||||
return
|
||||
self.is_merged_in = True
|
||||
for module in self.get_all_modules():
|
||||
module.merge_in(merge_weight)
|
||||
|
||||
Reference in New Issue
Block a user