mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-06 05:29:57 +00:00
282 lines
11 KiB
Python
282 lines
11 KiB
Python
# based heavily on https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/lokr.py
|
|
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from toolkit.network_mixins import ToolkitModuleMixin
|
|
|
|
from typing import TYPE_CHECKING, Union, List
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
from toolkit.lora_special import LoRASpecialNetwork
|
|
|
|
# 4, build custom backward function
|
|
# -
|
|
|
|
|
|
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
|
'''
|
|
return a tuple of two value of input dimension decomposed by the number closest to factor
|
|
second value is higher or equal than first value.
|
|
|
|
In LoRA with Kroneckor Product, first value is a value for weight scale.
|
|
secon value is a value for weight.
|
|
|
|
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
|
|
|
|
examples)
|
|
factor
|
|
-1 2 4 8 16 ...
|
|
127 -> 127, 1 127 -> 127, 1 127 -> 127, 1 127 -> 127, 1 127 -> 127, 1
|
|
128 -> 16, 8 128 -> 64, 2 128 -> 32, 4 128 -> 16, 8 128 -> 16, 8
|
|
250 -> 125, 2 250 -> 125, 2 250 -> 125, 2 250 -> 125, 2 250 -> 125, 2
|
|
360 -> 45, 8 360 -> 180, 2 360 -> 90, 4 360 -> 45, 8 360 -> 45, 8
|
|
512 -> 32, 16 512 -> 256, 2 512 -> 128, 4 512 -> 64, 8 512 -> 32, 16
|
|
1024 -> 32, 32 1024 -> 512, 2 1024 -> 256, 4 1024 -> 128, 8 1024 -> 64, 16
|
|
'''
|
|
|
|
if factor > 0 and (dimension % factor) == 0:
|
|
m = factor
|
|
n = dimension // factor
|
|
return m, n
|
|
if factor == -1:
|
|
factor = dimension
|
|
m, n = 1, dimension
|
|
length = m + n
|
|
while m<n:
|
|
new_m = m + 1
|
|
while dimension%new_m != 0:
|
|
new_m += 1
|
|
new_n = dimension // new_m
|
|
if new_m + new_n > length or new_m>factor:
|
|
break
|
|
else:
|
|
m, n = new_m, new_n
|
|
if m > n:
|
|
n, m = m, n
|
|
return m, n
|
|
|
|
|
|
def make_weight_cp(t, wa, wb):
|
|
rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l', t, wa, wb) # [c, d, k1, k2]
|
|
return rebuild2
|
|
|
|
|
|
def make_kron(w1, w2, scale):
|
|
if len(w2.shape) == 4:
|
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
w2 = w2.contiguous()
|
|
rebuild = torch.kron(w1, w2)
|
|
|
|
return rebuild*scale
|
|
|
|
|
|
class LokrModule(ToolkitModuleMixin, nn.Module):
|
|
"""
|
|
modifed from kohya-ss/sd-scripts/networks/lora:LoRAModule
|
|
and from KohakuBlueleaf/LyCORIS/lycoris:loha:LoHaModule
|
|
and from KohakuBlueleaf/LyCORIS/lycoris:locon:LoconModule
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
lora_name,
|
|
org_module: nn.Module,
|
|
multiplier=1.0,
|
|
lora_dim=4,
|
|
alpha=1,
|
|
dropout=0.,
|
|
rank_dropout=0.,
|
|
module_dropout=0.,
|
|
use_cp=False,
|
|
decompose_both = False,
|
|
network: 'LoRASpecialNetwork' = None,
|
|
factor:int=-1, # factorization factor
|
|
**kwargs,
|
|
):
|
|
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
|
ToolkitModuleMixin.__init__(self, network=network)
|
|
torch.nn.Module.__init__(self)
|
|
factor = int(factor)
|
|
self.lora_name = lora_name
|
|
self.lora_dim = lora_dim
|
|
self.cp = False
|
|
self.use_w1 = False
|
|
self.use_w2 = False
|
|
|
|
self.shape = org_module.weight.shape
|
|
if org_module.__class__.__name__ == 'Conv2d':
|
|
in_dim = org_module.in_channels
|
|
k_size = org_module.kernel_size
|
|
out_dim = org_module.out_channels
|
|
|
|
in_m, in_n = factorization(in_dim, factor)
|
|
out_l, out_k = factorization(out_dim, factor)
|
|
shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), *k_size)
|
|
|
|
self.cp = use_cp and k_size!=(1, 1)
|
|
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
|
|
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
|
|
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
|
|
else:
|
|
self.use_w1 = True
|
|
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
|
|
|
|
if lora_dim >= max(shape[0][1], shape[1][1])/2:
|
|
self.use_w2 = True
|
|
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1], *k_size))
|
|
elif self.cp:
|
|
self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, shape[2], shape[3]))
|
|
self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, shape[0][1])) # b, 1-mode
|
|
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1])) # d, 2-mode
|
|
else: # Conv2d not cp
|
|
# bigger part. weight and LoRA. [b, dim] x [dim, d*k1*k2]
|
|
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
|
|
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]*shape[2]*shape[3]))
|
|
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d*k1*k2)) = (a, b)⊗(c, d*k1*k2) = (ac, bd*k1*k2)
|
|
|
|
self.op = F.conv2d
|
|
self.extra_args = {
|
|
"stride": org_module.stride,
|
|
"padding": org_module.padding,
|
|
"dilation": org_module.dilation,
|
|
"groups": org_module.groups
|
|
}
|
|
|
|
else: # Linear
|
|
in_dim = org_module.in_features
|
|
out_dim = org_module.out_features
|
|
|
|
in_m, in_n = factorization(in_dim, factor)
|
|
out_l, out_k = factorization(out_dim, factor)
|
|
shape = ((out_l, out_k), (in_m, in_n)) # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d
|
|
|
|
# smaller part. weight scale
|
|
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
|
|
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
|
|
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
|
|
else:
|
|
self.use_w1 = True
|
|
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
|
|
|
|
if lora_dim < max(shape[0][1], shape[1][1])/2:
|
|
# bigger part. weight and LoRA. [b, dim] x [dim, d]
|
|
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
|
|
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]))
|
|
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d)) = (a, b)⊗(c, d) = (ac, bd)
|
|
else:
|
|
self.use_w2 = True
|
|
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1]))
|
|
|
|
self.op = F.linear
|
|
self.extra_args = {}
|
|
|
|
self.dropout = dropout
|
|
if dropout:
|
|
print("[WARN]LoHa/LoKr haven't implemented normal dropout yet.")
|
|
self.rank_dropout = rank_dropout
|
|
self.module_dropout = module_dropout
|
|
|
|
if isinstance(alpha, torch.Tensor):
|
|
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
|
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
|
if self.use_w2 and self.use_w1:
|
|
#use scale = 1
|
|
alpha = lora_dim
|
|
self.scale = alpha / self.lora_dim
|
|
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
|
|
|
if self.use_w2:
|
|
torch.nn.init.constant_(self.lokr_w2, 0)
|
|
else:
|
|
if self.cp:
|
|
torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5))
|
|
torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5))
|
|
torch.nn.init.constant_(self.lokr_w2_b, 0)
|
|
|
|
if self.use_w1:
|
|
torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5))
|
|
else:
|
|
torch.nn.init.kaiming_uniform_(self.lokr_w1_a, a=math.sqrt(5))
|
|
torch.nn.init.kaiming_uniform_(self.lokr_w1_b, a=math.sqrt(5))
|
|
|
|
self.multiplier = multiplier
|
|
self.org_module = [org_module]
|
|
weight = make_kron(
|
|
self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b,
|
|
(self.lokr_w2 if self.use_w2
|
|
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
|
|
else self.lokr_w2_a@self.lokr_w2_b),
|
|
torch.tensor(self.multiplier * self.scale)
|
|
)
|
|
assert torch.sum(torch.isnan(weight)) == 0, "weight is nan"
|
|
|
|
# Same as locon.py
|
|
def apply_to(self):
|
|
self.org_forward = self.org_module[0].forward
|
|
self.org_module[0].forward = self.forward
|
|
|
|
def get_weight(self, orig_weight = None):
|
|
weight = make_kron(
|
|
self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b,
|
|
(self.lokr_w2 if self.use_w2
|
|
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
|
|
else self.lokr_w2_a@self.lokr_w2_b),
|
|
torch.tensor(self.scale)
|
|
)
|
|
if orig_weight is not None:
|
|
weight = weight.reshape(orig_weight.shape)
|
|
if self.training and self.rank_dropout:
|
|
drop = torch.rand(weight.size(0)) < self.rank_dropout
|
|
weight *= drop.view(-1, [1]*len(weight.shape[1:])).to(weight.device)
|
|
return weight
|
|
|
|
@torch.no_grad()
|
|
def apply_max_norm(self, max_norm, device=None):
|
|
orig_norm = self.get_weight().norm()
|
|
norm = torch.clamp(orig_norm, max_norm/2)
|
|
desired = torch.clamp(norm, max=max_norm)
|
|
ratio = desired.cpu()/norm.cpu()
|
|
|
|
scaled = ratio.item() != 1.0
|
|
if scaled:
|
|
modules = (4 - self.use_w1 - self.use_w2 + (not self.use_w2 and self.cp))
|
|
if self.use_w1:
|
|
self.lokr_w1 *= ratio**(1/modules)
|
|
else:
|
|
self.lokr_w1_a *= ratio**(1/modules)
|
|
self.lokr_w1_b *= ratio**(1/modules)
|
|
|
|
if self.use_w2:
|
|
self.lokr_w2 *= ratio**(1/modules)
|
|
else:
|
|
if self.cp:
|
|
self.lokr_t2 *= ratio**(1/modules)
|
|
self.lokr_w2_a *= ratio**(1/modules)
|
|
self.lokr_w2_b *= ratio**(1/modules)
|
|
|
|
return scaled, orig_norm*ratio
|
|
|
|
def forward(self, x):
|
|
if self.module_dropout and self.training:
|
|
if torch.rand(1) < self.module_dropout:
|
|
return self.op(
|
|
x,
|
|
self.org_module[0].weight.data,
|
|
None if self.org_module[0].bias is None else self.org_module[0].bias.data
|
|
)
|
|
weight = (
|
|
self.org_module[0].weight.data
|
|
+ self.get_weight(self.org_module[0].weight.data) * self.multiplier
|
|
)
|
|
bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
|
|
return self.op(
|
|
x,
|
|
weight.view(self.shape),
|
|
bias,
|
|
**self.extra_args
|
|
) |