Merge branch 'main' into wan21

This commit is contained in:
Jaret Burkett
2025-03-04 00:31:57 -07:00
11 changed files with 603 additions and 132 deletions

View File

@@ -135,6 +135,15 @@ class NetworkConfig:
self.conv = 4
self.transformer_only = kwargs.get('transformer_only', True)
self.lokr_full_rank = kwargs.get('lokr_full_rank', False)
if self.lokr_full_rank and self.type.lower() == 'lokr':
self.linear = 9999999999
self.linear_alpha = 9999999999
self.conv = 9999999999
self.conv_alpha = 9999999999
# -1 automatically finds the largest factor
self.lokr_factor = kwargs.get('lokr_factor', -1)
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net']

View File

@@ -231,12 +231,18 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if self.network_type.lower() == "dora":
self.module_class = DoRAModule
module_class = DoRAModule
elif self.network_type.lower() == "lokr":
self.module_class = LokrModule
module_class = LokrModule
self.network_config: NetworkConfig = kwargs.get("network_config", None)
self.peft_format = peft_format
# always do peft for flux only for now
if self.is_flux or self.is_v3 or self.is_lumina2:
self.peft_format = True
# don't do peft format for lokr
if self.network_type.lower() != "lokr":
self.peft_format = True
if self.peft_format:
# no alpha for peft
@@ -338,8 +344,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if (is_linear or is_conv2d) and not skip:
if self.only_if_contains is not None and not any([word in clean_name for word in self.only_if_contains]):
continue
if self.only_if_contains is not None:
if not any([word in clean_name for word in self.only_if_contains]) and not any([word in lora_name for word in self.only_if_contains]):
continue
dim = None
alpha = None
@@ -373,6 +380,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.conv_lora_dim is not None or conv_block_dims is not None):
skipped.append(lora_name)
continue
module_kwargs = {}
if self.network_type.lower() == "lokr":
module_kwargs["factor"] = self.network_config.lokr_factor
lora = module_class(
lora_name,
@@ -386,10 +398,16 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
network=self,
parent=module,
use_bias=use_bias,
**module_kwargs
)
loras.append(lora)
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)
]
if self.network_type.lower() == "lokr":
try:
lora_shape_dict[lora_name] = [list(lora.lokr_w1.weight.shape), list(lora.lokr_w2.weight.shape)]
except:
pass
else:
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)]
return loras, skipped
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]

View File

@@ -10,24 +10,23 @@ from toolkit.network_mixins import ToolkitModuleMixin
from typing import TYPE_CHECKING, Union, List
from optimum.quanto import QBytesTensor, QTensor
if TYPE_CHECKING:
from toolkit.lora_special import LoRASpecialNetwork
# 4, build custom backward function
# -
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]:
'''
return a tuple of two value of input dimension decomposed by the number closest to factor
second value is higher or equal than first value.
In LoRA with Kroneckor Product, first value is a value for weight scale.
secon value is a value for weight.
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
examples)
factor
-1 2 4 8 16 ...
@@ -38,7 +37,7 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
512 -> 32, 16 512 -> 256, 2 512 -> 128, 4 512 -> 64, 8 512 -> 32, 16
1024 -> 32, 32 1024 -> 512, 2 1024 -> 256, 4 1024 -> 128, 8 1024 -> 64, 16
'''
if factor > 0 and (dimension % factor) == 0:
m = factor
n = dimension // factor
@@ -47,12 +46,12 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
factor = dimension
m, n = 1, dimension
length = m + n
while m<n:
while m < n:
new_m = m + 1
while dimension%new_m != 0:
while dimension % new_m != 0:
new_m += 1
new_n = dimension // new_m
if new_m + new_n > length or new_m>factor:
if new_m + new_n > length or new_m > factor:
break
else:
m, n = new_m, new_n
@@ -62,7 +61,8 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
def make_weight_cp(t, wa, wb):
rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l', t, wa, wb) # [c, d, k1, k2]
rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l',
t, wa, wb) # [c, d, k1, k2]
return rebuild2
@@ -71,31 +71,25 @@ def make_kron(w1, w2, scale):
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
rebuild = torch.kron(w1, w2)
return rebuild*scale
class LokrModule(ToolkitModuleMixin, nn.Module):
"""
modifed from kohya-ss/sd-scripts/networks/lora:LoRAModule
and from KohakuBlueleaf/LyCORIS/lycoris:loha:LoHaModule
and from KohakuBlueleaf/LyCORIS/lycoris:locon:LoconModule
"""
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=0.,
rank_dropout=0.,
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=0.,
rank_dropout=0.,
module_dropout=0.,
use_cp=False,
decompose_both = False,
decompose_both=False,
network: 'LoRASpecialNetwork' = None,
factor:int=-1, # factorization factor
factor: int = -1, # factorization factor
**kwargs,
):
""" if alpha == 0 or None, alpha is rank (no scaling). """
@@ -107,38 +101,49 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
self.cp = False
self.use_w1 = False
self.use_w2 = False
self.can_merge_in = True
self.shape = org_module.weight.shape
if org_module.__class__.__name__ == 'Conv2d':
in_dim = org_module.in_channels
k_size = org_module.kernel_size
out_dim = org_module.out_channels
in_m, in_n = factorization(in_dim, factor)
out_l, out_k = factorization(out_dim, factor)
shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), *k_size)
self.cp = use_cp and k_size!=(1, 1)
# ((a, b), (c, d), *k_size)
shape = ((out_l, out_k), (in_m, in_n), *k_size)
self.cp = use_cp and k_size != (1, 1)
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
self.lokr_w1_a = nn.Parameter(
torch.empty(shape[0][0], lora_dim))
self.lokr_w1_b = nn.Parameter(
torch.empty(lora_dim, shape[1][0]))
else:
self.use_w1 = True
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
self.lokr_w1 = nn.Parameter(torch.empty(
shape[0][0], shape[1][0])) # a*c, 1-mode
if lora_dim >= max(shape[0][1], shape[1][1])/2:
self.use_w2 = True
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1], *k_size))
self.lokr_w2 = nn.Parameter(torch.empty(
shape[0][1], shape[1][1], *k_size))
elif self.cp:
self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, shape[2], shape[3]))
self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, shape[0][1])) # b, 1-mode
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1])) # d, 2-mode
else: # Conv2d not cp
self.lokr_t2 = nn.Parameter(torch.empty(
lora_dim, lora_dim, shape[2], shape[3]))
self.lokr_w2_a = nn.Parameter(
torch.empty(lora_dim, shape[0][1])) # b, 1-mode
self.lokr_w2_b = nn.Parameter(
torch.empty(lora_dim, shape[1][1])) # d, 2-mode
else: # Conv2d not cp
# bigger part. weight and LoRA. [b, dim] x [dim, d*k1*k2]
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]*shape[2]*shape[3]))
self.lokr_w2_a = nn.Parameter(
torch.empty(shape[0][1], lora_dim))
self.lokr_w2_b = nn.Parameter(torch.empty(
lora_dim, shape[1][1]*shape[2]*shape[3]))
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d*k1*k2)) = (a, b)⊗(c, d*k1*k2) = (ac, bd*k1*k2)
self.op = F.conv2d
self.extra_args = {
"stride": org_module.stride,
@@ -147,48 +152,55 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
"groups": org_module.groups
}
else: # Linear
else: # Linear
in_dim = org_module.in_features
out_dim = org_module.out_features
in_m, in_n = factorization(in_dim, factor)
out_l, out_k = factorization(out_dim, factor)
shape = ((out_l, out_k), (in_m, in_n)) # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d
# ((a, b), (c, d)), out_dim = a*c, in_dim = b*d
shape = ((out_l, out_k), (in_m, in_n))
# smaller part. weight scale
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
self.lokr_w1_a = nn.Parameter(
torch.empty(shape[0][0], lora_dim))
self.lokr_w1_b = nn.Parameter(
torch.empty(lora_dim, shape[1][0]))
else:
self.use_w1 = True
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
self.lokr_w1 = nn.Parameter(torch.empty(
shape[0][0], shape[1][0])) # a*c, 1-mode
if lora_dim < max(shape[0][1], shape[1][1])/2:
# bigger part. weight and LoRA. [b, dim] x [dim, d]
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]))
self.lokr_w2_a = nn.Parameter(
torch.empty(shape[0][1], lora_dim))
self.lokr_w2_b = nn.Parameter(
torch.empty(lora_dim, shape[1][1]))
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d)) = (a, b)⊗(c, d) = (ac, bd)
else:
self.use_w2 = True
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1]))
self.lokr_w2 = nn.Parameter(
torch.empty(shape[0][1], shape[1][1]))
self.op = F.linear
self.extra_args = {}
self.dropout = dropout
if dropout:
print("[WARN]LoHa/LoKr haven't implemented normal dropout yet.")
print("[WARN]LoKr haven't implemented normal dropout yet.")
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
if isinstance(alpha, torch.Tensor):
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = lora_dim if alpha is None or alpha == 0 else alpha
if self.use_w2 and self.use_w1:
#use scale = 1
# use scale = 1
alpha = lora_dim
self.scale = alpha / self.lora_dim
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
self.register_buffer('alpha', torch.tensor(alpha)) # treat as constant
if self.use_w2:
torch.nn.init.constant_(self.lokr_w2, 0)
@@ -197,7 +209,7 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5))
torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5))
torch.nn.init.constant_(self.lokr_w2_b, 0)
if self.use_w1:
torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5))
else:
@@ -208,8 +220,8 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
self.org_module = [org_module]
weight = make_kron(
self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b,
(self.lokr_w2 if self.use_w2
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
(self.lokr_w2 if self.use_w2
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
else self.lokr_w2_a@self.lokr_w2_b),
torch.tensor(self.multiplier * self.scale)
)
@@ -219,12 +231,12 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
def apply_to(self):
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
def get_weight(self, orig_weight = None):
def get_weight(self, orig_weight=None):
weight = make_kron(
self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b,
(self.lokr_w2 if self.use_w2
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
(self.lokr_w2 if self.use_w2
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
else self.lokr_w2_a@self.lokr_w2_b),
torch.tensor(self.scale)
)
@@ -232,51 +244,88 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
weight = weight.reshape(orig_weight.shape)
if self.training and self.rank_dropout:
drop = torch.rand(weight.size(0)) < self.rank_dropout
weight *= drop.view(-1, [1]*len(weight.shape[1:])).to(weight.device)
weight *= drop.view(-1, [1] *
len(weight.shape[1:])).to(weight.device)
return weight
@torch.no_grad()
def apply_max_norm(self, max_norm, device=None):
orig_norm = self.get_weight().norm()
norm = torch.clamp(orig_norm, max_norm/2)
desired = torch.clamp(norm, max=max_norm)
ratio = desired.cpu()/norm.cpu()
scaled = ratio.item() != 1.0
if scaled:
modules = (4 - self.use_w1 - self.use_w2 + (not self.use_w2 and self.cp))
if self.use_w1:
self.lokr_w1 *= ratio**(1/modules)
else:
self.lokr_w1_a *= ratio**(1/modules)
self.lokr_w1_b *= ratio**(1/modules)
if self.use_w2:
self.lokr_w2 *= ratio**(1/modules)
else:
if self.cp:
self.lokr_t2 *= ratio**(1/modules)
self.lokr_w2_a *= ratio**(1/modules)
self.lokr_w2_b *= ratio**(1/modules)
return scaled, orig_norm*ratio
def merge_in(self, merge_weight=1.0):
if not self.can_merge_in:
return
def forward(self, x):
if self.module_dropout and self.training:
if torch.rand(1) < self.module_dropout:
return self.op(
x,
self.org_module[0].weight.data,
None if self.org_module[0].bias is None else self.org_module[0].bias.data
)
weight = (
self.org_module[0].weight.data
+ self.get_weight(self.org_module[0].weight.data) * self.multiplier
# extract weight from org_module
org_sd = self.org_module[0].state_dict()
# todo find a way to merge in weights when doing quantized model
if 'weight._data' in org_sd:
# quantized weight
return
weight_key = "weight"
if 'weight._data' in org_sd:
# quantized weight
weight_key = "weight._data"
orig_dtype = org_sd[weight_key].dtype
weight = org_sd[weight_key].float()
scale = self.scale
# handle trainable scaler method locon does
if hasattr(self, 'scalar'):
scale = scale * self.scalar
lokr_weight = self.get_weight(weight)
merged_weight = (
weight
+ (lokr_weight * merge_weight).to(weight.device, dtype=weight.dtype)
)
bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
return self.op(
x,
# set weight to org_module
org_sd[weight_key] = merged_weight.to(orig_dtype)
self.org_module[0].load_state_dict(org_sd)
def get_orig_weight(self):
weight = self.org_module[0].weight
if isinstance(weight, QTensor) or isinstance(weight, QBytesTensor):
return weight.dequantize().data.detach()
else:
return weight.data.detach()
def get_orig_bias(self):
if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None:
if isinstance(self.org_module[0].bias, QTensor) or isinstance(self.org_module[0].bias, QBytesTensor):
return self.org_module[0].bias.dequantize().data.detach()
else:
return self.org_module[0].bias.data.detach()
return None
def _call_forward(self, x):
if isinstance(x, QTensor) or isinstance(x, QBytesTensor):
x = x.dequantize()
orig_dtype = x.dtype
orig_weight = self.get_orig_weight()
lokr_weight = self.get_weight(orig_weight).to(dtype=orig_weight.dtype)
multiplier = self.network_ref().torch_multiplier
if x.dtype != orig_weight.dtype:
x = x.to(dtype=orig_weight.dtype)
# we do not currently support split batch multipliers for lokr. Just do a mean
multiplier = torch.mean(multiplier)
weight = (
orig_weight
+ lokr_weight * multiplier
)
bias = self.get_orig_bias()
if bias is not None:
bias = bias.to(weight.device, dtype=weight.dtype)
output = self.op(
x,
weight.view(self.shape),
bias,
**self.extra_args
)
)
return output.to(orig_dtype)

View File

@@ -272,6 +272,9 @@ class ToolkitModuleMixin:
# if self.__class__.__name__ == "DoRAModule":
# # return dora forward
# return self.dora_forward(x, *args, **kwargs)
if self.__class__.__name__ == "LokrModule":
return self._call_forward(x)
org_forwarded = self.org_forward(x, *args, **kwargs)
@@ -540,6 +543,17 @@ class ToolkitNetworkMixin:
new_save_dict[new_key] = value
save_dict = new_save_dict
if self.network_type.lower() == "lokr":
new_save_dict = {}
for key, value in save_dict.items():
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
new_key = key
new_key = new_key.replace('lora_transformer_', 'lycoris_')
new_save_dict[new_key] = value
save_dict = new_save_dict
if metadata is None:
metadata = OrderedDict()
@@ -585,6 +599,10 @@ class ToolkitNetworkMixin:
load_key = load_key.replace('.', '$$')
load_key = load_key.replace('$$lora_down$$', '.lora_down.')
load_key = load_key.replace('$$lora_up$$', '.lora_up.')
if self.network_type.lower() == "lokr":
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
load_key = load_key.replace('lycoris_', 'lora_transformer_')
load_sd[load_key] = value
@@ -616,9 +634,22 @@ class ToolkitNetworkMixin:
# without having to set it in every single module every time it changes
multiplier = self._multiplier
# get first module
first_module = self.get_all_modules()[0]
device = first_module.lora_down.weight.device
dtype = first_module.lora_down.weight.dtype
try:
first_module = self.get_all_modules()[0]
except IndexError:
raise ValueError("There are not any lora modules in this network. Check your config and try again")
if hasattr(first_module, 'lora_down'):
device = first_module.lora_down.weight.device
dtype = first_module.lora_down.weight.dtype
elif hasattr(first_module, 'lokr_w1'):
device = first_module.lokr_w1.device
dtype = first_module.lokr_w1.dtype
elif hasattr(first_module, 'lokr_w1_a'):
device = first_module.lokr_w1_a.device
dtype = first_module.lokr_w1_a.dtype
else:
raise ValueError("Unknown module type")
with torch.no_grad():
tensor_multiplier = None
if isinstance(multiplier, int) or isinstance(multiplier, float):

View File

@@ -1385,7 +1385,8 @@ class StableDiffusion:
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
self.adapter(conditional_clip_embeds)
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) \
and gen_config.adapter_image_path is not None:
# handle condition the prompts
gen_config.prompt = self.adapter.condition_prompt(
gen_config.prompt,
@@ -1439,7 +1440,7 @@ class StableDiffusion:
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False)
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True)
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None:
conditional_embeds = self.adapter.condition_encoded_embeds(
tensors_0_1=validation_image,
prompt_embeds=conditional_embeds,