rework lora and patching system

and dora etc - backend rework is 60% finished
And I also removed the webui’s extremely annoying lora filter from model versions.
This commit is contained in:
layerdiffusion
2024-08-02 13:45:06 -07:00
parent f367b07282
commit d1b8a2676d
18 changed files with 1668 additions and 1615 deletions

View File

@@ -0,0 +1,134 @@
UNET_MAP_ATTENTIONS = {
"proj_in.weight",
"proj_in.bias",
"proj_out.weight",
"proj_out.bias",
"norm.weight",
"norm.bias",
}
TRANSFORMER_BLOCKS = {
"norm1.weight",
"norm1.bias",
"norm2.weight",
"norm2.bias",
"norm3.weight",
"norm3.bias",
"attn1.to_q.weight",
"attn1.to_k.weight",
"attn1.to_v.weight",
"attn1.to_out.0.weight",
"attn1.to_out.0.bias",
"attn2.to_q.weight",
"attn2.to_k.weight",
"attn2.to_v.weight",
"attn2.to_out.0.weight",
"attn2.to_out.0.bias",
"ff.net.0.proj.weight",
"ff.net.0.proj.bias",
"ff.net.2.weight",
"ff.net.2.bias",
}
UNET_MAP_RESNET = {
"in_layers.2.weight": "conv1.weight",
"in_layers.2.bias": "conv1.bias",
"emb_layers.1.weight": "time_emb_proj.weight",
"emb_layers.1.bias": "time_emb_proj.bias",
"out_layers.3.weight": "conv2.weight",
"out_layers.3.bias": "conv2.bias",
"skip_connection.weight": "conv_shortcut.weight",
"skip_connection.bias": "conv_shortcut.bias",
"in_layers.0.weight": "norm1.weight",
"in_layers.0.bias": "norm1.bias",
"out_layers.0.weight": "norm2.weight",
"out_layers.0.bias": "norm2.bias",
}
UNET_MAP_BASIC = {
("label_emb.0.0.weight", "class_embedding.linear_1.weight"),
("label_emb.0.0.bias", "class_embedding.linear_1.bias"),
("label_emb.0.2.weight", "class_embedding.linear_2.weight"),
("label_emb.0.2.bias", "class_embedding.linear_2.bias"),
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
("input_blocks.0.0.weight", "conv_in.weight"),
("input_blocks.0.0.bias", "conv_in.bias"),
("out.0.weight", "conv_norm_out.weight"),
("out.0.bias", "conv_norm_out.bias"),
("out.2.weight", "conv_out.weight"),
("out.2.bias", "conv_out.bias"),
("time_embed.0.weight", "time_embedding.linear_1.weight"),
("time_embed.0.bias", "time_embedding.linear_1.bias"),
("time_embed.2.weight", "time_embedding.linear_2.weight"),
("time_embed.2.bias", "time_embedding.linear_2.bias")
}
def unet_to_diffusers(unet_config):
if "num_res_blocks" not in unet_config:
return {}
num_res_blocks = unet_config["num_res_blocks"]
channel_mult = unet_config["channel_mult"]
transformer_depth = unet_config["transformer_depth"][:]
transformer_depth_output = unet_config["transformer_depth_output"][:]
num_blocks = len(channel_mult)
transformers_mid = unet_config.get("transformer_depth_middle", None)
diffusers_unet_map = {}
for x in range(num_blocks):
n = 1 + (num_res_blocks[x] + 1) * x
for i in range(num_res_blocks[x]):
for b in UNET_MAP_RESNET:
diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
num_transformers = transformer_depth.pop(0)
if num_transformers > 0:
for b in UNET_MAP_ATTENTIONS:
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
for t in range(num_transformers):
for b in TRANSFORMER_BLOCKS:
diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
n += 1
for k in ["weight", "bias"]:
diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k)
i = 0
for b in UNET_MAP_ATTENTIONS:
diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b)
for t in range(transformers_mid):
for b in TRANSFORMER_BLOCKS:
diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b)
for i, n in enumerate([0, 2]):
for b in UNET_MAP_RESNET:
diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
num_res_blocks = list(reversed(num_res_blocks))
for x in range(num_blocks):
n = (num_res_blocks[x] + 1) * x
l = num_res_blocks[x] + 1
for i in range(l):
c = 0
for b in UNET_MAP_RESNET:
diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
c += 1
num_transformers = transformer_depth_output.pop()
if num_transformers > 0:
c += 1
for b in UNET_MAP_ATTENTIONS:
diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
for t in range(num_transformers):
for b in TRANSFORMER_BLOCKS:
diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
if i == l - 1:
for k in ["weight", "bias"]:
diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k)
n += 1
for k in UNET_MAP_BASIC:
diffusers_unet_map[k[1]] = k[0]
return diffusers_unet_map

500
backend/patcher/base.py Normal file
View File

@@ -0,0 +1,500 @@
# Model Patching API Template Extracted From ComfyUI
# The actual implementation for those APIs are from Forge, implemented from scratch (after forge-v1.0.1),
# and may have certain level of differences.
import torch
import copy
import inspect
from backend import memory_management, utils
extra_weight_calculators = {}
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
dora_scale = memory_management.cast_to_device(dora_scale, weight.device, torch.float32)
lora_diff *= alpha
weight_calc = weight + lora_diff.type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight
weight += strength * weight_calc
else:
weight[:] = weight_calc
return weight
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy()
if "patches_replace" not in to:
to["patches_replace"] = {}
else:
to["patches_replace"] = to["patches_replace"].copy()
if name not in to["patches_replace"]:
to["patches_replace"][name] = {}
else:
to["patches_replace"][name] = to["patches_replace"][name].copy()
if transformer_index is not None:
block = (block_name, number, transformer_index)
else:
block = (block_name, number)
to["patches_replace"][name][block] = patch
model_options["transformer_options"] = to
return model_options
def set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=False):
model_options["sampler_post_cfg_function"] = model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
if disable_cfg1_optimization:
model_options["disable_cfg1_optimization"] = True
return model_options
def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False):
model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function]
if disable_cfg1_optimization:
model_options["disable_cfg1_optimization"] = True
return model_options
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
self.size = size
self.model = model
self.patches = {}
self.backup = {}
self.object_patches = {}
self.object_patches_backup = {}
self.model_options = {"transformer_options": {}}
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
if current_device is None:
self.current_device = self.offload_device
else:
self.current_device = current_device
self.weight_inplace_update = weight_inplace_update
def model_size(self):
if self.size > 0:
return self.size
self.size = memory_management.module_size(self.model)
return self.size
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
return n
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) # Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False):
self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization)
def set_model_unet_function_wrapper(self, unet_wrapper_function):
self.model_options["model_function_wrapper"] = unet_wrapper_function
def set_model_vae_encode_wrapper(self, wrapper_function):
self.model_options["model_vae_encode_wrapper"] = wrapper_function
def set_model_vae_decode_wrapper(self, wrapper_function):
self.model_options["model_vae_decode_wrapper"] = wrapper_function
def set_model_vae_regulation(self, vae_regulation):
self.model_options["model_vae_regulation"] = vae_regulation
def set_model_denoise_mask_function(self, denoise_mask_function):
self.model_options["denoise_mask_function"] = denoise_mask_function
def set_model_patch(self, patch, name):
to = self.model_options["transformer_options"]
if "patches" not in to:
to["patches"] = {}
to["patches"][name] = to["patches"].get(name, []) + [patch]
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
self.model_options = set_model_options_patch_replace(self.model_options, patch, name, block_name, number, transformer_index=transformer_index)
def set_model_attn1_patch(self, patch):
self.set_model_patch(patch, "attn1_patch")
def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch")
def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
def set_model_attn1_output_patch(self, patch):
self.set_model_patch(patch, "attn1_output_patch")
def set_model_attn2_output_patch(self, patch):
self.set_model_patch(patch, "attn2_output_patch")
def set_model_input_block_patch(self, patch):
self.set_model_patch(patch, "input_block_patch")
def set_model_input_block_patch_after_skip(self, patch):
self.set_model_patch(patch, "input_block_patch_after_skip")
def set_model_output_block_patch(self, patch):
self.set_model_patch(patch, "output_block_patch")
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
def get_model_object(self, name):
if name in self.object_patches:
return self.object_patches[name]
else:
if name in self.object_patches_backup:
return self.object_patches_backup[name]
else:
return utils.get_attr(self.model, name)
def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
patch_list[i] = patch_list[i].to(device)
if "patches_replace" in to:
patches = to["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], "to"):
patch_list[k] = patch_list[k].to(device)
if "model_function_wrapper" in self.model_options:
wrap_func = self.model_options["model_function_wrapper"]
if hasattr(wrap_func, "to"):
self.model_options["model_function_wrapper"] = wrap_func.to(device)
def model_dtype(self):
if hasattr(self.model, "get_dtype"):
return self.model.get_dtype()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = set()
model_sd = self.model.state_dict()
for k in patches:
offset = None
function = None
if isinstance(k, str):
key = k
else:
offset = k[1]
key = k[0]
if len(k) > 2:
function = k[2]
if key in model_sd:
p.add(k)
current_patches = self.patches.get(key, [])
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
self.patches[key] = current_patches
return list(p)
def get_key_patches(self, filter_prefix=None):
memory_management.unload_model_clones(self)
model_sd = self.model_state_dict()
p = {}
for k in model_sd:
if filter_prefix is not None:
if not k.startswith(filter_prefix):
continue
if k in self.patches:
p[k] = [model_sd[k]] + self.patches[k]
else:
p[k] = (model_sd[k],)
return p
def model_state_dict(self, filter_prefix=None):
sd = self.model.state_dict()
keys = list(sd.keys())
if filter_prefix is not None:
for k in keys:
if not k.startswith(filter_prefix):
sd.pop(k)
return sd
def patch_model(self, device_to=None, patch_weights=True):
for k in self.object_patches:
old = utils.get_attr(self.model, k)
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
utils.set_attr_raw(self.model, k, self.object_patches[k])
if patch_weights:
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
print("could not patch. key doesn't exist in model:", key)
continue
weight = model_sd[key]
inplace_update = self.weight_inplace_update
if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
if device_to is not None:
temp_weight = memory_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
if inplace_update:
utils.copy_to_param(self.model, key, out_weight)
else:
utils.set_attr(self.model, key, out_weight)
del temp_weight
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
return self.model
def calculate_weight(self, patches, weight, key):
for p in patches:
strength = p[0]
v = p[1]
strength_model = p[2]
offset = p[3]
function = p[4]
if function is None:
function = lambda a: a
old_weight = None
if offset is not None:
old_weight = weight
weight = weight.narrow(offset[0], offset[1], offset[2])
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key),)
patch_type = ''
if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "diff":
w1 = v[0]
if strength != 0.0:
if w1.shape != weight.shape:
if w1.ndim == weight.ndim == 4:
new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)]
print(f'Merged with {key} channel changed to {new_shape}')
new_diff = strength * memory_management.cast_to_device(w1, weight.device, weight.dtype)
new_weight = torch.zeros(size=new_shape).to(weight)
new_weight[:weight.shape[0], :weight.shape[1], :weight.shape[2], :weight.shape[3]] = weight
new_weight[:new_diff.shape[0], :new_diff.shape[1], :new_diff.shape[2], :new_diff.shape[3]] += new_diff
new_weight = new_weight.contiguous().clone()
weight = new_weight
else:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += strength * memory_management.cast_to_device(w1, weight.device, weight.dtype)
elif patch_type == "lora":
mat1 = memory_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = memory_management.cast_to_device(v[1], weight.device, torch.float32)
dora_scale = v[4]
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:
alpha = 1.0
if v[3] is not None:
mat3 = memory_management.cast_to_device(v[3], weight.device, torch.float32)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
print("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dora_scale = v[8]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(memory_management.cast_to_device(w1_a, weight.device, torch.float32),
memory_management.cast_to_device(w1_b, weight.device, torch.float32))
else:
w1 = memory_management.cast_to_device(w1, weight.device, torch.float32)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(memory_management.cast_to_device(w2_a, weight.device, torch.float32),
memory_management.cast_to_device(w2_b, weight.device, torch.float32))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
memory_management.cast_to_device(t2, weight.device, torch.float32),
memory_management.cast_to_device(w2_b, weight.device, torch.float32),
memory_management.cast_to_device(w2_a, weight.device, torch.float32))
else:
w2 = memory_management.cast_to_device(w2, weight.device, torch.float32)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha = v[2] / dim
else:
alpha = 1.0
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
print("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha = v[2] / w1b.shape[0]
else:
alpha = 1.0
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None:
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
memory_management.cast_to_device(t1, weight.device, torch.float32),
memory_management.cast_to_device(w1b, weight.device, torch.float32),
memory_management.cast_to_device(w1a, weight.device, torch.float32))
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
memory_management.cast_to_device(t2, weight.device, torch.float32),
memory_management.cast_to_device(w2b, weight.device, torch.float32),
memory_management.cast_to_device(w2a, weight.device, torch.float32))
else:
m1 = torch.mm(memory_management.cast_to_device(w1a, weight.device, torch.float32),
memory_management.cast_to_device(w1b, weight.device, torch.float32))
m2 = torch.mm(memory_management.cast_to_device(w2a, weight.device, torch.float32),
memory_management.cast_to_device(w2b, weight.device, torch.float32))
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
print("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
if v[4] is not None:
alpha = v[4] / v[0].shape[0]
else:
alpha = 1.0
dora_scale = v[5]
a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
try:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
print("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type in extra_weight_calculators:
weight = extra_weight_calculators[patch_type](weight, strength, v)
else:
print("patch type not recognized {} {}".format(patch_type, key))
if old_weight is not None:
weight = old_weight
return weight
def unpatch_model(self, device_to=None):
keys = list(self.backup.keys())
if self.weight_inplace_update:
for k in keys:
utils.copy_to_param(self.model, k, self.backup[k])
else:
for k in keys:
utils.set_attr(self.model, k, self.backup[k])
self.backup = {}
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
keys = list(self.object_patches_backup.keys())
for k in keys:
utils.set_attr_raw(self.model, k, self.object_patches_backup[k])
self.object_patches_backup = {}

293
backend/patcher/lora.py Normal file
View File

@@ -0,0 +1,293 @@
# LoRA Implementation Collection form ComfyUI
# Modified by Forge to support greedy loading (load a set or wrong/correct loras to a model and only preserve the correct ones),
# which is important to webui experience
from backend.misc.diffusers_state_dict import unet_to_diffusers
LORA_CLIP_MAP = {
"mlp.fc1": "mlp_fc1",
"mlp.fc2": "mlp_fc2",
"self_attn.k_proj": "self_attn_k_proj",
"self_attn.q_proj": "self_attn_q_proj",
"self_attn.v_proj": "self_attn_v_proj",
"self_attn.out_proj": "self_attn_out_proj",
}
def load_lora(lora, to_load):
patch_dict = {}
loaded_keys = set()
for x in to_load:
alpha_name = "{}.alpha".format(x)
alpha = None
if alpha_name in lora.keys():
alpha = lora[alpha_name].item()
loaded_keys.add(alpha_name)
dora_scale_name = "{}.dora_scale".format(x)
dora_scale = None
if dora_scale_name in lora.keys():
dora_scale = lora[dora_scale_name]
loaded_keys.add(dora_scale_name)
regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x)
diffusers2_lora = "{}.lora_B.weight".format(x)
diffusers3_lora = "{}.lora.up.weight".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
A_name = None
if regular_lora in lora.keys():
A_name = regular_lora
B_name = "{}.lora_down.weight".format(x)
mid_name = "{}.lora_mid.weight".format(x)
elif diffusers_lora in lora.keys():
A_name = diffusers_lora
B_name = "{}_lora.down.weight".format(x)
mid_name = None
elif diffusers2_lora in lora.keys():
A_name = diffusers2_lora
B_name = "{}.lora_A.weight".format(x)
mid_name = None
elif diffusers3_lora in lora.keys():
A_name = diffusers3_lora
B_name = "{}.lora.down.weight".format(x)
mid_name = None
elif transformers_lora in lora.keys():
A_name = transformers_lora
B_name = "{}.lora_linear_layer.down.weight".format(x)
mid_name = None
if A_name is not None:
mid = None
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
loaded_keys.add(A_name)
loaded_keys.add(B_name)
######## loha
hada_w1_a_name = "{}.hada_w1_a".format(x)
hada_w1_b_name = "{}.hada_w1_b".format(x)
hada_w2_a_name = "{}.hada_w2_a".format(x)
hada_w2_b_name = "{}.hada_w2_b".format(x)
hada_t1_name = "{}.hada_t1".format(x)
hada_t2_name = "{}.hada_t2".format(x)
if hada_w1_a_name in lora.keys():
hada_t1 = None
hada_t2 = None
if hada_t1_name in lora.keys():
hada_t1 = lora[hada_t1_name]
hada_t2 = lora[hada_t2_name]
loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name)
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name)
loaded_keys.add(hada_w2_b_name)
######## lokr
lokr_w1_name = "{}.lokr_w1".format(x)
lokr_w2_name = "{}.lokr_w2".format(x)
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
lokr_t2_name = "{}.lokr_t2".format(x)
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
lokr_w1 = None
if lokr_w1_name in lora.keys():
lokr_w1 = lora[lokr_w1_name]
loaded_keys.add(lokr_w1_name)
lokr_w2 = None
if lokr_w2_name in lora.keys():
lokr_w2 = lora[lokr_w2_name]
loaded_keys.add(lokr_w2_name)
lokr_w1_a = None
if lokr_w1_a_name in lora.keys():
lokr_w1_a = lora[lokr_w1_a_name]
loaded_keys.add(lokr_w1_a_name)
lokr_w1_b = None
if lokr_w1_b_name in lora.keys():
lokr_w1_b = lora[lokr_w1_b_name]
loaded_keys.add(lokr_w1_b_name)
lokr_w2_a = None
if lokr_w2_a_name in lora.keys():
lokr_w2_a = lora[lokr_w2_a_name]
loaded_keys.add(lokr_w2_a_name)
lokr_w2_b = None
if lokr_w2_b_name in lora.keys():
lokr_w2_b = lora[lokr_w2_b_name]
loaded_keys.add(lokr_w2_b_name)
lokr_t2 = None
if lokr_t2_name in lora.keys():
lokr_t2 = lora[lokr_t2_name]
loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
# glora
a1_name = "{}.a1.weight".format(x)
a2_name = "{}.a2.weight".format(x)
b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x)
if a1_name in lora:
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
loaded_keys.add(a1_name)
loaded_keys.add(a2_name)
loaded_keys.add(b1_name)
loaded_keys.add(b2_name)
w_norm_name = "{}.w_norm".format(x)
b_norm_name = "{}.b_norm".format(x)
w_norm = lora.get(w_norm_name, None)
b_norm = lora.get(b_norm_name, None)
if w_norm is not None:
loaded_keys.add(w_norm_name)
patch_dict[to_load[x]] = ("diff", (w_norm,))
if b_norm is not None:
loaded_keys.add(b_norm_name)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
diff_name = "{}.diff".format(x)
diff_weight = lora.get(diff_name, None)
if diff_weight is not None:
patch_dict[to_load[x]] = ("diff", (diff_weight,))
loaded_keys.add(diff_name)
diff_bias_name = "{}.diff_b".format(x)
diff_bias = lora.get(diff_bias_name, None)
if diff_bias is not None:
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
loaded_keys.add(diff_bias_name)
remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
return patch_dict, remaining_dict
def model_lora_keys_clip(model, key_map={}):
sdk = model.state_dict().keys()
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False
for b in range(32): # TODO: clean up
for c in LORA_CLIP_MAP:
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c)
key_map[lora_key] = k
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
clip_l_present = True
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c)
key_map[lora_key] = k
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
if clip_l_present:
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c)
key_map[lora_key] = k
else:
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c)
key_map[lora_key] = k
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
for k in sdk: # OneTrainer SD3 lora
if k.startswith("t5xxl.transformer.") and k.endswith(".weight"):
l_key = k[len("t5xxl.transformer."):-len(".weight")]
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
key_map[lora_key] = k
k = "clip_g.transformer.text_projection.weight"
if k in sdk:
key_map["lora_prior_te_text_projection"] = k
key_map["lora_te2_text_projection"] = k
k = "clip_l.transformer.text_projection.weight"
if k in sdk:
key_map["lora_te1_text_projection"] = k
return key_map
def model_lora_keys_unet(model, key_map={}):
sd = model.state_dict()
sdk = sd.keys()
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
key_map["lora_prior_unet_{}".format(key_lora)] = k
diffusers_keys = unet_to_diffusers(model.diffusion_model.legacy_config)
for k in diffusers_keys:
if k.endswith(".weight"):
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
key_lora = k[:-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = unet_key
diffusers_lora_prefix = ["", "unet."]
for p in diffusers_lora_prefix:
diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
if diffusers_lora_key.endswith(".to_out.0"):
diffusers_lora_key = diffusers_lora_key[:-2]
key_map[diffusers_lora_key] = unet_key
# TODO:
# if isinstance(model, xxxx.modules.model_base.SD3): # Diffusers lora SD3
# diffusers_keys = xxxx.modules.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
# for k in diffusers_keys:
# if k.endswith(".weight"):
# to = diffusers_keys[k]
# key_lora = "transformer.{}".format(k[:-len(".weight")]) # regular diffusers sd3 lora format
# key_map[key_lora] = to
#
# key_lora = "base_model.model.{}".format(k[:-len(".weight")]) # format for flash-sd3 lora and others?
# key_map[key_lora] = to
#
# key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) # OneTrainer lora
# key_map[key_lora] = to
#
# if isinstance(model, xxxx.modules.model_base.AuraFlow): # Diffusers lora AuraFlow
# diffusers_keys = xxxx.modules.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
# for k in diffusers_keys:
# if k.endswith(".weight"):
# to = diffusers_keys[k]
# key_lora = "transformer.{}".format(k[:-len(".weight")]) # simpletrainer and probably regular diffusers lora format
# key_map[key_lora] = to
#
# if isinstance(model, xxxx.modules.model_base.HunyuanDiT):
# for k in sdk:
# if k.startswith("diffusion_model.") and k.endswith(".weight"):
# key_lora = k[len("diffusion_model."):-len(".weight")]
# key_map["base_model.model.{}".format(key_lora)] = k # official hunyuan lora format
return key_map

30
backend/utils.py Normal file
View File

@@ -0,0 +1,30 @@
import torch
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
def set_attr_raw(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
setattr(obj, attrs[-1], value)
def copy_to_param(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
prev.data.copy_(value)
def get_attr(obj, attr):
attrs = attr.split(".")
for name in attrs:
obj = getattr(obj, name)
return obj

View File

@@ -1,6 +0,0 @@
class LoraPatches:
def __init__(self):
pass
def undo(self):
pass

View File

@@ -1,228 +0,0 @@
from __future__ import annotations
import os
from collections import namedtuple
import enum
import torch.nn as nn
import torch.nn.functional as F
from modules import sd_models, cache, errors, hashes, shared
import modules.models.sd3.mmdit
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
class SdVersion(enum.Enum):
Unknown = 1
SD1 = 2
SD2 = 3
SDXL = 4
class NetworkOnDisk:
def __init__(self, name, filename):
self.name = name
self.filename = filename
self.metadata = {}
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
def read_metadata():
metadata = sd_models.read_metadata_from_safetensors(filename)
return metadata
if self.is_safetensors:
try:
self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
except Exception as e:
errors.display(e, f"reading lora {filename}")
if self.metadata:
m = {}
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
m[k] = v
self.metadata = m
self.alias = self.metadata.get('ss_output_name', self.name)
self.hash = None
self.shorthash = None
self.set_hash(
self.metadata.get('sshs_model_hash') or
hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
''
)
self.sd_version = self.detect_version()
def detect_version(self):
if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
return SdVersion.SDXL
elif str(self.metadata.get('ss_v2', "")) == "True":
return SdVersion.SD2
elif len(self.metadata):
return SdVersion.SD1
return SdVersion.Unknown
def set_hash(self, v):
self.hash = v
self.shorthash = self.hash[0:12]
if self.shorthash:
import networks
networks.available_network_hash_lookup[self.shorthash] = self
def read_hash(self):
if not self.hash:
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
def get_alias(self):
import networks
if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:
return self.name
else:
return self.alias
class Network: # LoraModule
def __init__(self, name, network_on_disk: NetworkOnDisk):
self.name = name
self.network_on_disk = network_on_disk
self.te_multiplier = 1.0
self.unet_multiplier = 1.0
self.dyn_dim = None
self.modules = {}
self.bundle_embeddings = {}
self.mtime = None
self.mentioned_name = None
"""the text that was used to add the network to prompt - can be either name or an alias"""
class ModuleType:
def create_module(self, net: Network, weights: NetworkWeights) -> Network | None:
return None
class NetworkModule:
def __init__(self, net: Network, weights: NetworkWeights):
self.network = net
self.network_key = weights.network_key
self.sd_key = weights.sd_key
self.sd_module = weights.sd_module
if isinstance(self.sd_module, modules.models.sd3.mmdit.QkvLinear):
s = self.sd_module.weight.shape
self.shape = (s[0] // 3, s[1])
elif hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
elif isinstance(self.sd_module, nn.MultiheadAttention):
# For now, only self-attn use Pytorch's MHA
# So assume all qkvo proj have same shape
self.shape = self.sd_module.out_proj.weight.shape
else:
self.shape = None
self.ops = None
self.extra_kwargs = {}
if isinstance(self.sd_module, nn.Conv2d):
self.ops = F.conv2d
self.extra_kwargs = {
'stride': self.sd_module.stride,
'padding': self.sd_module.padding
}
elif isinstance(self.sd_module, nn.Linear):
self.ops = F.linear
elif isinstance(self.sd_module, nn.LayerNorm):
self.ops = F.layer_norm
self.extra_kwargs = {
'normalized_shape': self.sd_module.normalized_shape,
'eps': self.sd_module.eps
}
elif isinstance(self.sd_module, nn.GroupNorm):
self.ops = F.group_norm
self.extra_kwargs = {
'num_groups': self.sd_module.num_groups,
'eps': self.sd_module.eps
}
self.dim = None
self.bias = weights.w.get("bias")
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
self.dora_scale = weights.w.get("dora_scale", None)
self.dora_norm_dims = len(self.shape) - 1
def multiplier(self):
if 'transformer' in self.sd_key[:20]:
return self.network.te_multiplier
else:
return self.network.unet_multiplier
def calc_scale(self):
if self.scale is not None:
return self.scale
if self.dim is not None and self.alpha is not None:
return self.alpha / self.dim
return 1.0
def apply_weight_decompose(self, updown, orig_weight):
# Match the device/dtype
orig_weight = orig_weight.to(updown.dtype)
dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype)
updown = updown.to(orig_weight.device)
merged_scale1 = updown + orig_weight
merged_scale1_norm = (
merged_scale1.transpose(0, 1)
.reshape(merged_scale1.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims)
.transpose(0, 1)
)
dora_merged = (
merged_scale1 * (dora_scale / merged_scale1_norm)
)
final_updown = dora_merged - orig_weight
return final_updown
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
updown = updown.reshape(output_shape)
if len(output_shape) == 4:
updown = updown.reshape(output_shape)
if orig_weight.size().numel() == updown.size().numel():
updown = updown.reshape(orig_weight.shape)
if ex_bias is not None:
ex_bias = ex_bias * self.multiplier()
updown = updown * self.calc_scale()
if self.dora_scale is not None:
updown = self.apply_weight_decompose(updown, orig_weight)
return updown * self.multiplier(), ex_bias
def calc_updown(self, target):
raise NotImplementedError()
def forward(self, x, y):
"""A general forward implementation for all modules"""
if self.ops is None:
raise NotImplementedError()
else:
updown, ex_bias = self.calc_updown(self.sd_module.weight)
return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)

View File

@@ -1,62 +1,62 @@
from modules import extra_networks, shared
import networks
class ExtraNetworkLora(extra_networks.ExtraNetwork):
def __init__(self):
super().__init__('lora')
self.errors = {}
"""mapping of network names to the number of errors the network had during operation"""
remove_symbols = str.maketrans('', '', ":,")
def activate(self, p, params_list):
additional = shared.opts.sd_lora
self.errors.clear()
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
names = []
te_multipliers = []
unet_multipliers = []
dyn_dims = []
for params in params_list:
assert params.items
names.append(params.positional[0])
te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0
te_multiplier = float(params.named.get("te", te_multiplier))
unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else te_multiplier
unet_multiplier = float(params.named.get("unet", unet_multiplier))
dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None
dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim
te_multipliers.append(te_multiplier)
unet_multipliers.append(unet_multiplier)
dyn_dims.append(dyn_dim)
networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims)
if shared.opts.lora_add_hashes_to_infotext:
if not getattr(p, "is_hr_pass", False) or not hasattr(p, "lora_hashes"):
p.lora_hashes = {}
for item in networks.loaded_networks:
if item.network_on_disk.shorthash and item.mentioned_name:
p.lora_hashes[item.mentioned_name.translate(self.remove_symbols)] = item.network_on_disk.shorthash
if p.lora_hashes:
p.extra_generation_params["Lora hashes"] = ', '.join(f'{k}: {v}' for k, v in p.lora_hashes.items())
def deactivate(self, p):
if self.errors:
p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))
self.errors.clear()
from modules import extra_networks, shared
import networks
class ExtraNetworkLora(extra_networks.ExtraNetwork):
def __init__(self):
super().__init__('lora')
self.errors = {}
"""mapping of network names to the number of errors the network had during operation"""
remove_symbols = str.maketrans('', '', ":,")
def activate(self, p, params_list):
additional = shared.opts.sd_lora
self.errors.clear()
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
names = []
te_multipliers = []
unet_multipliers = []
dyn_dims = []
for params in params_list:
assert params.items
names.append(params.positional[0])
te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0
te_multiplier = float(params.named.get("te", te_multiplier))
unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else te_multiplier
unet_multiplier = float(params.named.get("unet", unet_multiplier))
dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None
dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim
te_multipliers.append(te_multiplier)
unet_multipliers.append(unet_multiplier)
dyn_dims.append(dyn_dim)
networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims)
if shared.opts.lora_add_hashes_to_infotext:
if not getattr(p, "is_hr_pass", False) or not hasattr(p, "lora_hashes"):
p.lora_hashes = {}
for item in networks.loaded_networks:
if item.network_on_disk.shorthash and item.mentioned_name:
p.lora_hashes[item.mentioned_name.translate(self.remove_symbols)] = item.network_on_disk.shorthash
if p.lora_hashes:
p.extra_generation_params["Lora hashes"] = ', '.join(f'{k}: {v}' for k, v in p.lora_hashes.items())
def deactivate(self, p):
if self.errors:
p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))
self.errors.clear()

View File

@@ -1,9 +1,9 @@
import networks
list_available_loras = networks.list_available_networks
available_loras = networks.available_networks
available_lora_aliases = networks.available_network_aliases
available_lora_hash_lookup = networks.available_network_hash_lookup
forbidden_lora_aliases = networks.forbidden_network_aliases
loaded_loras = networks.loaded_networks
import networks
list_available_loras = networks.list_available_networks
available_loras = networks.available_networks
available_lora_aliases = networks.available_network_aliases
available_lora_hash_lookup = networks.available_network_hash_lookup
forbidden_lora_aliases = networks.forbidden_network_aliases
loaded_loras = networks.loaded_networks

View File

@@ -0,0 +1,74 @@
import os
from modules import sd_models, cache, errors, hashes, shared
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
class NetworkOnDisk:
def __init__(self, name, filename):
self.name = name
self.filename = filename
self.metadata = {}
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
def read_metadata():
metadata = sd_models.read_metadata_from_safetensors(filename)
return metadata
if self.is_safetensors:
try:
self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
except Exception as e:
errors.display(e, f"reading lora {filename}")
if self.metadata:
m = {}
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
m[k] = v
self.metadata = m
self.alias = self.metadata.get('ss_output_name', self.name)
self.hash = None
self.shorthash = None
self.set_hash(
self.metadata.get('sshs_model_hash') or
hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
''
)
def set_hash(self, v):
self.hash = v
self.shorthash = self.hash[0:12]
if self.shorthash:
import networks
networks.available_network_hash_lookup[self.shorthash] = self
def read_hash(self):
if not self.hash:
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
def get_alias(self):
import networks
if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:
return self.name
else:
return self.alias
class Network:
def __init__(self, name, network_on_disk: NetworkOnDisk):
self.name = name
self.network_on_disk = network_on_disk
self.te_multiplier = 1.0
self.unet_multiplier = 1.0
self.dyn_dim = None
self.modules = {}
self.bundle_embeddings = {}
self.mtime = None
self.mentioned_name = None

View File

@@ -1,272 +1,171 @@
from __future__ import annotations
import gradio as gr
import logging
import os
import re
import lora_patches
import functools
import network
import torch
from typing import Union
from modules import shared, sd_models, errors, scripts
from ldm_patched.modules.utils import load_torch_file
from ldm_patched.modules.sd import load_lora_for_models
@functools.lru_cache(maxsize=5)
def load_lora_state_dict(filename):
return load_torch_file(filename, safe_load=True)
def convert_diffusers_name_to_compvis(key, is_sd2):
pass
def assign_network_names_to_compvis_modules(sd_model):
pass
class BundledTIHash(str):
def __init__(self, hash_str):
self.hash = hash_str
def __str__(self):
return self.hash if shared.opts.lora_bundled_ti_to_infotext else ''
def load_network(name, network_on_disk):
net = network.Network(name, network_on_disk)
net.mtime = os.path.getmtime(network_on_disk.filename)
return net
def purge_networks_from_memory():
pass
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
global lora_state_dict_cache
current_sd = sd_models.model_data.get_sd_model()
if current_sd is None:
return
loaded_networks.clear()
unavailable_networks = []
for name in names:
if name.lower() in forbidden_network_aliases and available_networks.get(name) is None:
unavailable_networks.append(name)
elif available_network_aliases.get(name) is None:
unavailable_networks.append(name)
if unavailable_networks:
update_available_networks_by_names(unavailable_networks)
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
if any(x is None for x in networks_on_disk):
list_available_networks()
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
try:
net = load_network(name, network_on_disk)
except Exception as e:
errors.display(e, f"loading network {network_on_disk.filename}")
continue
net.mentioned_name = name
network_on_disk.read_hash()
loaded_networks.append(net)
compiled_lora_targets = []
for a, b, c in zip(networks_on_disk, unet_multipliers, te_multipliers):
compiled_lora_targets.append([a.filename, b, c])
compiled_lora_targets_hash = str(compiled_lora_targets)
if current_sd.current_lora_hash == compiled_lora_targets_hash:
return
current_sd.current_lora_hash = compiled_lora_targets_hash
current_sd.forge_objects.unet = current_sd.forge_objects_original.unet
current_sd.forge_objects.clip = current_sd.forge_objects_original.clip
for filename, strength_model, strength_clip in compiled_lora_targets:
lora_sd = load_lora_state_dict(filename)
current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models(
current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip,
filename=filename)
current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy()
return
def allowed_layer_without_weight(layer):
if isinstance(layer, torch.nn.LayerNorm) and not layer.elementwise_affine:
return True
return False
def store_weights_backup(weight):
if weight is None:
return None
return weight.to(devices.cpu, copy=True)
def restore_weights_backup(obj, field, weight):
if weight is None:
setattr(obj, field, None)
return
getattr(obj, field).copy_(weight)
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
pass
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
pass
def network_forward(org_module, input, original_forward):
pass
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
pass
def network_Linear_forward(self, input):
pass
def network_Linear_load_state_dict(self, *args, **kwargs):
pass
def network_Conv2d_forward(self, input):
pass
def network_Conv2d_load_state_dict(self, *args, **kwargs):
pass
def network_GroupNorm_forward(self, input):
pass
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
pass
def network_LayerNorm_forward(self, input):
pass
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
pass
def network_MultiheadAttention_forward(self, *args, **kwargs):
pass
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
pass
def process_network_files(names: list[str] | None = None):
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
for filename in candidates:
if os.path.isdir(filename):
continue
name = os.path.splitext(os.path.basename(filename))[0]
# if names is provided, only load networks with names in the list
if names and name not in names:
continue
try:
entry = network.NetworkOnDisk(name, filename)
except OSError: # should catch FileNotFoundError and PermissionError etc.
errors.report(f"Failed to load network {name} from {filename}", exc_info=True)
continue
available_networks[name] = entry
if entry.alias in available_network_aliases:
forbidden_network_aliases[entry.alias.lower()] = 1
available_network_aliases[name] = entry
available_network_aliases[entry.alias] = entry
def update_available_networks_by_names(names: list[str]):
process_network_files(names)
def list_available_networks():
available_networks.clear()
available_network_aliases.clear()
forbidden_network_aliases.clear()
available_network_hash_lookup.clear()
forbidden_network_aliases.update({"none": 1, "Addams": 1})
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
process_network_files()
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
def infotext_pasted(infotext, params):
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
return # if the other extension is active, it will handle those fields, no need to do anything
added = []
for k in params:
if not k.startswith("AddNet Model "):
continue
num = k[13:]
if params.get("AddNet Module " + num) != "LoRA":
continue
name = params.get("AddNet Model " + num)
if name is None:
continue
m = re_network_name.match(name)
if m:
name = m.group(1)
multiplier = params.get("AddNet Weight A " + num, "1.0")
added.append(f"<lora:{name}:{multiplier}>")
if added:
params["Prompt"] += "\n" + "".join(added)
originals: lora_patches.LoraPatches = None
extra_network_lora = None
available_networks = {}
available_network_aliases = {}
loaded_networks = []
loaded_bundle_embeddings = {}
networks_in_memory = {}
available_network_hash_lookup = {}
forbidden_network_aliases = {}
list_available_networks()
from __future__ import annotations
import gradio as gr
import logging
import os
import re
import functools
import network
import torch
from typing import Union
from modules import shared, sd_models, errors, scripts
from ldm_patched.modules.utils import load_torch_file
from ldm_patched.modules.sd import load_lora_for_models
@functools.lru_cache(maxsize=5)
def load_lora_state_dict(filename):
return load_torch_file(filename, safe_load=True)
def load_network(name, network_on_disk):
net = network.Network(name, network_on_disk)
net.mtime = os.path.getmtime(network_on_disk.filename)
return net
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
global lora_state_dict_cache
current_sd = sd_models.model_data.get_sd_model()
if current_sd is None:
return
loaded_networks.clear()
unavailable_networks = []
for name in names:
if name.lower() in forbidden_network_aliases and available_networks.get(name) is None:
unavailable_networks.append(name)
elif available_network_aliases.get(name) is None:
unavailable_networks.append(name)
if unavailable_networks:
update_available_networks_by_names(unavailable_networks)
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
if any(x is None for x in networks_on_disk):
list_available_networks()
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
try:
net = load_network(name, network_on_disk)
except Exception as e:
errors.display(e, f"loading network {network_on_disk.filename}")
continue
net.mentioned_name = name
network_on_disk.read_hash()
loaded_networks.append(net)
compiled_lora_targets = []
for a, b, c in zip(networks_on_disk, unet_multipliers, te_multipliers):
compiled_lora_targets.append([a.filename, b, c])
compiled_lora_targets_hash = str(compiled_lora_targets)
if current_sd.current_lora_hash == compiled_lora_targets_hash:
return
current_sd.current_lora_hash = compiled_lora_targets_hash
current_sd.forge_objects.unet = current_sd.forge_objects_original.unet
current_sd.forge_objects.clip = current_sd.forge_objects_original.clip
for filename, strength_model, strength_clip in compiled_lora_targets:
lora_sd = load_lora_state_dict(filename)
current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models(
current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip,
filename=filename)
current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy()
return
def process_network_files(names: list[str] | None = None):
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
for filename in candidates:
if os.path.isdir(filename):
continue
name = os.path.splitext(os.path.basename(filename))[0]
# if names is provided, only load networks with names in the list
if names and name not in names:
continue
try:
entry = network.NetworkOnDisk(name, filename)
except OSError: # should catch FileNotFoundError and PermissionError etc.
errors.report(f"Failed to load network {name} from {filename}", exc_info=True)
continue
available_networks[name] = entry
if entry.alias in available_network_aliases:
forbidden_network_aliases[entry.alias.lower()] = 1
available_network_aliases[name] = entry
available_network_aliases[entry.alias] = entry
def update_available_networks_by_names(names: list[str]):
process_network_files(names)
def list_available_networks():
available_networks.clear()
available_network_aliases.clear()
forbidden_network_aliases.clear()
available_network_hash_lookup.clear()
forbidden_network_aliases.update({"none": 1, "Addams": 1})
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
process_network_files()
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
def infotext_pasted(infotext, params):
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
return # if the other extension is active, it will handle those fields, no need to do anything
added = []
for k in params:
if not k.startswith("AddNet Model "):
continue
num = k[13:]
if params.get("AddNet Module " + num) != "LoRA":
continue
name = params.get("AddNet Model " + num)
if name is None:
continue
m = re_network_name.match(name)
if m:
name = m.group(1)
multiplier = params.get("AddNet Weight A " + num, "1.0")
added.append(f"<lora:{name}:{multiplier}>")
if added:
params["Prompt"] += "\n" + "".join(added)
extra_network_lora = None
available_networks = {}
available_network_aliases = {}
loaded_networks = []
loaded_bundle_embeddings = {}
networks_in_memory = {}
available_network_hash_lookup = {}
forbidden_network_aliases = {}
list_available_networks()

View File

@@ -1,8 +1,8 @@
import os
from modules import paths
from modules.paths_internal import normalized_filepath
def preload(parser):
parser.add_argument("--lora-dir", type=normalized_filepath, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
parser.add_argument("--lyco-dir-backcompat", type=normalized_filepath, help="Path to directory with LyCORIS networks (for backawards compatibility; can also use --lyco-dir).", default=os.path.join(paths.models_path, 'LyCORIS'))
import os
from modules import paths
from modules.paths_internal import normalized_filepath
def preload(parser):
parser.add_argument("--lora-dir", type=normalized_filepath, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
parser.add_argument("--lyco-dir-backcompat", type=normalized_filepath, help="Path to directory with LyCORIS networks (for backawards compatibility; can also use --lyco-dir).", default=os.path.join(paths.models_path, 'LyCORIS'))

View File

@@ -1,101 +1,90 @@
import re
import gradio as gr
from fastapi import FastAPI
import network
import networks
import lora # noqa:F401
import lora_patches
import extra_networks_lora
import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def unload():
networks.originals.undo()
def before_ui():
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
extra_networks.register_extra_network(networks.extra_network_lora)
networks.originals = lora_patches.LoraPatches()
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
script_callbacks.on_before_ui(before_ui)
script_callbacks.on_infotext_pasted(networks.infotext_pasted)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks),
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
"lora_bundled_ti_to_infotext": shared.OptionInfo(True, "Add Lora name as TI hashes for bundled Textual Inversion").info('"Add Textual Inversion hashes to infotext" needs to be enabled'),
"lora_filter_disabled": shared.OptionInfo(True, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
"lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
"lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),
}))
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
"lora_functional": shared.OptionInfo(False, "Lora/Networks: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
}))
def create_lora_json(obj: network.NetworkOnDisk):
return {
"name": obj.name,
"alias": obj.alias,
"path": obj.filename,
"metadata": obj.metadata,
}
def api_networks(_: gr.Blocks, app: FastAPI):
@app.get("/sdapi/v1/loras")
async def get_loras():
return [create_lora_json(obj) for obj in networks.available_networks.values()]
@app.post("/sdapi/v1/refresh-loras")
async def refresh_loras():
return networks.list_available_networks()
script_callbacks.on_app_started(api_networks)
re_lora = re.compile("<lora:([^:]+):")
def infotext_pasted(infotext, d):
hashes = d.get("Lora hashes")
if not hashes:
return
hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
def network_replacement(m):
alias = m.group(1)
shorthash = hashes.get(alias)
if shorthash is None:
return m.group(0)
network_on_disk = networks.available_network_hash_lookup.get(shorthash)
if network_on_disk is None:
return m.group(0)
return f'<lora:{network_on_disk.get_alias()}:'
d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
script_callbacks.on_infotext_pasted(infotext_pasted)
shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory)
import re
import gradio as gr
from fastapi import FastAPI
import network
import networks
import lora # noqa:F401
import extra_networks_lora
import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def before_ui():
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
extra_networks.register_extra_network(networks.extra_network_lora)
script_callbacks.on_before_ui(before_ui)
script_callbacks.on_infotext_pasted(networks.infotext_pasted)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks),
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
"lora_bundled_ti_to_infotext": shared.OptionInfo(True, "Add Lora name as TI hashes for bundled Textual Inversion").info('"Add Textual Inversion hashes to infotext" needs to be enabled'),
"lora_filter_disabled": shared.OptionInfo(True, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
"lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
"lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),
}))
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
"lora_functional": shared.OptionInfo(False, "Lora/Networks: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
}))
def create_lora_json(obj: network.NetworkOnDisk):
return {
"name": obj.name,
"alias": obj.alias,
"path": obj.filename,
"metadata": obj.metadata,
}
def api_networks(_: gr.Blocks, app: FastAPI):
@app.get("/sdapi/v1/loras")
async def get_loras():
return [create_lora_json(obj) for obj in networks.available_networks.values()]
@app.post("/sdapi/v1/refresh-loras")
async def refresh_loras():
return networks.list_available_networks()
script_callbacks.on_app_started(api_networks)
re_lora = re.compile("<lora:([^:]+):")
def infotext_pasted(infotext, d):
hashes = d.get("Lora hashes")
if not hashes:
return
hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
def network_replacement(m):
alias = m.group(1)
shorthash = hashes.get(alias)
if shorthash is None:
return m.group(0)
network_on_disk = networks.available_network_hash_lookup.get(shorthash)
if network_on_disk is None:
return m.group(0)
return f'<lora:{network_on_disk.get_alias()}:'
d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
script_callbacks.on_infotext_pasted(infotext_pasted)

View File

@@ -1,226 +1,226 @@
import datetime
import html
import random
import gradio as gr
import re
from modules import ui_extra_networks_user_metadata
def is_non_comma_tagset(tags):
average_tag_length = sum(len(x) for x in tags.keys()) / len(tags)
return average_tag_length >= 16
re_word = re.compile(r"[-_\w']+")
re_comma = re.compile(r" *, *")
def build_tags(metadata):
tags = {}
ss_tag_frequency = metadata.get("ss_tag_frequency", {})
if ss_tag_frequency is not None and hasattr(ss_tag_frequency, 'items'):
for _, tags_dict in ss_tag_frequency.items():
for tag, tag_count in tags_dict.items():
tag = tag.strip()
tags[tag] = tags.get(tag, 0) + int(tag_count)
if tags and is_non_comma_tagset(tags):
new_tags = {}
for text, text_count in tags.items():
for word in re.findall(re_word, text):
if len(word) < 3:
continue
new_tags[word] = new_tags.get(word, 0) + text_count
tags = new_tags
ordered_tags = sorted(tags.keys(), key=tags.get, reverse=True)
return [(tag, tags[tag]) for tag in ordered_tags]
class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):
def __init__(self, ui, tabname, page):
super().__init__(ui, tabname, page)
self.select_sd_version = None
self.taginfo = None
self.edit_activation_text = None
self.slider_preferred_weight = None
self.edit_notes = None
def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes):
user_metadata = self.get_user_metadata(name)
user_metadata["description"] = desc
user_metadata["sd version"] = sd_version
user_metadata["activation text"] = activation_text
user_metadata["preferred weight"] = preferred_weight
user_metadata["negative text"] = negative_text
user_metadata["notes"] = notes
self.write_user_metadata(name, user_metadata)
def get_metadata_table(self, name):
table = super().get_metadata_table(name)
item = self.page.items.get(name, {})
metadata = item.get("metadata") or {}
keys = {
'ss_output_name': "Output name:",
'ss_sd_model_name': "Model:",
'ss_clip_skip': "Clip skip:",
'ss_network_module': "Kohya module:",
}
for key, label in keys.items():
value = metadata.get(key, None)
if value is not None and str(value) != "None":
table.append((label, html.escape(value)))
ss_training_started_at = metadata.get('ss_training_started_at')
if ss_training_started_at:
table.append(("Date trained:", datetime.datetime.utcfromtimestamp(float(ss_training_started_at)).strftime('%Y-%m-%d %H:%M')))
ss_bucket_info = metadata.get("ss_bucket_info")
if ss_bucket_info and "buckets" in ss_bucket_info:
resolutions = {}
for _, bucket in ss_bucket_info["buckets"].items():
resolution = bucket["resolution"]
resolution = f'{resolution[1]}x{resolution[0]}'
resolutions[resolution] = resolutions.get(resolution, 0) + int(bucket["count"])
resolutions_list = sorted(resolutions.keys(), key=resolutions.get, reverse=True)
resolutions_text = html.escape(", ".join(resolutions_list[0:4]))
if len(resolutions) > 4:
resolutions_text += ", ..."
resolutions_text = f"<span title='{html.escape(', '.join(resolutions_list))}'>{resolutions_text}</span>"
table.append(('Resolutions:' if len(resolutions_list) > 1 else 'Resolution:', resolutions_text))
image_count = 0
for _, params in metadata.get("ss_dataset_dirs", {}).items():
image_count += int(params.get("img_count", 0))
if image_count:
table.append(("Dataset size:", image_count))
return table
def put_values_into_components(self, name):
user_metadata = self.get_user_metadata(name)
values = super().put_values_into_components(name)
item = self.page.items.get(name, {})
metadata = item.get("metadata") or {}
tags = build_tags(metadata)
gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]]
return [
*values[0:5],
item.get("sd_version", "Unknown"),
gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
user_metadata.get('activation text', ''),
float(user_metadata.get('preferred weight', 0.0)),
user_metadata.get('negative text', ''),
gr.update(visible=True if tags else False),
gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
]
def generate_random_prompt(self, name):
item = self.page.items.get(name, {})
metadata = item.get("metadata") or {}
tags = build_tags(metadata)
return self.generate_random_prompt_from_tags(tags)
def generate_random_prompt_from_tags(self, tags):
max_count = None
res = []
for tag, count in tags:
if not max_count:
max_count = count
v = random.random() * max_count
if count > v:
for x in "({[]})":
tag = tag.replace(x, '\\' + x)
res.append(tag)
return ", ".join(sorted(res))
def create_extra_default_items_in_left_column(self):
# this would be a lot better as gr.Radio but I can't make it work
self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True)
def create_editor(self):
self.create_default_editor_elems()
self.taginfo = gr.HighlightedText(label="Training dataset tags")
self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts")
with gr.Row() as row_random_prompt:
with gr.Column(scale=8):
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
with gr.Column(scale=1, min_width=120):
generate_random_prompt = gr.Button('Generate', size="lg", scale=1)
self.edit_notes = gr.TextArea(label='Notes', lines=4)
generate_random_prompt.click(fn=self.generate_random_prompt, inputs=[self.edit_name_input], outputs=[random_prompt], show_progress=False)
def select_tag(activation_text, evt: gr.SelectData):
tag = evt.value[0]
words = re.split(re_comma, activation_text)
if tag in words:
words = [x for x in words if x != tag and x.strip()]
return ", ".join(words)
return activation_text + ", " + tag if activation_text else tag
self.taginfo.select(fn=select_tag, inputs=[self.edit_activation_text], outputs=[self.edit_activation_text], show_progress=False)
self.create_default_buttons()
viewed_components = [
self.edit_name,
self.edit_description,
self.html_filedata,
self.html_preview,
self.edit_notes,
self.select_sd_version,
self.taginfo,
self.edit_activation_text,
self.slider_preferred_weight,
self.edit_negative_text,
row_random_prompt,
random_prompt,
]
self.button_edit\
.click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\
.then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
edited_components = [
self.edit_description,
self.select_sd_version,
self.edit_activation_text,
self.slider_preferred_weight,
self.edit_negative_text,
self.edit_notes,
]
self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components)
import datetime
import html
import random
import gradio as gr
import re
from modules import ui_extra_networks_user_metadata
def is_non_comma_tagset(tags):
average_tag_length = sum(len(x) for x in tags.keys()) / len(tags)
return average_tag_length >= 16
re_word = re.compile(r"[-_\w']+")
re_comma = re.compile(r" *, *")
def build_tags(metadata):
tags = {}
ss_tag_frequency = metadata.get("ss_tag_frequency", {})
if ss_tag_frequency is not None and hasattr(ss_tag_frequency, 'items'):
for _, tags_dict in ss_tag_frequency.items():
for tag, tag_count in tags_dict.items():
tag = tag.strip()
tags[tag] = tags.get(tag, 0) + int(tag_count)
if tags and is_non_comma_tagset(tags):
new_tags = {}
for text, text_count in tags.items():
for word in re.findall(re_word, text):
if len(word) < 3:
continue
new_tags[word] = new_tags.get(word, 0) + text_count
tags = new_tags
ordered_tags = sorted(tags.keys(), key=tags.get, reverse=True)
return [(tag, tags[tag]) for tag in ordered_tags]
class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):
def __init__(self, ui, tabname, page):
super().__init__(ui, tabname, page)
self.select_sd_version = None
self.taginfo = None
self.edit_activation_text = None
self.slider_preferred_weight = None
self.edit_notes = None
def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes):
user_metadata = self.get_user_metadata(name)
user_metadata["description"] = desc
user_metadata["sd version"] = sd_version
user_metadata["activation text"] = activation_text
user_metadata["preferred weight"] = preferred_weight
user_metadata["negative text"] = negative_text
user_metadata["notes"] = notes
self.write_user_metadata(name, user_metadata)
def get_metadata_table(self, name):
table = super().get_metadata_table(name)
item = self.page.items.get(name, {})
metadata = item.get("metadata") or {}
keys = {
'ss_output_name': "Output name:",
'ss_sd_model_name': "Model:",
'ss_clip_skip': "Clip skip:",
'ss_network_module': "Kohya module:",
}
for key, label in keys.items():
value = metadata.get(key, None)
if value is not None and str(value) != "None":
table.append((label, html.escape(value)))
ss_training_started_at = metadata.get('ss_training_started_at')
if ss_training_started_at:
table.append(("Date trained:", datetime.datetime.utcfromtimestamp(float(ss_training_started_at)).strftime('%Y-%m-%d %H:%M')))
ss_bucket_info = metadata.get("ss_bucket_info")
if ss_bucket_info and "buckets" in ss_bucket_info:
resolutions = {}
for _, bucket in ss_bucket_info["buckets"].items():
resolution = bucket["resolution"]
resolution = f'{resolution[1]}x{resolution[0]}'
resolutions[resolution] = resolutions.get(resolution, 0) + int(bucket["count"])
resolutions_list = sorted(resolutions.keys(), key=resolutions.get, reverse=True)
resolutions_text = html.escape(", ".join(resolutions_list[0:4]))
if len(resolutions) > 4:
resolutions_text += ", ..."
resolutions_text = f"<span title='{html.escape(', '.join(resolutions_list))}'>{resolutions_text}</span>"
table.append(('Resolutions:' if len(resolutions_list) > 1 else 'Resolution:', resolutions_text))
image_count = 0
for _, params in metadata.get("ss_dataset_dirs", {}).items():
image_count += int(params.get("img_count", 0))
if image_count:
table.append(("Dataset size:", image_count))
return table
def put_values_into_components(self, name):
user_metadata = self.get_user_metadata(name)
values = super().put_values_into_components(name)
item = self.page.items.get(name, {})
metadata = item.get("metadata") or {}
tags = build_tags(metadata)
gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]]
return [
*values[0:5],
item.get("sd_version", "Unknown"),
gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
user_metadata.get('activation text', ''),
float(user_metadata.get('preferred weight', 0.0)),
user_metadata.get('negative text', ''),
gr.update(visible=True if tags else False),
gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
]
def generate_random_prompt(self, name):
item = self.page.items.get(name, {})
metadata = item.get("metadata") or {}
tags = build_tags(metadata)
return self.generate_random_prompt_from_tags(tags)
def generate_random_prompt_from_tags(self, tags):
max_count = None
res = []
for tag, count in tags:
if not max_count:
max_count = count
v = random.random() * max_count
if count > v:
for x in "({[]})":
tag = tag.replace(x, '\\' + x)
res.append(tag)
return ", ".join(sorted(res))
def create_extra_default_items_in_left_column(self):
# this would be a lot better as gr.Radio but I can't make it work
self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True)
def create_editor(self):
self.create_default_editor_elems()
self.taginfo = gr.HighlightedText(label="Training dataset tags")
self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts")
with gr.Row() as row_random_prompt:
with gr.Column(scale=8):
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
with gr.Column(scale=1, min_width=120):
generate_random_prompt = gr.Button('Generate', size="lg", scale=1)
self.edit_notes = gr.TextArea(label='Notes', lines=4)
generate_random_prompt.click(fn=self.generate_random_prompt, inputs=[self.edit_name_input], outputs=[random_prompt], show_progress=False)
def select_tag(activation_text, evt: gr.SelectData):
tag = evt.value[0]
words = re.split(re_comma, activation_text)
if tag in words:
words = [x for x in words if x != tag and x.strip()]
return ", ".join(words)
return activation_text + ", " + tag if activation_text else tag
self.taginfo.select(fn=select_tag, inputs=[self.edit_activation_text], outputs=[self.edit_activation_text], show_progress=False)
self.create_default_buttons()
viewed_components = [
self.edit_name,
self.edit_description,
self.html_filedata,
self.html_preview,
self.edit_notes,
self.select_sd_version,
self.taginfo,
self.edit_activation_text,
self.slider_preferred_weight,
self.edit_negative_text,
row_random_prompt,
random_prompt,
]
self.button_edit\
.click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\
.then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
edited_components = [
self.edit_description,
self.select_sd_version,
self.edit_activation_text,
self.slider_preferred_weight,
self.edit_negative_text,
self.edit_notes,
]
self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components)

View File

@@ -1,90 +1,69 @@
import os
import network
import networks
from modules import shared, ui_extra_networks
from modules.ui_extra_networks import quote_js
from ui_edit_user_metadata import LoraUserMetadataEditor
class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
def __init__(self):
super().__init__('Lora')
def refresh(self):
networks.list_available_networks()
def create_item(self, name, index=None, enable_filter=True):
lora_on_disk = networks.available_networks.get(name)
if lora_on_disk is None:
return
path, ext = os.path.splitext(lora_on_disk.filename)
alias = lora_on_disk.get_alias()
search_terms = [self.search_terms_from_path(lora_on_disk.filename)]
if lora_on_disk.hash:
search_terms.append(lora_on_disk.hash)
item = {
"name": name,
"filename": lora_on_disk.filename,
"shorthash": lora_on_disk.shorthash,
"preview": self.find_preview(path) or self.find_embedded_preview(path, name, lora_on_disk.metadata),
"description": self.find_description(path),
"search_terms": search_terms,
"local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": lora_on_disk.metadata,
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
"sd_version": lora_on_disk.sd_version.name,
}
self.read_user_metadata(item)
activation_text = item["user_metadata"].get("activation text")
preferred_weight = item["user_metadata"].get("preferred weight", 0.0)
item["prompt"] = quote_js(f"<lora:{alias}:") + " + " + (str(preferred_weight) if preferred_weight else "opts.extra_networks_default_multiplier") + " + " + quote_js(">")
if activation_text:
item["prompt"] += " + " + quote_js(" " + activation_text)
negative_prompt = item["user_metadata"].get("negative text")
item["negative_prompt"] = quote_js("")
if negative_prompt:
item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)')
sd_version = item["user_metadata"].get("sd version")
if sd_version in network.SdVersion.__members__:
item["sd_version"] = sd_version
sd_version = network.SdVersion[sd_version]
else:
sd_version = lora_on_disk.sd_version
if shared.opts.lora_filter_disabled or not enable_filter or not shared.sd_model:
pass
elif sd_version == network.SdVersion.Unknown:
model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1
if model_version.name in shared.opts.lora_hide_unknown_for_versions:
return None
elif shared.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL:
return None
elif shared.sd_model.is_sd2 and sd_version != network.SdVersion.SD2:
return None
elif shared.sd_model.is_sd1 and sd_version != network.SdVersion.SD1:
return None
return item
def list_items(self):
# instantiate a list to protect against concurrent modification
names = list(networks.available_networks)
for index, name in enumerate(names):
item = self.create_item(name, index)
if item is not None:
yield item
def allowed_directories_for_previews(self):
return [shared.cmd_opts.lora_dir]
def create_user_metadata_editor(self, ui, tabname):
return LoraUserMetadataEditor(ui, tabname, self)
import os
import network
import networks
from modules import shared, ui_extra_networks
from modules.ui_extra_networks import quote_js
from ui_edit_user_metadata import LoraUserMetadataEditor
class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
def __init__(self):
super().__init__('Lora')
def refresh(self):
networks.list_available_networks()
def create_item(self, name, index=None, enable_filter=True):
lora_on_disk = networks.available_networks.get(name)
if lora_on_disk is None:
return
path, ext = os.path.splitext(lora_on_disk.filename)
alias = lora_on_disk.get_alias()
search_terms = [self.search_terms_from_path(lora_on_disk.filename)]
if lora_on_disk.hash:
search_terms.append(lora_on_disk.hash)
item = {
"name": name,
"filename": lora_on_disk.filename,
"shorthash": lora_on_disk.shorthash,
"preview": self.find_preview(path) or self.find_embedded_preview(path, name, lora_on_disk.metadata),
"description": self.find_description(path),
"search_terms": search_terms,
"local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": lora_on_disk.metadata,
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
}
self.read_user_metadata(item)
activation_text = item["user_metadata"].get("activation text")
preferred_weight = item["user_metadata"].get("preferred weight", 0.0)
item["prompt"] = quote_js(f"<lora:{alias}:") + " + " + (str(preferred_weight) if preferred_weight else "opts.extra_networks_default_multiplier") + " + " + quote_js(">")
if activation_text:
item["prompt"] += " + " + quote_js(" " + activation_text)
negative_prompt = item["user_metadata"].get("negative text")
item["negative_prompt"] = quote_js("")
if negative_prompt:
item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)')
return item
def list_items(self):
# instantiate a list to protect against concurrent modification
names = list(networks.available_networks)
for index, name in enumerate(names):
item = self.create_item(name, index)
if item is not None:
yield item
def allowed_directories_for_previews(self):
return [shared.cmd_opts.lora_dir]
def create_user_metadata_editor(self, ui, tabname):
return LoraUserMetadataEditor(ui, tabname, self)

View File

@@ -1,226 +1 @@
# 1st edit by https://github.com/comfyanonymous/ComfyUI
# 2nd edit by Forge Official
import ldm_patched.modules.utils
LORA_CLIP_MAP = {
"mlp.fc1": "mlp_fc1",
"mlp.fc2": "mlp_fc2",
"self_attn.k_proj": "self_attn_k_proj",
"self_attn.q_proj": "self_attn_q_proj",
"self_attn.v_proj": "self_attn_v_proj",
"self_attn.out_proj": "self_attn_out_proj",
}
def load_lora(lora, to_load):
patch_dict = {}
loaded_keys = set()
for x in to_load:
alpha_name = "{}.alpha".format(x)
alpha = None
if alpha_name in lora.keys():
alpha = lora[alpha_name].item()
loaded_keys.add(alpha_name)
regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
A_name = None
if regular_lora in lora.keys():
A_name = regular_lora
B_name = "{}.lora_down.weight".format(x)
mid_name = "{}.lora_mid.weight".format(x)
elif diffusers_lora in lora.keys():
A_name = diffusers_lora
B_name = "{}_lora.down.weight".format(x)
mid_name = None
elif transformers_lora in lora.keys():
A_name = transformers_lora
B_name ="{}.lora_linear_layer.down.weight".format(x)
mid_name = None
if A_name is not None:
mid = None
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
loaded_keys.add(A_name)
loaded_keys.add(B_name)
######## loha
hada_w1_a_name = "{}.hada_w1_a".format(x)
hada_w1_b_name = "{}.hada_w1_b".format(x)
hada_w2_a_name = "{}.hada_w2_a".format(x)
hada_w2_b_name = "{}.hada_w2_b".format(x)
hada_t1_name = "{}.hada_t1".format(x)
hada_t2_name = "{}.hada_t2".format(x)
if hada_w1_a_name in lora.keys():
hada_t1 = None
hada_t2 = None
if hada_t1_name in lora.keys():
hada_t1 = lora[hada_t1_name]
hada_t2 = lora[hada_t2_name]
loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name)
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name)
loaded_keys.add(hada_w2_b_name)
######## lokr
lokr_w1_name = "{}.lokr_w1".format(x)
lokr_w2_name = "{}.lokr_w2".format(x)
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
lokr_t2_name = "{}.lokr_t2".format(x)
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
lokr_w1 = None
if lokr_w1_name in lora.keys():
lokr_w1 = lora[lokr_w1_name]
loaded_keys.add(lokr_w1_name)
lokr_w2 = None
if lokr_w2_name in lora.keys():
lokr_w2 = lora[lokr_w2_name]
loaded_keys.add(lokr_w2_name)
lokr_w1_a = None
if lokr_w1_a_name in lora.keys():
lokr_w1_a = lora[lokr_w1_a_name]
loaded_keys.add(lokr_w1_a_name)
lokr_w1_b = None
if lokr_w1_b_name in lora.keys():
lokr_w1_b = lora[lokr_w1_b_name]
loaded_keys.add(lokr_w1_b_name)
lokr_w2_a = None
if lokr_w2_a_name in lora.keys():
lokr_w2_a = lora[lokr_w2_a_name]
loaded_keys.add(lokr_w2_a_name)
lokr_w2_b = None
if lokr_w2_b_name in lora.keys():
lokr_w2_b = lora[lokr_w2_b_name]
loaded_keys.add(lokr_w2_b_name)
lokr_t2 = None
if lokr_t2_name in lora.keys():
lokr_t2 = lora[lokr_t2_name]
loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
#glora
a1_name = "{}.a1.weight".format(x)
a2_name = "{}.a2.weight".format(x)
b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x)
if a1_name in lora:
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
loaded_keys.add(a1_name)
loaded_keys.add(a2_name)
loaded_keys.add(b1_name)
loaded_keys.add(b2_name)
w_norm_name = "{}.w_norm".format(x)
b_norm_name = "{}.b_norm".format(x)
w_norm = lora.get(w_norm_name, None)
b_norm = lora.get(b_norm_name, None)
if w_norm is not None:
loaded_keys.add(w_norm_name)
patch_dict[to_load[x]] = ("diff", (w_norm,))
if b_norm is not None:
loaded_keys.add(b_norm_name)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
diff_name = "{}.diff".format(x)
diff_weight = lora.get(diff_name, None)
if diff_weight is not None:
patch_dict[to_load[x]] = ("diff", (diff_weight,))
loaded_keys.add(diff_name)
diff_bias_name = "{}.diff_b".format(x)
diff_bias = lora.get(diff_bias_name, None)
if diff_bias is not None:
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
loaded_keys.add(diff_bias_name)
remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
return patch_dict, remaining_dict
def model_lora_keys_clip(model, key_map={}):
sdk = model.state_dict().keys()
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False
for b in range(32): #TODO: clean up
for c in LORA_CLIP_MAP:
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
key_map[lora_key] = k
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
key_map[lora_key] = k
clip_l_present = True
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
key_map[lora_key] = k
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
if clip_l_present:
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
key_map[lora_key] = k
lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
key_map[lora_key] = k
else:
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
key_map[lora_key] = k
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
key_map[lora_key] = k
return key_map
def model_lora_keys_unet(model, key_map={}):
sdk = model.state_dict().keys()
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(model.diffusion_model.legacy_config)
for k in diffusers_keys:
if k.endswith(".weight"):
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
key_lora = k[:-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = unet_key
diffusers_lora_prefix = ["", "unet."]
for p in diffusers_lora_prefix:
diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
if diffusers_lora_key.endswith(".to_out.0"):
diffusers_lora_key = diffusers_lora_key[:-2]
key_map[diffusers_lora_key] = unet_key
return key_map
from backend.patcher.lora import *

View File

@@ -1,386 +1 @@
# 1st edit by https://github.com/comfyanonymous/ComfyUI
# 2nd edit by Forge Official
import torch
import copy
import inspect
import ldm_patched.modules.utils
import ldm_patched.modules.model_management
extra_weight_calculators = {}
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
self.size = size
self.model = model
self.patches = {}
self.backup = {}
self.object_patches = {}
self.object_patches_backup = {}
self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
if current_device is None:
self.current_device = self.offload_device
else:
self.current_device = current_device
self.weight_inplace_update = weight_inplace_update
def model_size(self):
if self.size > 0:
return self.size
model_sd = self.model.state_dict()
self.size = ldm_patched.modules.model_management.module_size(self.model)
self.model_keys = set(model_sd.keys())
return self.size
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
return n
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True
def set_model_unet_function_wrapper(self, unet_wrapper_function):
self.model_options["model_function_wrapper"] = unet_wrapper_function
def set_model_vae_encode_wrapper(self, wrapper_function):
self.model_options["model_vae_encode_wrapper"] = wrapper_function
def set_model_vae_decode_wrapper(self, wrapper_function):
self.model_options["model_vae_decode_wrapper"] = wrapper_function
def set_model_vae_regulation(self, vae_regulation):
self.model_options["model_vae_regulation"] = vae_regulation
def set_model_patch(self, patch, name):
to = self.model_options["transformer_options"]
if "patches" not in to:
to["patches"] = {}
to["patches"][name] = to["patches"].get(name, []) + [patch]
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
to = self.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if name not in to["patches_replace"]:
to["patches_replace"][name] = {}
if transformer_index is not None:
block = (block_name, number, transformer_index)
else:
block = (block_name, number)
to["patches_replace"][name][block] = patch
def set_model_attn1_patch(self, patch):
self.set_model_patch(patch, "attn1_patch")
def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch")
def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
def set_model_attn1_output_patch(self, patch):
self.set_model_patch(patch, "attn1_output_patch")
def set_model_attn2_output_patch(self, patch):
self.set_model_patch(patch, "attn2_output_patch")
def set_model_input_block_patch(self, patch):
self.set_model_patch(patch, "input_block_patch")
def set_model_input_block_patch_after_skip(self, patch):
self.set_model_patch(patch, "input_block_patch_after_skip")
def set_model_output_block_patch(self, patch):
self.set_model_patch(patch, "output_block_patch")
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
patch_list[i] = patch_list[i].to(device)
if "patches_replace" in to:
patches = to["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], "to"):
patch_list[k] = patch_list[k].to(device)
if "model_function_wrapper" in self.model_options:
wrap_func = self.model_options["model_function_wrapper"]
if hasattr(wrap_func, "to"):
self.model_options["model_function_wrapper"] = wrap_func.to(device)
def model_dtype(self):
if hasattr(self.model, "get_dtype"):
return self.model.get_dtype()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = set()
for k in patches:
if k in self.model_keys:
p.add(k)
current_patches = self.patches.get(k, [])
current_patches.append((strength_patch, patches[k], strength_model))
self.patches[k] = current_patches
return list(p)
def get_key_patches(self, filter_prefix=None):
ldm_patched.modules.model_management.unload_model_clones(self)
model_sd = self.model_state_dict()
p = {}
for k in model_sd:
if filter_prefix is not None:
if not k.startswith(filter_prefix):
continue
if k in self.patches:
p[k] = [model_sd[k]] + self.patches[k]
else:
p[k] = (model_sd[k],)
return p
def model_state_dict(self, filter_prefix=None):
sd = self.model.state_dict()
keys = list(sd.keys())
if filter_prefix is not None:
for k in keys:
if not k.startswith(filter_prefix):
sd.pop(k)
return sd
def patch_model(self, device_to=None, patch_weights=True):
for k in self.object_patches:
old = ldm_patched.modules.utils.get_attr(self.model, k)
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
ldm_patched.modules.utils.set_attr_raw(self.model, k, self.object_patches[k])
if patch_weights:
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
print("could not patch. key doesn't exist in model:", key)
continue
weight = model_sd[key]
inplace_update = self.weight_inplace_update
if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
if device_to is not None:
temp_weight = ldm_patched.modules.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
if inplace_update:
ldm_patched.modules.utils.copy_to_param(self.model, key, out_weight)
else:
ldm_patched.modules.utils.set_attr(self.model, key, out_weight)
del temp_weight
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
return self.model
def calculate_weight(self, patches, weight, key):
for p in patches:
alpha = p[0]
v = p[1]
strength_model = p[2]
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "diff":
w1 = v[0]
if alpha != 0.0:
if w1.shape != weight.shape:
if w1.ndim == weight.ndim == 4:
new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)]
print(f'Merged with {key} channel changed to {new_shape}')
new_diff = alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype)
new_weight = torch.zeros(size=new_shape).to(weight)
new_weight[:weight.shape[0], :weight.shape[1], :weight.shape[2], :weight.shape[3]] = weight
new_weight[:new_diff.shape[0], :new_diff.shape[1], :new_diff.shape[2], :new_diff.shape[3]] += new_diff
new_weight = new_weight.contiguous().clone()
weight = new_weight
else:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype)
elif patch_type == "lora": #lora/locon
mat1 = ldm_patched.modules.model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = ldm_patched.modules.model_management.cast_to_device(v[1], weight.device, torch.float32)
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = ldm_patched.modules.model_management.cast_to_device(v[3], weight.device, torch.float32)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w1_a, weight.device, torch.float32),
ldm_patched.modules.model_management.cast_to_device(w1_b, weight.device, torch.float32))
else:
w1 = ldm_patched.modules.model_management.cast_to_device(w1, weight.device, torch.float32)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w2_a, weight.device, torch.float32),
ldm_patched.modules.model_management.cast_to_device(w2_b, weight.device, torch.float32))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
ldm_patched.modules.model_management.cast_to_device(t2, weight.device, torch.float32),
ldm_patched.modules.model_management.cast_to_device(w2_b, weight.device, torch.float32),
ldm_patched.modules.model_management.cast_to_device(w2_a, weight.device, torch.float32))
else:
w2 = ldm_patched.modules.model_management.cast_to_device(w2, weight.device, torch.float32)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha *= v[2] / dim
try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha *= v[2] / w1b.shape[0]
w2a = v[3]
w2b = v[4]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
ldm_patched.modules.model_management.cast_to_device(t1, weight.device, torch.float32),
ldm_patched.modules.model_management.cast_to_device(w1b, weight.device, torch.float32),
ldm_patched.modules.model_management.cast_to_device(w1a, weight.device, torch.float32))
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
ldm_patched.modules.model_management.cast_to_device(t2, weight.device, torch.float32),
ldm_patched.modules.model_management.cast_to_device(w2b, weight.device, torch.float32),
ldm_patched.modules.model_management.cast_to_device(w2a, weight.device, torch.float32))
else:
m1 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w1a, weight.device, torch.float32),
ldm_patched.modules.model_management.cast_to_device(w1b, weight.device, torch.float32))
m2 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w2a, weight.device, torch.float32),
ldm_patched.modules.model_management.cast_to_device(w2b, weight.device, torch.float32))
try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
elif patch_type == "glora":
if v[4] is not None:
alpha *= v[4] / v[0].shape[0]
a1 = ldm_patched.modules.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
a2 = ldm_patched.modules.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
b1 = ldm_patched.modules.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
b2 = ldm_patched.modules.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
elif patch_type in extra_weight_calculators:
weight = extra_weight_calculators[patch_type](weight, alpha, v)
else:
print("patch type not recognized", patch_type, key)
return weight
def unpatch_model(self, device_to=None):
keys = list(self.backup.keys())
if self.weight_inplace_update:
for k in keys:
ldm_patched.modules.utils.copy_to_param(self.model, k, self.backup[k])
else:
for k in keys:
ldm_patched.modules.utils.set_attr(self.model, k, self.backup[k])
self.backup = {}
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
keys = list(self.object_patches_backup.keys())
for k in keys:
ldm_patched.modules.utils.set_attr_raw(self.model, k, self.object_patches_backup[k])
self.object_patches_backup = {}
from backend.patcher.base import *

View File

@@ -25,7 +25,6 @@ class UnetPatcher(ModelPatcher):
n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
n.controlnet_linked_list = self.controlnet_linked_list
n.extra_preserved_memory_during_sampling = self.extra_preserved_memory_during_sampling
n.extra_model_patchers_during_sampling = self.extra_model_patchers_during_sampling.copy()