mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
332 lines
12 KiB
Python
332 lines
12 KiB
Python
# based heavily on https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/lokr.py
|
|
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from toolkit.network_mixins import ToolkitModuleMixin
|
|
|
|
from typing import TYPE_CHECKING, Union, List
|
|
|
|
from optimum.quanto import QBytesTensor, QTensor
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
from toolkit.lora_special import LoRASpecialNetwork
|
|
|
|
|
|
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]:
|
|
'''
|
|
return a tuple of two value of input dimension decomposed by the number closest to factor
|
|
second value is higher or equal than first value.
|
|
|
|
In LoRA with Kroneckor Product, first value is a value for weight scale.
|
|
secon value is a value for weight.
|
|
|
|
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
|
|
|
|
examples)
|
|
factor
|
|
-1 2 4 8 16 ...
|
|
127 -> 127, 1 127 -> 127, 1 127 -> 127, 1 127 -> 127, 1 127 -> 127, 1
|
|
128 -> 16, 8 128 -> 64, 2 128 -> 32, 4 128 -> 16, 8 128 -> 16, 8
|
|
250 -> 125, 2 250 -> 125, 2 250 -> 125, 2 250 -> 125, 2 250 -> 125, 2
|
|
360 -> 45, 8 360 -> 180, 2 360 -> 90, 4 360 -> 45, 8 360 -> 45, 8
|
|
512 -> 32, 16 512 -> 256, 2 512 -> 128, 4 512 -> 64, 8 512 -> 32, 16
|
|
1024 -> 32, 32 1024 -> 512, 2 1024 -> 256, 4 1024 -> 128, 8 1024 -> 64, 16
|
|
'''
|
|
|
|
if factor > 0 and (dimension % factor) == 0:
|
|
m = factor
|
|
n = dimension // factor
|
|
return m, n
|
|
if factor == -1:
|
|
factor = dimension
|
|
m, n = 1, dimension
|
|
length = m + n
|
|
while m < n:
|
|
new_m = m + 1
|
|
while dimension % new_m != 0:
|
|
new_m += 1
|
|
new_n = dimension // new_m
|
|
if new_m + new_n > length or new_m > factor:
|
|
break
|
|
else:
|
|
m, n = new_m, new_n
|
|
if m > n:
|
|
n, m = m, n
|
|
return m, n
|
|
|
|
|
|
def make_weight_cp(t, wa, wb):
|
|
rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l',
|
|
t, wa, wb) # [c, d, k1, k2]
|
|
return rebuild2
|
|
|
|
|
|
def make_kron(w1, w2, scale):
|
|
if len(w2.shape) == 4:
|
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
w2 = w2.contiguous()
|
|
rebuild = torch.kron(w1, w2)
|
|
|
|
return rebuild*scale
|
|
|
|
|
|
class LokrModule(ToolkitModuleMixin, nn.Module):
|
|
def __init__(
|
|
self,
|
|
lora_name,
|
|
org_module: nn.Module,
|
|
multiplier=1.0,
|
|
lora_dim=4,
|
|
alpha=1,
|
|
dropout=0.,
|
|
rank_dropout=0.,
|
|
module_dropout=0.,
|
|
use_cp=False,
|
|
decompose_both=False,
|
|
network: 'LoRASpecialNetwork' = None,
|
|
factor: int = -1, # factorization factor
|
|
**kwargs,
|
|
):
|
|
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
|
ToolkitModuleMixin.__init__(self, network=network)
|
|
torch.nn.Module.__init__(self)
|
|
factor = int(factor)
|
|
self.lora_name = lora_name
|
|
self.lora_dim = lora_dim
|
|
self.cp = False
|
|
self.use_w1 = False
|
|
self.use_w2 = False
|
|
self.can_merge_in = True
|
|
|
|
self.shape = org_module.weight.shape
|
|
if org_module.__class__.__name__ == 'Conv2d':
|
|
in_dim = org_module.in_channels
|
|
k_size = org_module.kernel_size
|
|
out_dim = org_module.out_channels
|
|
|
|
in_m, in_n = factorization(in_dim, factor)
|
|
out_l, out_k = factorization(out_dim, factor)
|
|
# ((a, b), (c, d), *k_size)
|
|
shape = ((out_l, out_k), (in_m, in_n), *k_size)
|
|
|
|
self.cp = use_cp and k_size != (1, 1)
|
|
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
|
|
self.lokr_w1_a = nn.Parameter(
|
|
torch.empty(shape[0][0], lora_dim))
|
|
self.lokr_w1_b = nn.Parameter(
|
|
torch.empty(lora_dim, shape[1][0]))
|
|
else:
|
|
self.use_w1 = True
|
|
self.lokr_w1 = nn.Parameter(torch.empty(
|
|
shape[0][0], shape[1][0])) # a*c, 1-mode
|
|
|
|
if lora_dim >= max(shape[0][1], shape[1][1])/2:
|
|
self.use_w2 = True
|
|
self.lokr_w2 = nn.Parameter(torch.empty(
|
|
shape[0][1], shape[1][1], *k_size))
|
|
elif self.cp:
|
|
self.lokr_t2 = nn.Parameter(torch.empty(
|
|
lora_dim, lora_dim, shape[2], shape[3]))
|
|
self.lokr_w2_a = nn.Parameter(
|
|
torch.empty(lora_dim, shape[0][1])) # b, 1-mode
|
|
self.lokr_w2_b = nn.Parameter(
|
|
torch.empty(lora_dim, shape[1][1])) # d, 2-mode
|
|
else: # Conv2d not cp
|
|
# bigger part. weight and LoRA. [b, dim] x [dim, d*k1*k2]
|
|
self.lokr_w2_a = nn.Parameter(
|
|
torch.empty(shape[0][1], lora_dim))
|
|
self.lokr_w2_b = nn.Parameter(torch.empty(
|
|
lora_dim, shape[1][1]*shape[2]*shape[3]))
|
|
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d*k1*k2)) = (a, b)⊗(c, d*k1*k2) = (ac, bd*k1*k2)
|
|
|
|
self.op = F.conv2d
|
|
self.extra_args = {
|
|
"stride": org_module.stride,
|
|
"padding": org_module.padding,
|
|
"dilation": org_module.dilation,
|
|
"groups": org_module.groups
|
|
}
|
|
|
|
else: # Linear
|
|
in_dim = org_module.in_features
|
|
out_dim = org_module.out_features
|
|
|
|
in_m, in_n = factorization(in_dim, factor)
|
|
out_l, out_k = factorization(out_dim, factor)
|
|
# ((a, b), (c, d)), out_dim = a*c, in_dim = b*d
|
|
shape = ((out_l, out_k), (in_m, in_n))
|
|
|
|
# smaller part. weight scale
|
|
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
|
|
self.lokr_w1_a = nn.Parameter(
|
|
torch.empty(shape[0][0], lora_dim))
|
|
self.lokr_w1_b = nn.Parameter(
|
|
torch.empty(lora_dim, shape[1][0]))
|
|
else:
|
|
self.use_w1 = True
|
|
self.lokr_w1 = nn.Parameter(torch.empty(
|
|
shape[0][0], shape[1][0])) # a*c, 1-mode
|
|
|
|
if lora_dim < max(shape[0][1], shape[1][1])/2:
|
|
# bigger part. weight and LoRA. [b, dim] x [dim, d]
|
|
self.lokr_w2_a = nn.Parameter(
|
|
torch.empty(shape[0][1], lora_dim))
|
|
self.lokr_w2_b = nn.Parameter(
|
|
torch.empty(lora_dim, shape[1][1]))
|
|
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d)) = (a, b)⊗(c, d) = (ac, bd)
|
|
else:
|
|
self.use_w2 = True
|
|
self.lokr_w2 = nn.Parameter(
|
|
torch.empty(shape[0][1], shape[1][1]))
|
|
|
|
self.op = F.linear
|
|
self.extra_args = {}
|
|
|
|
self.dropout = dropout
|
|
if dropout:
|
|
print("[WARN]LoKr haven't implemented normal dropout yet.")
|
|
self.rank_dropout = rank_dropout
|
|
self.module_dropout = module_dropout
|
|
|
|
if isinstance(alpha, torch.Tensor):
|
|
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
|
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
|
if self.use_w2 and self.use_w1:
|
|
# use scale = 1
|
|
alpha = lora_dim
|
|
self.scale = alpha / self.lora_dim
|
|
self.register_buffer('alpha', torch.tensor(alpha)) # treat as constant
|
|
|
|
if self.use_w2:
|
|
torch.nn.init.constant_(self.lokr_w2, 0)
|
|
else:
|
|
if self.cp:
|
|
torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5))
|
|
torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5))
|
|
torch.nn.init.constant_(self.lokr_w2_b, 0)
|
|
|
|
if self.use_w1:
|
|
torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5))
|
|
else:
|
|
torch.nn.init.kaiming_uniform_(self.lokr_w1_a, a=math.sqrt(5))
|
|
torch.nn.init.kaiming_uniform_(self.lokr_w1_b, a=math.sqrt(5))
|
|
|
|
self.multiplier = multiplier
|
|
self.org_module = [org_module]
|
|
weight = make_kron(
|
|
self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b,
|
|
(self.lokr_w2 if self.use_w2
|
|
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
|
|
else self.lokr_w2_a@self.lokr_w2_b),
|
|
torch.tensor(self.multiplier * self.scale)
|
|
)
|
|
assert torch.sum(torch.isnan(weight)) == 0, "weight is nan"
|
|
|
|
# Same as locon.py
|
|
def apply_to(self):
|
|
self.org_forward = self.org_module[0].forward
|
|
self.org_module[0].forward = self.forward
|
|
|
|
def get_weight(self, orig_weight=None):
|
|
weight = make_kron(
|
|
self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b,
|
|
(self.lokr_w2 if self.use_w2
|
|
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
|
|
else self.lokr_w2_a@self.lokr_w2_b),
|
|
torch.tensor(self.scale)
|
|
)
|
|
if orig_weight is not None:
|
|
weight = weight.reshape(orig_weight.shape)
|
|
if self.training and self.rank_dropout:
|
|
drop = torch.rand(weight.size(0)) < self.rank_dropout
|
|
weight *= drop.view(-1, [1] *
|
|
len(weight.shape[1:])).to(weight.device)
|
|
return weight
|
|
|
|
@torch.no_grad()
|
|
def merge_in(self, merge_weight=1.0):
|
|
if not self.can_merge_in:
|
|
return
|
|
|
|
# extract weight from org_module
|
|
org_sd = self.org_module[0].state_dict()
|
|
# todo find a way to merge in weights when doing quantized model
|
|
if 'weight._data' in org_sd:
|
|
# quantized weight
|
|
return
|
|
|
|
weight_key = "weight"
|
|
if 'weight._data' in org_sd:
|
|
# quantized weight
|
|
weight_key = "weight._data"
|
|
|
|
orig_dtype = org_sd[weight_key].dtype
|
|
weight = org_sd[weight_key].float()
|
|
|
|
scale = self.scale
|
|
# handle trainable scaler method locon does
|
|
if hasattr(self, 'scalar'):
|
|
scale = scale * self.scalar
|
|
|
|
lokr_weight = self.get_weight(weight)
|
|
|
|
merged_weight = (
|
|
weight
|
|
+ (lokr_weight * merge_weight).to(weight.device, dtype=weight.dtype)
|
|
)
|
|
|
|
# set weight to org_module
|
|
org_sd[weight_key] = merged_weight.to(orig_dtype)
|
|
self.org_module[0].load_state_dict(org_sd)
|
|
|
|
def get_orig_weight(self):
|
|
weight = self.org_module[0].weight
|
|
if isinstance(weight, QTensor) or isinstance(weight, QBytesTensor):
|
|
return weight.dequantize().data.detach()
|
|
else:
|
|
return weight.data.detach()
|
|
|
|
def get_orig_bias(self):
|
|
if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None:
|
|
if isinstance(self.org_module[0].bias, QTensor) or isinstance(self.org_module[0].bias, QBytesTensor):
|
|
return self.org_module[0].bias.dequantize().data.detach()
|
|
else:
|
|
return self.org_module[0].bias.data.detach()
|
|
return None
|
|
|
|
def _call_forward(self, x):
|
|
if isinstance(x, QTensor) or isinstance(x, QBytesTensor):
|
|
x = x.dequantize()
|
|
|
|
orig_dtype = x.dtype
|
|
|
|
orig_weight = self.get_orig_weight()
|
|
lokr_weight = self.get_weight(orig_weight).to(dtype=orig_weight.dtype)
|
|
multiplier = self.network_ref().torch_multiplier
|
|
|
|
if x.dtype != orig_weight.dtype:
|
|
x = x.to(dtype=orig_weight.dtype)
|
|
|
|
# we do not currently support split batch multipliers for lokr. Just do a mean
|
|
multiplier = torch.mean(multiplier)
|
|
|
|
weight = (
|
|
orig_weight
|
|
+ lokr_weight * multiplier
|
|
)
|
|
bias = self.get_orig_bias()
|
|
if bias is not None:
|
|
bias = bias.to(weight.device, dtype=weight.dtype)
|
|
output = self.op(
|
|
x,
|
|
weight.view(self.shape),
|
|
bias,
|
|
**self.extra_args
|
|
)
|
|
return output.to(orig_dtype)
|