mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Merge branch 'main' into wan21
This commit is contained in:
@@ -135,6 +135,15 @@ class NetworkConfig:
|
||||
self.conv = 4
|
||||
|
||||
self.transformer_only = kwargs.get('transformer_only', True)
|
||||
|
||||
self.lokr_full_rank = kwargs.get('lokr_full_rank', False)
|
||||
if self.lokr_full_rank and self.type.lower() == 'lokr':
|
||||
self.linear = 9999999999
|
||||
self.linear_alpha = 9999999999
|
||||
self.conv = 9999999999
|
||||
self.conv_alpha = 9999999999
|
||||
# -1 automatically finds the largest factor
|
||||
self.lokr_factor = kwargs.get('lokr_factor', -1)
|
||||
|
||||
|
||||
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net']
|
||||
|
||||
@@ -231,12 +231,18 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
if self.network_type.lower() == "dora":
|
||||
self.module_class = DoRAModule
|
||||
module_class = DoRAModule
|
||||
elif self.network_type.lower() == "lokr":
|
||||
self.module_class = LokrModule
|
||||
module_class = LokrModule
|
||||
self.network_config: NetworkConfig = kwargs.get("network_config", None)
|
||||
|
||||
self.peft_format = peft_format
|
||||
|
||||
# always do peft for flux only for now
|
||||
if self.is_flux or self.is_v3 or self.is_lumina2:
|
||||
self.peft_format = True
|
||||
# don't do peft format for lokr
|
||||
if self.network_type.lower() != "lokr":
|
||||
self.peft_format = True
|
||||
|
||||
if self.peft_format:
|
||||
# no alpha for peft
|
||||
@@ -338,8 +344,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
|
||||
if (is_linear or is_conv2d) and not skip:
|
||||
|
||||
if self.only_if_contains is not None and not any([word in clean_name for word in self.only_if_contains]):
|
||||
continue
|
||||
if self.only_if_contains is not None:
|
||||
if not any([word in clean_name for word in self.only_if_contains]) and not any([word in lora_name for word in self.only_if_contains]):
|
||||
continue
|
||||
|
||||
dim = None
|
||||
alpha = None
|
||||
@@ -373,6 +380,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.conv_lora_dim is not None or conv_block_dims is not None):
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
module_kwargs = {}
|
||||
|
||||
if self.network_type.lower() == "lokr":
|
||||
module_kwargs["factor"] = self.network_config.lokr_factor
|
||||
|
||||
lora = module_class(
|
||||
lora_name,
|
||||
@@ -386,10 +398,16 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
network=self,
|
||||
parent=module,
|
||||
use_bias=use_bias,
|
||||
**module_kwargs
|
||||
)
|
||||
loras.append(lora)
|
||||
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)
|
||||
]
|
||||
if self.network_type.lower() == "lokr":
|
||||
try:
|
||||
lora_shape_dict[lora_name] = [list(lora.lokr_w1.weight.shape), list(lora.lokr_w2.weight.shape)]
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)]
|
||||
return loras, skipped
|
||||
|
||||
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
||||
|
||||
@@ -10,24 +10,23 @@ from toolkit.network_mixins import ToolkitModuleMixin
|
||||
|
||||
from typing import TYPE_CHECKING, Union, List
|
||||
|
||||
from optimum.quanto import QBytesTensor, QTensor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
|
||||
# 4, build custom backward function
|
||||
# -
|
||||
|
||||
|
||||
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
||||
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]:
|
||||
'''
|
||||
return a tuple of two value of input dimension decomposed by the number closest to factor
|
||||
second value is higher or equal than first value.
|
||||
|
||||
|
||||
In LoRA with Kroneckor Product, first value is a value for weight scale.
|
||||
secon value is a value for weight.
|
||||
|
||||
|
||||
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
|
||||
|
||||
|
||||
examples)
|
||||
factor
|
||||
-1 2 4 8 16 ...
|
||||
@@ -38,7 +37,7 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
||||
512 -> 32, 16 512 -> 256, 2 512 -> 128, 4 512 -> 64, 8 512 -> 32, 16
|
||||
1024 -> 32, 32 1024 -> 512, 2 1024 -> 256, 4 1024 -> 128, 8 1024 -> 64, 16
|
||||
'''
|
||||
|
||||
|
||||
if factor > 0 and (dimension % factor) == 0:
|
||||
m = factor
|
||||
n = dimension // factor
|
||||
@@ -47,12 +46,12 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
||||
factor = dimension
|
||||
m, n = 1, dimension
|
||||
length = m + n
|
||||
while m<n:
|
||||
while m < n:
|
||||
new_m = m + 1
|
||||
while dimension%new_m != 0:
|
||||
while dimension % new_m != 0:
|
||||
new_m += 1
|
||||
new_n = dimension // new_m
|
||||
if new_m + new_n > length or new_m>factor:
|
||||
if new_m + new_n > length or new_m > factor:
|
||||
break
|
||||
else:
|
||||
m, n = new_m, new_n
|
||||
@@ -62,7 +61,8 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
||||
|
||||
|
||||
def make_weight_cp(t, wa, wb):
|
||||
rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l', t, wa, wb) # [c, d, k1, k2]
|
||||
rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l',
|
||||
t, wa, wb) # [c, d, k1, k2]
|
||||
return rebuild2
|
||||
|
||||
|
||||
@@ -71,31 +71,25 @@ def make_kron(w1, w2, scale):
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
rebuild = torch.kron(w1, w2)
|
||||
|
||||
|
||||
return rebuild*scale
|
||||
|
||||
|
||||
class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
"""
|
||||
modifed from kohya-ss/sd-scripts/networks/lora:LoRAModule
|
||||
and from KohakuBlueleaf/LyCORIS/lycoris:loha:LoHaModule
|
||||
and from KohakuBlueleaf/LyCORIS/lycoris:locon:LoconModule
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
dropout=0.,
|
||||
rank_dropout=0.,
|
||||
self,
|
||||
lora_name,
|
||||
org_module: nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
dropout=0.,
|
||||
rank_dropout=0.,
|
||||
module_dropout=0.,
|
||||
use_cp=False,
|
||||
decompose_both = False,
|
||||
decompose_both=False,
|
||||
network: 'LoRASpecialNetwork' = None,
|
||||
factor:int=-1, # factorization factor
|
||||
factor: int = -1, # factorization factor
|
||||
**kwargs,
|
||||
):
|
||||
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
||||
@@ -107,38 +101,49 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
self.cp = False
|
||||
self.use_w1 = False
|
||||
self.use_w2 = False
|
||||
self.can_merge_in = True
|
||||
|
||||
self.shape = org_module.weight.shape
|
||||
if org_module.__class__.__name__ == 'Conv2d':
|
||||
in_dim = org_module.in_channels
|
||||
k_size = org_module.kernel_size
|
||||
out_dim = org_module.out_channels
|
||||
|
||||
|
||||
in_m, in_n = factorization(in_dim, factor)
|
||||
out_l, out_k = factorization(out_dim, factor)
|
||||
shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), *k_size)
|
||||
|
||||
self.cp = use_cp and k_size!=(1, 1)
|
||||
# ((a, b), (c, d), *k_size)
|
||||
shape = ((out_l, out_k), (in_m, in_n), *k_size)
|
||||
|
||||
self.cp = use_cp and k_size != (1, 1)
|
||||
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
|
||||
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
|
||||
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
|
||||
self.lokr_w1_a = nn.Parameter(
|
||||
torch.empty(shape[0][0], lora_dim))
|
||||
self.lokr_w1_b = nn.Parameter(
|
||||
torch.empty(lora_dim, shape[1][0]))
|
||||
else:
|
||||
self.use_w1 = True
|
||||
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
|
||||
|
||||
self.lokr_w1 = nn.Parameter(torch.empty(
|
||||
shape[0][0], shape[1][0])) # a*c, 1-mode
|
||||
|
||||
if lora_dim >= max(shape[0][1], shape[1][1])/2:
|
||||
self.use_w2 = True
|
||||
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1], *k_size))
|
||||
self.lokr_w2 = nn.Parameter(torch.empty(
|
||||
shape[0][1], shape[1][1], *k_size))
|
||||
elif self.cp:
|
||||
self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, shape[2], shape[3]))
|
||||
self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, shape[0][1])) # b, 1-mode
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1])) # d, 2-mode
|
||||
else: # Conv2d not cp
|
||||
self.lokr_t2 = nn.Parameter(torch.empty(
|
||||
lora_dim, lora_dim, shape[2], shape[3]))
|
||||
self.lokr_w2_a = nn.Parameter(
|
||||
torch.empty(lora_dim, shape[0][1])) # b, 1-mode
|
||||
self.lokr_w2_b = nn.Parameter(
|
||||
torch.empty(lora_dim, shape[1][1])) # d, 2-mode
|
||||
else: # Conv2d not cp
|
||||
# bigger part. weight and LoRA. [b, dim] x [dim, d*k1*k2]
|
||||
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]*shape[2]*shape[3]))
|
||||
self.lokr_w2_a = nn.Parameter(
|
||||
torch.empty(shape[0][1], lora_dim))
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(
|
||||
lora_dim, shape[1][1]*shape[2]*shape[3]))
|
||||
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d*k1*k2)) = (a, b)⊗(c, d*k1*k2) = (ac, bd*k1*k2)
|
||||
|
||||
|
||||
self.op = F.conv2d
|
||||
self.extra_args = {
|
||||
"stride": org_module.stride,
|
||||
@@ -147,48 +152,55 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
"groups": org_module.groups
|
||||
}
|
||||
|
||||
else: # Linear
|
||||
else: # Linear
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
|
||||
|
||||
in_m, in_n = factorization(in_dim, factor)
|
||||
out_l, out_k = factorization(out_dim, factor)
|
||||
shape = ((out_l, out_k), (in_m, in_n)) # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d
|
||||
|
||||
# ((a, b), (c, d)), out_dim = a*c, in_dim = b*d
|
||||
shape = ((out_l, out_k), (in_m, in_n))
|
||||
|
||||
# smaller part. weight scale
|
||||
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
|
||||
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
|
||||
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
|
||||
self.lokr_w1_a = nn.Parameter(
|
||||
torch.empty(shape[0][0], lora_dim))
|
||||
self.lokr_w1_b = nn.Parameter(
|
||||
torch.empty(lora_dim, shape[1][0]))
|
||||
else:
|
||||
self.use_w1 = True
|
||||
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
|
||||
self.lokr_w1 = nn.Parameter(torch.empty(
|
||||
shape[0][0], shape[1][0])) # a*c, 1-mode
|
||||
|
||||
if lora_dim < max(shape[0][1], shape[1][1])/2:
|
||||
# bigger part. weight and LoRA. [b, dim] x [dim, d]
|
||||
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]))
|
||||
self.lokr_w2_a = nn.Parameter(
|
||||
torch.empty(shape[0][1], lora_dim))
|
||||
self.lokr_w2_b = nn.Parameter(
|
||||
torch.empty(lora_dim, shape[1][1]))
|
||||
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d)) = (a, b)⊗(c, d) = (ac, bd)
|
||||
else:
|
||||
self.use_w2 = True
|
||||
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1]))
|
||||
self.lokr_w2 = nn.Parameter(
|
||||
torch.empty(shape[0][1], shape[1][1]))
|
||||
|
||||
self.op = F.linear
|
||||
self.extra_args = {}
|
||||
|
||||
|
||||
self.dropout = dropout
|
||||
if dropout:
|
||||
print("[WARN]LoHa/LoKr haven't implemented normal dropout yet.")
|
||||
print("[WARN]LoKr haven't implemented normal dropout yet.")
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
|
||||
if isinstance(alpha, torch.Tensor):
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
||||
if self.use_w2 and self.use_w1:
|
||||
#use scale = 1
|
||||
# use scale = 1
|
||||
alpha = lora_dim
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
||||
self.register_buffer('alpha', torch.tensor(alpha)) # treat as constant
|
||||
|
||||
if self.use_w2:
|
||||
torch.nn.init.constant_(self.lokr_w2, 0)
|
||||
@@ -197,7 +209,7 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5))
|
||||
torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5))
|
||||
torch.nn.init.constant_(self.lokr_w2_b, 0)
|
||||
|
||||
|
||||
if self.use_w1:
|
||||
torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5))
|
||||
else:
|
||||
@@ -208,8 +220,8 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
self.org_module = [org_module]
|
||||
weight = make_kron(
|
||||
self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b,
|
||||
(self.lokr_w2 if self.use_w2
|
||||
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
|
||||
(self.lokr_w2 if self.use_w2
|
||||
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
|
||||
else self.lokr_w2_a@self.lokr_w2_b),
|
||||
torch.tensor(self.multiplier * self.scale)
|
||||
)
|
||||
@@ -219,12 +231,12 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module[0].forward
|
||||
self.org_module[0].forward = self.forward
|
||||
|
||||
def get_weight(self, orig_weight = None):
|
||||
|
||||
def get_weight(self, orig_weight=None):
|
||||
weight = make_kron(
|
||||
self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b,
|
||||
(self.lokr_w2 if self.use_w2
|
||||
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
|
||||
(self.lokr_w2 if self.use_w2
|
||||
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
|
||||
else self.lokr_w2_a@self.lokr_w2_b),
|
||||
torch.tensor(self.scale)
|
||||
)
|
||||
@@ -232,51 +244,88 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
weight = weight.reshape(orig_weight.shape)
|
||||
if self.training and self.rank_dropout:
|
||||
drop = torch.rand(weight.size(0)) < self.rank_dropout
|
||||
weight *= drop.view(-1, [1]*len(weight.shape[1:])).to(weight.device)
|
||||
weight *= drop.view(-1, [1] *
|
||||
len(weight.shape[1:])).to(weight.device)
|
||||
return weight
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_max_norm(self, max_norm, device=None):
|
||||
orig_norm = self.get_weight().norm()
|
||||
norm = torch.clamp(orig_norm, max_norm/2)
|
||||
desired = torch.clamp(norm, max=max_norm)
|
||||
ratio = desired.cpu()/norm.cpu()
|
||||
|
||||
scaled = ratio.item() != 1.0
|
||||
if scaled:
|
||||
modules = (4 - self.use_w1 - self.use_w2 + (not self.use_w2 and self.cp))
|
||||
if self.use_w1:
|
||||
self.lokr_w1 *= ratio**(1/modules)
|
||||
else:
|
||||
self.lokr_w1_a *= ratio**(1/modules)
|
||||
self.lokr_w1_b *= ratio**(1/modules)
|
||||
|
||||
if self.use_w2:
|
||||
self.lokr_w2 *= ratio**(1/modules)
|
||||
else:
|
||||
if self.cp:
|
||||
self.lokr_t2 *= ratio**(1/modules)
|
||||
self.lokr_w2_a *= ratio**(1/modules)
|
||||
self.lokr_w2_b *= ratio**(1/modules)
|
||||
|
||||
return scaled, orig_norm*ratio
|
||||
def merge_in(self, merge_weight=1.0):
|
||||
if not self.can_merge_in:
|
||||
return
|
||||
|
||||
def forward(self, x):
|
||||
if self.module_dropout and self.training:
|
||||
if torch.rand(1) < self.module_dropout:
|
||||
return self.op(
|
||||
x,
|
||||
self.org_module[0].weight.data,
|
||||
None if self.org_module[0].bias is None else self.org_module[0].bias.data
|
||||
)
|
||||
weight = (
|
||||
self.org_module[0].weight.data
|
||||
+ self.get_weight(self.org_module[0].weight.data) * self.multiplier
|
||||
# extract weight from org_module
|
||||
org_sd = self.org_module[0].state_dict()
|
||||
# todo find a way to merge in weights when doing quantized model
|
||||
if 'weight._data' in org_sd:
|
||||
# quantized weight
|
||||
return
|
||||
|
||||
weight_key = "weight"
|
||||
if 'weight._data' in org_sd:
|
||||
# quantized weight
|
||||
weight_key = "weight._data"
|
||||
|
||||
orig_dtype = org_sd[weight_key].dtype
|
||||
weight = org_sd[weight_key].float()
|
||||
|
||||
scale = self.scale
|
||||
# handle trainable scaler method locon does
|
||||
if hasattr(self, 'scalar'):
|
||||
scale = scale * self.scalar
|
||||
|
||||
lokr_weight = self.get_weight(weight)
|
||||
|
||||
merged_weight = (
|
||||
weight
|
||||
+ (lokr_weight * merge_weight).to(weight.device, dtype=weight.dtype)
|
||||
)
|
||||
bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
|
||||
return self.op(
|
||||
x,
|
||||
|
||||
# set weight to org_module
|
||||
org_sd[weight_key] = merged_weight.to(orig_dtype)
|
||||
self.org_module[0].load_state_dict(org_sd)
|
||||
|
||||
def get_orig_weight(self):
|
||||
weight = self.org_module[0].weight
|
||||
if isinstance(weight, QTensor) or isinstance(weight, QBytesTensor):
|
||||
return weight.dequantize().data.detach()
|
||||
else:
|
||||
return weight.data.detach()
|
||||
|
||||
def get_orig_bias(self):
|
||||
if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None:
|
||||
if isinstance(self.org_module[0].bias, QTensor) or isinstance(self.org_module[0].bias, QBytesTensor):
|
||||
return self.org_module[0].bias.dequantize().data.detach()
|
||||
else:
|
||||
return self.org_module[0].bias.data.detach()
|
||||
return None
|
||||
|
||||
def _call_forward(self, x):
|
||||
if isinstance(x, QTensor) or isinstance(x, QBytesTensor):
|
||||
x = x.dequantize()
|
||||
|
||||
orig_dtype = x.dtype
|
||||
|
||||
orig_weight = self.get_orig_weight()
|
||||
lokr_weight = self.get_weight(orig_weight).to(dtype=orig_weight.dtype)
|
||||
multiplier = self.network_ref().torch_multiplier
|
||||
|
||||
if x.dtype != orig_weight.dtype:
|
||||
x = x.to(dtype=orig_weight.dtype)
|
||||
|
||||
# we do not currently support split batch multipliers for lokr. Just do a mean
|
||||
multiplier = torch.mean(multiplier)
|
||||
|
||||
weight = (
|
||||
orig_weight
|
||||
+ lokr_weight * multiplier
|
||||
)
|
||||
bias = self.get_orig_bias()
|
||||
if bias is not None:
|
||||
bias = bias.to(weight.device, dtype=weight.dtype)
|
||||
output = self.op(
|
||||
x,
|
||||
weight.view(self.shape),
|
||||
bias,
|
||||
**self.extra_args
|
||||
)
|
||||
)
|
||||
return output.to(orig_dtype)
|
||||
|
||||
@@ -272,6 +272,9 @@ class ToolkitModuleMixin:
|
||||
# if self.__class__.__name__ == "DoRAModule":
|
||||
# # return dora forward
|
||||
# return self.dora_forward(x, *args, **kwargs)
|
||||
|
||||
if self.__class__.__name__ == "LokrModule":
|
||||
return self._call_forward(x)
|
||||
|
||||
org_forwarded = self.org_forward(x, *args, **kwargs)
|
||||
|
||||
@@ -540,6 +543,17 @@ class ToolkitNetworkMixin:
|
||||
new_save_dict[new_key] = value
|
||||
|
||||
save_dict = new_save_dict
|
||||
|
||||
|
||||
if self.network_type.lower() == "lokr":
|
||||
new_save_dict = {}
|
||||
for key, value in save_dict.items():
|
||||
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
|
||||
new_key = key
|
||||
new_key = new_key.replace('lora_transformer_', 'lycoris_')
|
||||
new_save_dict[new_key] = value
|
||||
|
||||
save_dict = new_save_dict
|
||||
|
||||
if metadata is None:
|
||||
metadata = OrderedDict()
|
||||
@@ -585,6 +599,10 @@ class ToolkitNetworkMixin:
|
||||
load_key = load_key.replace('.', '$$')
|
||||
load_key = load_key.replace('$$lora_down$$', '.lora_down.')
|
||||
load_key = load_key.replace('$$lora_up$$', '.lora_up.')
|
||||
|
||||
if self.network_type.lower() == "lokr":
|
||||
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
|
||||
load_key = load_key.replace('lycoris_', 'lora_transformer_')
|
||||
|
||||
load_sd[load_key] = value
|
||||
|
||||
@@ -616,9 +634,22 @@ class ToolkitNetworkMixin:
|
||||
# without having to set it in every single module every time it changes
|
||||
multiplier = self._multiplier
|
||||
# get first module
|
||||
first_module = self.get_all_modules()[0]
|
||||
device = first_module.lora_down.weight.device
|
||||
dtype = first_module.lora_down.weight.dtype
|
||||
try:
|
||||
first_module = self.get_all_modules()[0]
|
||||
except IndexError:
|
||||
raise ValueError("There are not any lora modules in this network. Check your config and try again")
|
||||
|
||||
if hasattr(first_module, 'lora_down'):
|
||||
device = first_module.lora_down.weight.device
|
||||
dtype = first_module.lora_down.weight.dtype
|
||||
elif hasattr(first_module, 'lokr_w1'):
|
||||
device = first_module.lokr_w1.device
|
||||
dtype = first_module.lokr_w1.dtype
|
||||
elif hasattr(first_module, 'lokr_w1_a'):
|
||||
device = first_module.lokr_w1_a.device
|
||||
dtype = first_module.lokr_w1_a.dtype
|
||||
else:
|
||||
raise ValueError("Unknown module type")
|
||||
with torch.no_grad():
|
||||
tensor_multiplier = None
|
||||
if isinstance(multiplier, int) or isinstance(multiplier, float):
|
||||
|
||||
@@ -1385,7 +1385,8 @@ class StableDiffusion:
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
||||
self.adapter(conditional_clip_embeds)
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) \
|
||||
and gen_config.adapter_image_path is not None:
|
||||
# handle condition the prompts
|
||||
gen_config.prompt = self.adapter.condition_prompt(
|
||||
gen_config.prompt,
|
||||
@@ -1439,7 +1440,7 @@ class StableDiffusion:
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False)
|
||||
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True)
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None:
|
||||
conditional_embeds = self.adapter.condition_encoded_embeds(
|
||||
tensors_0_1=validation_image,
|
||||
prompt_embeds=conditional_embeds,
|
||||
|
||||
Reference in New Issue
Block a user