mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 11:11:15 +00:00
rework lora and patching system
and dora etc - backend rework is 60% finished And I also removed the webui’s extremely annoying lora filter from model versions.
This commit is contained in:
134
backend/misc/diffusers_state_dict.py
Normal file
134
backend/misc/diffusers_state_dict.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
UNET_MAP_ATTENTIONS = {
|
||||||
|
"proj_in.weight",
|
||||||
|
"proj_in.bias",
|
||||||
|
"proj_out.weight",
|
||||||
|
"proj_out.bias",
|
||||||
|
"norm.weight",
|
||||||
|
"norm.bias",
|
||||||
|
}
|
||||||
|
|
||||||
|
TRANSFORMER_BLOCKS = {
|
||||||
|
"norm1.weight",
|
||||||
|
"norm1.bias",
|
||||||
|
"norm2.weight",
|
||||||
|
"norm2.bias",
|
||||||
|
"norm3.weight",
|
||||||
|
"norm3.bias",
|
||||||
|
"attn1.to_q.weight",
|
||||||
|
"attn1.to_k.weight",
|
||||||
|
"attn1.to_v.weight",
|
||||||
|
"attn1.to_out.0.weight",
|
||||||
|
"attn1.to_out.0.bias",
|
||||||
|
"attn2.to_q.weight",
|
||||||
|
"attn2.to_k.weight",
|
||||||
|
"attn2.to_v.weight",
|
||||||
|
"attn2.to_out.0.weight",
|
||||||
|
"attn2.to_out.0.bias",
|
||||||
|
"ff.net.0.proj.weight",
|
||||||
|
"ff.net.0.proj.bias",
|
||||||
|
"ff.net.2.weight",
|
||||||
|
"ff.net.2.bias",
|
||||||
|
}
|
||||||
|
|
||||||
|
UNET_MAP_RESNET = {
|
||||||
|
"in_layers.2.weight": "conv1.weight",
|
||||||
|
"in_layers.2.bias": "conv1.bias",
|
||||||
|
"emb_layers.1.weight": "time_emb_proj.weight",
|
||||||
|
"emb_layers.1.bias": "time_emb_proj.bias",
|
||||||
|
"out_layers.3.weight": "conv2.weight",
|
||||||
|
"out_layers.3.bias": "conv2.bias",
|
||||||
|
"skip_connection.weight": "conv_shortcut.weight",
|
||||||
|
"skip_connection.bias": "conv_shortcut.bias",
|
||||||
|
"in_layers.0.weight": "norm1.weight",
|
||||||
|
"in_layers.0.bias": "norm1.bias",
|
||||||
|
"out_layers.0.weight": "norm2.weight",
|
||||||
|
"out_layers.0.bias": "norm2.bias",
|
||||||
|
}
|
||||||
|
|
||||||
|
UNET_MAP_BASIC = {
|
||||||
|
("label_emb.0.0.weight", "class_embedding.linear_1.weight"),
|
||||||
|
("label_emb.0.0.bias", "class_embedding.linear_1.bias"),
|
||||||
|
("label_emb.0.2.weight", "class_embedding.linear_2.weight"),
|
||||||
|
("label_emb.0.2.bias", "class_embedding.linear_2.bias"),
|
||||||
|
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
|
||||||
|
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
|
||||||
|
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
|
||||||
|
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
|
||||||
|
("input_blocks.0.0.weight", "conv_in.weight"),
|
||||||
|
("input_blocks.0.0.bias", "conv_in.bias"),
|
||||||
|
("out.0.weight", "conv_norm_out.weight"),
|
||||||
|
("out.0.bias", "conv_norm_out.bias"),
|
||||||
|
("out.2.weight", "conv_out.weight"),
|
||||||
|
("out.2.bias", "conv_out.bias"),
|
||||||
|
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||||
|
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
||||||
|
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
||||||
|
("time_embed.2.bias", "time_embedding.linear_2.bias")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def unet_to_diffusers(unet_config):
|
||||||
|
if "num_res_blocks" not in unet_config:
|
||||||
|
return {}
|
||||||
|
num_res_blocks = unet_config["num_res_blocks"]
|
||||||
|
channel_mult = unet_config["channel_mult"]
|
||||||
|
transformer_depth = unet_config["transformer_depth"][:]
|
||||||
|
transformer_depth_output = unet_config["transformer_depth_output"][:]
|
||||||
|
num_blocks = len(channel_mult)
|
||||||
|
|
||||||
|
transformers_mid = unet_config.get("transformer_depth_middle", None)
|
||||||
|
|
||||||
|
diffusers_unet_map = {}
|
||||||
|
for x in range(num_blocks):
|
||||||
|
n = 1 + (num_res_blocks[x] + 1) * x
|
||||||
|
for i in range(num_res_blocks[x]):
|
||||||
|
for b in UNET_MAP_RESNET:
|
||||||
|
diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
|
||||||
|
num_transformers = transformer_depth.pop(0)
|
||||||
|
if num_transformers > 0:
|
||||||
|
for b in UNET_MAP_ATTENTIONS:
|
||||||
|
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
|
||||||
|
for t in range(num_transformers):
|
||||||
|
for b in TRANSFORMER_BLOCKS:
|
||||||
|
diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||||
|
n += 1
|
||||||
|
for k in ["weight", "bias"]:
|
||||||
|
diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
for b in UNET_MAP_ATTENTIONS:
|
||||||
|
diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b)
|
||||||
|
for t in range(transformers_mid):
|
||||||
|
for b in TRANSFORMER_BLOCKS:
|
||||||
|
diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b)
|
||||||
|
|
||||||
|
for i, n in enumerate([0, 2]):
|
||||||
|
for b in UNET_MAP_RESNET:
|
||||||
|
diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
|
||||||
|
|
||||||
|
num_res_blocks = list(reversed(num_res_blocks))
|
||||||
|
for x in range(num_blocks):
|
||||||
|
n = (num_res_blocks[x] + 1) * x
|
||||||
|
l = num_res_blocks[x] + 1
|
||||||
|
for i in range(l):
|
||||||
|
c = 0
|
||||||
|
for b in UNET_MAP_RESNET:
|
||||||
|
diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
|
||||||
|
c += 1
|
||||||
|
num_transformers = transformer_depth_output.pop()
|
||||||
|
if num_transformers > 0:
|
||||||
|
c += 1
|
||||||
|
for b in UNET_MAP_ATTENTIONS:
|
||||||
|
diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
|
||||||
|
for t in range(num_transformers):
|
||||||
|
for b in TRANSFORMER_BLOCKS:
|
||||||
|
diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||||
|
if i == l - 1:
|
||||||
|
for k in ["weight", "bias"]:
|
||||||
|
diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k)
|
||||||
|
n += 1
|
||||||
|
|
||||||
|
for k in UNET_MAP_BASIC:
|
||||||
|
diffusers_unet_map[k[1]] = k[0]
|
||||||
|
|
||||||
|
return diffusers_unet_map
|
||||||
500
backend/patcher/base.py
Normal file
500
backend/patcher/base.py
Normal file
@@ -0,0 +1,500 @@
|
|||||||
|
# Model Patching API Template Extracted From ComfyUI
|
||||||
|
# The actual implementation for those APIs are from Forge, implemented from scratch (after forge-v1.0.1),
|
||||||
|
# and may have certain level of differences.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import copy
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from backend import memory_management, utils
|
||||||
|
|
||||||
|
extra_weight_calculators = {}
|
||||||
|
|
||||||
|
|
||||||
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
|
||||||
|
dora_scale = memory_management.cast_to_device(dora_scale, weight.device, torch.float32)
|
||||||
|
lora_diff *= alpha
|
||||||
|
weight_calc = weight + lora_diff.type(weight.dtype)
|
||||||
|
weight_norm = (
|
||||||
|
weight_calc.transpose(0, 1)
|
||||||
|
.reshape(weight_calc.shape[1], -1)
|
||||||
|
.norm(dim=1, keepdim=True)
|
||||||
|
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||||
|
if strength != 1.0:
|
||||||
|
weight_calc -= weight
|
||||||
|
weight += strength * weight_calc
|
||||||
|
else:
|
||||||
|
weight[:] = weight_calc
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||||
|
to = model_options["transformer_options"].copy()
|
||||||
|
|
||||||
|
if "patches_replace" not in to:
|
||||||
|
to["patches_replace"] = {}
|
||||||
|
else:
|
||||||
|
to["patches_replace"] = to["patches_replace"].copy()
|
||||||
|
|
||||||
|
if name not in to["patches_replace"]:
|
||||||
|
to["patches_replace"][name] = {}
|
||||||
|
else:
|
||||||
|
to["patches_replace"][name] = to["patches_replace"][name].copy()
|
||||||
|
|
||||||
|
if transformer_index is not None:
|
||||||
|
block = (block_name, number, transformer_index)
|
||||||
|
else:
|
||||||
|
block = (block_name, number)
|
||||||
|
to["patches_replace"][name][block] = patch
|
||||||
|
model_options["transformer_options"] = to
|
||||||
|
return model_options
|
||||||
|
|
||||||
|
|
||||||
|
def set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=False):
|
||||||
|
model_options["sampler_post_cfg_function"] = model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
|
||||||
|
if disable_cfg1_optimization:
|
||||||
|
model_options["disable_cfg1_optimization"] = True
|
||||||
|
return model_options
|
||||||
|
|
||||||
|
|
||||||
|
def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False):
|
||||||
|
model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function]
|
||||||
|
if disable_cfg1_optimization:
|
||||||
|
model_options["disable_cfg1_optimization"] = True
|
||||||
|
return model_options
|
||||||
|
|
||||||
|
|
||||||
|
class ModelPatcher:
|
||||||
|
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
||||||
|
self.size = size
|
||||||
|
self.model = model
|
||||||
|
self.patches = {}
|
||||||
|
self.backup = {}
|
||||||
|
self.object_patches = {}
|
||||||
|
self.object_patches_backup = {}
|
||||||
|
self.model_options = {"transformer_options": {}}
|
||||||
|
self.model_size()
|
||||||
|
self.load_device = load_device
|
||||||
|
self.offload_device = offload_device
|
||||||
|
if current_device is None:
|
||||||
|
self.current_device = self.offload_device
|
||||||
|
else:
|
||||||
|
self.current_device = current_device
|
||||||
|
|
||||||
|
self.weight_inplace_update = weight_inplace_update
|
||||||
|
|
||||||
|
def model_size(self):
|
||||||
|
if self.size > 0:
|
||||||
|
return self.size
|
||||||
|
self.size = memory_management.module_size(self.model)
|
||||||
|
return self.size
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
|
||||||
|
n.patches = {}
|
||||||
|
for k in self.patches:
|
||||||
|
n.patches[k] = self.patches[k][:]
|
||||||
|
|
||||||
|
n.object_patches = self.object_patches.copy()
|
||||||
|
n.model_options = copy.deepcopy(self.model_options)
|
||||||
|
return n
|
||||||
|
|
||||||
|
def is_clone(self, other):
|
||||||
|
if hasattr(other, 'model') and self.model is other.model:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def memory_required(self, input_shape):
|
||||||
|
return self.model.memory_required(input_shape=input_shape)
|
||||||
|
|
||||||
|
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
||||||
|
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||||
|
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) # Old way
|
||||||
|
else:
|
||||||
|
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||||
|
if disable_cfg1_optimization:
|
||||||
|
self.model_options["disable_cfg1_optimization"] = True
|
||||||
|
|
||||||
|
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
||||||
|
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
||||||
|
|
||||||
|
def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False):
|
||||||
|
self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization)
|
||||||
|
|
||||||
|
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
||||||
|
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||||
|
|
||||||
|
def set_model_vae_encode_wrapper(self, wrapper_function):
|
||||||
|
self.model_options["model_vae_encode_wrapper"] = wrapper_function
|
||||||
|
|
||||||
|
def set_model_vae_decode_wrapper(self, wrapper_function):
|
||||||
|
self.model_options["model_vae_decode_wrapper"] = wrapper_function
|
||||||
|
|
||||||
|
def set_model_vae_regulation(self, vae_regulation):
|
||||||
|
self.model_options["model_vae_regulation"] = vae_regulation
|
||||||
|
|
||||||
|
def set_model_denoise_mask_function(self, denoise_mask_function):
|
||||||
|
self.model_options["denoise_mask_function"] = denoise_mask_function
|
||||||
|
|
||||||
|
def set_model_patch(self, patch, name):
|
||||||
|
to = self.model_options["transformer_options"]
|
||||||
|
if "patches" not in to:
|
||||||
|
to["patches"] = {}
|
||||||
|
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
||||||
|
|
||||||
|
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
|
||||||
|
self.model_options = set_model_options_patch_replace(self.model_options, patch, name, block_name, number, transformer_index=transformer_index)
|
||||||
|
|
||||||
|
def set_model_attn1_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "attn1_patch")
|
||||||
|
|
||||||
|
def set_model_attn2_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "attn2_patch")
|
||||||
|
|
||||||
|
def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
|
||||||
|
self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
|
||||||
|
|
||||||
|
def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
|
||||||
|
self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
|
||||||
|
|
||||||
|
def set_model_attn1_output_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "attn1_output_patch")
|
||||||
|
|
||||||
|
def set_model_attn2_output_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "attn2_output_patch")
|
||||||
|
|
||||||
|
def set_model_input_block_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "input_block_patch")
|
||||||
|
|
||||||
|
def set_model_input_block_patch_after_skip(self, patch):
|
||||||
|
self.set_model_patch(patch, "input_block_patch_after_skip")
|
||||||
|
|
||||||
|
def set_model_output_block_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "output_block_patch")
|
||||||
|
|
||||||
|
def add_object_patch(self, name, obj):
|
||||||
|
self.object_patches[name] = obj
|
||||||
|
|
||||||
|
def get_model_object(self, name):
|
||||||
|
if name in self.object_patches:
|
||||||
|
return self.object_patches[name]
|
||||||
|
else:
|
||||||
|
if name in self.object_patches_backup:
|
||||||
|
return self.object_patches_backup[name]
|
||||||
|
else:
|
||||||
|
return utils.get_attr(self.model, name)
|
||||||
|
|
||||||
|
def model_patches_to(self, device):
|
||||||
|
to = self.model_options["transformer_options"]
|
||||||
|
if "patches" in to:
|
||||||
|
patches = to["patches"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for i in range(len(patch_list)):
|
||||||
|
if hasattr(patch_list[i], "to"):
|
||||||
|
patch_list[i] = patch_list[i].to(device)
|
||||||
|
if "patches_replace" in to:
|
||||||
|
patches = to["patches_replace"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for k in patch_list:
|
||||||
|
if hasattr(patch_list[k], "to"):
|
||||||
|
patch_list[k] = patch_list[k].to(device)
|
||||||
|
if "model_function_wrapper" in self.model_options:
|
||||||
|
wrap_func = self.model_options["model_function_wrapper"]
|
||||||
|
if hasattr(wrap_func, "to"):
|
||||||
|
self.model_options["model_function_wrapper"] = wrap_func.to(device)
|
||||||
|
|
||||||
|
def model_dtype(self):
|
||||||
|
if hasattr(self.model, "get_dtype"):
|
||||||
|
return self.model.get_dtype()
|
||||||
|
|
||||||
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
|
p = set()
|
||||||
|
model_sd = self.model.state_dict()
|
||||||
|
for k in patches:
|
||||||
|
offset = None
|
||||||
|
function = None
|
||||||
|
if isinstance(k, str):
|
||||||
|
key = k
|
||||||
|
else:
|
||||||
|
offset = k[1]
|
||||||
|
key = k[0]
|
||||||
|
if len(k) > 2:
|
||||||
|
function = k[2]
|
||||||
|
|
||||||
|
if key in model_sd:
|
||||||
|
p.add(k)
|
||||||
|
current_patches = self.patches.get(key, [])
|
||||||
|
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
|
||||||
|
self.patches[key] = current_patches
|
||||||
|
|
||||||
|
return list(p)
|
||||||
|
|
||||||
|
def get_key_patches(self, filter_prefix=None):
|
||||||
|
memory_management.unload_model_clones(self)
|
||||||
|
model_sd = self.model_state_dict()
|
||||||
|
p = {}
|
||||||
|
for k in model_sd:
|
||||||
|
if filter_prefix is not None:
|
||||||
|
if not k.startswith(filter_prefix):
|
||||||
|
continue
|
||||||
|
if k in self.patches:
|
||||||
|
p[k] = [model_sd[k]] + self.patches[k]
|
||||||
|
else:
|
||||||
|
p[k] = (model_sd[k],)
|
||||||
|
return p
|
||||||
|
|
||||||
|
def model_state_dict(self, filter_prefix=None):
|
||||||
|
sd = self.model.state_dict()
|
||||||
|
keys = list(sd.keys())
|
||||||
|
if filter_prefix is not None:
|
||||||
|
for k in keys:
|
||||||
|
if not k.startswith(filter_prefix):
|
||||||
|
sd.pop(k)
|
||||||
|
return sd
|
||||||
|
|
||||||
|
def patch_model(self, device_to=None, patch_weights=True):
|
||||||
|
for k in self.object_patches:
|
||||||
|
old = utils.get_attr(self.model, k)
|
||||||
|
if k not in self.object_patches_backup:
|
||||||
|
self.object_patches_backup[k] = old
|
||||||
|
utils.set_attr_raw(self.model, k, self.object_patches[k])
|
||||||
|
|
||||||
|
if patch_weights:
|
||||||
|
model_sd = self.model_state_dict()
|
||||||
|
for key in self.patches:
|
||||||
|
if key not in model_sd:
|
||||||
|
print("could not patch. key doesn't exist in model:", key)
|
||||||
|
continue
|
||||||
|
|
||||||
|
weight = model_sd[key]
|
||||||
|
|
||||||
|
inplace_update = self.weight_inplace_update
|
||||||
|
|
||||||
|
if key not in self.backup:
|
||||||
|
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
||||||
|
|
||||||
|
if device_to is not None:
|
||||||
|
temp_weight = memory_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||||
|
else:
|
||||||
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
|
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||||
|
if inplace_update:
|
||||||
|
utils.copy_to_param(self.model, key, out_weight)
|
||||||
|
else:
|
||||||
|
utils.set_attr(self.model, key, out_weight)
|
||||||
|
del temp_weight
|
||||||
|
|
||||||
|
if device_to is not None:
|
||||||
|
self.model.to(device_to)
|
||||||
|
self.current_device = device_to
|
||||||
|
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def calculate_weight(self, patches, weight, key):
|
||||||
|
for p in patches:
|
||||||
|
strength = p[0]
|
||||||
|
v = p[1]
|
||||||
|
strength_model = p[2]
|
||||||
|
offset = p[3]
|
||||||
|
function = p[4]
|
||||||
|
if function is None:
|
||||||
|
function = lambda a: a
|
||||||
|
|
||||||
|
old_weight = None
|
||||||
|
if offset is not None:
|
||||||
|
old_weight = weight
|
||||||
|
weight = weight.narrow(offset[0], offset[1], offset[2])
|
||||||
|
|
||||||
|
if strength_model != 1.0:
|
||||||
|
weight *= strength_model
|
||||||
|
|
||||||
|
if isinstance(v, list):
|
||||||
|
v = (self.calculate_weight(v[1:], v[0].clone(), key),)
|
||||||
|
|
||||||
|
patch_type = ''
|
||||||
|
|
||||||
|
if len(v) == 1:
|
||||||
|
patch_type = "diff"
|
||||||
|
elif len(v) == 2:
|
||||||
|
patch_type = v[0]
|
||||||
|
v = v[1]
|
||||||
|
|
||||||
|
if patch_type == "diff":
|
||||||
|
w1 = v[0]
|
||||||
|
if strength != 0.0:
|
||||||
|
if w1.shape != weight.shape:
|
||||||
|
if w1.ndim == weight.ndim == 4:
|
||||||
|
new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)]
|
||||||
|
print(f'Merged with {key} channel changed to {new_shape}')
|
||||||
|
new_diff = strength * memory_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||||
|
new_weight = torch.zeros(size=new_shape).to(weight)
|
||||||
|
new_weight[:weight.shape[0], :weight.shape[1], :weight.shape[2], :weight.shape[3]] = weight
|
||||||
|
new_weight[:new_diff.shape[0], :new_diff.shape[1], :new_diff.shape[2], :new_diff.shape[3]] += new_diff
|
||||||
|
new_weight = new_weight.contiguous().clone()
|
||||||
|
weight = new_weight
|
||||||
|
else:
|
||||||
|
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||||
|
else:
|
||||||
|
weight += strength * memory_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||||
|
elif patch_type == "lora":
|
||||||
|
mat1 = memory_management.cast_to_device(v[0], weight.device, torch.float32)
|
||||||
|
mat2 = memory_management.cast_to_device(v[1], weight.device, torch.float32)
|
||||||
|
dora_scale = v[4]
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha = v[2] / mat2.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
if v[3] is not None:
|
||||||
|
mat3 = memory_management.cast_to_device(v[3], weight.device, torch.float32)
|
||||||
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||||
|
try:
|
||||||
|
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
print("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
elif patch_type == "lokr":
|
||||||
|
w1 = v[0]
|
||||||
|
w2 = v[1]
|
||||||
|
w1_a = v[3]
|
||||||
|
w1_b = v[4]
|
||||||
|
w2_a = v[5]
|
||||||
|
w2_b = v[6]
|
||||||
|
t2 = v[7]
|
||||||
|
dora_scale = v[8]
|
||||||
|
dim = None
|
||||||
|
|
||||||
|
if w1 is None:
|
||||||
|
dim = w1_b.shape[0]
|
||||||
|
w1 = torch.mm(memory_management.cast_to_device(w1_a, weight.device, torch.float32),
|
||||||
|
memory_management.cast_to_device(w1_b, weight.device, torch.float32))
|
||||||
|
else:
|
||||||
|
w1 = memory_management.cast_to_device(w1, weight.device, torch.float32)
|
||||||
|
|
||||||
|
if w2 is None:
|
||||||
|
dim = w2_b.shape[0]
|
||||||
|
if t2 is None:
|
||||||
|
w2 = torch.mm(memory_management.cast_to_device(w2_a, weight.device, torch.float32),
|
||||||
|
memory_management.cast_to_device(w2_b, weight.device, torch.float32))
|
||||||
|
else:
|
||||||
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
memory_management.cast_to_device(t2, weight.device, torch.float32),
|
||||||
|
memory_management.cast_to_device(w2_b, weight.device, torch.float32),
|
||||||
|
memory_management.cast_to_device(w2_a, weight.device, torch.float32))
|
||||||
|
else:
|
||||||
|
w2 = memory_management.cast_to_device(w2, weight.device, torch.float32)
|
||||||
|
|
||||||
|
if len(w2.shape) == 4:
|
||||||
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
|
if v[2] is not None and dim is not None:
|
||||||
|
alpha = v[2] / dim
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
print("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
elif patch_type == "loha":
|
||||||
|
w1a = v[0]
|
||||||
|
w1b = v[1]
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha = v[2] / w1b.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
w2a = v[3]
|
||||||
|
w2b = v[4]
|
||||||
|
dora_scale = v[7]
|
||||||
|
if v[5] is not None:
|
||||||
|
t1 = v[5]
|
||||||
|
t2 = v[6]
|
||||||
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
memory_management.cast_to_device(t1, weight.device, torch.float32),
|
||||||
|
memory_management.cast_to_device(w1b, weight.device, torch.float32),
|
||||||
|
memory_management.cast_to_device(w1a, weight.device, torch.float32))
|
||||||
|
|
||||||
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
memory_management.cast_to_device(t2, weight.device, torch.float32),
|
||||||
|
memory_management.cast_to_device(w2b, weight.device, torch.float32),
|
||||||
|
memory_management.cast_to_device(w2a, weight.device, torch.float32))
|
||||||
|
else:
|
||||||
|
m1 = torch.mm(memory_management.cast_to_device(w1a, weight.device, torch.float32),
|
||||||
|
memory_management.cast_to_device(w1b, weight.device, torch.float32))
|
||||||
|
m2 = torch.mm(memory_management.cast_to_device(w2a, weight.device, torch.float32),
|
||||||
|
memory_management.cast_to_device(w2b, weight.device, torch.float32))
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
print("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
elif patch_type == "glora":
|
||||||
|
if v[4] is not None:
|
||||||
|
alpha = v[4] / v[0].shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
dora_scale = v[5]
|
||||||
|
|
||||||
|
a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
|
||||||
|
a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
|
||||||
|
b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
|
||||||
|
b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
print("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
elif patch_type in extra_weight_calculators:
|
||||||
|
weight = extra_weight_calculators[patch_type](weight, strength, v)
|
||||||
|
else:
|
||||||
|
print("patch type not recognized {} {}".format(patch_type, key))
|
||||||
|
|
||||||
|
if old_weight is not None:
|
||||||
|
weight = old_weight
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def unpatch_model(self, device_to=None):
|
||||||
|
keys = list(self.backup.keys())
|
||||||
|
|
||||||
|
if self.weight_inplace_update:
|
||||||
|
for k in keys:
|
||||||
|
utils.copy_to_param(self.model, k, self.backup[k])
|
||||||
|
else:
|
||||||
|
for k in keys:
|
||||||
|
utils.set_attr(self.model, k, self.backup[k])
|
||||||
|
|
||||||
|
self.backup = {}
|
||||||
|
|
||||||
|
if device_to is not None:
|
||||||
|
self.model.to(device_to)
|
||||||
|
self.current_device = device_to
|
||||||
|
|
||||||
|
keys = list(self.object_patches_backup.keys())
|
||||||
|
for k in keys:
|
||||||
|
utils.set_attr_raw(self.model, k, self.object_patches_backup[k])
|
||||||
|
|
||||||
|
self.object_patches_backup = {}
|
||||||
293
backend/patcher/lora.py
Normal file
293
backend/patcher/lora.py
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
# LoRA Implementation Collection form ComfyUI
|
||||||
|
# Modified by Forge to support greedy loading (load a set or wrong/correct loras to a model and only preserve the correct ones),
|
||||||
|
# which is important to webui experience
|
||||||
|
|
||||||
|
from backend.misc.diffusers_state_dict import unet_to_diffusers
|
||||||
|
|
||||||
|
|
||||||
|
LORA_CLIP_MAP = {
|
||||||
|
"mlp.fc1": "mlp_fc1",
|
||||||
|
"mlp.fc2": "mlp_fc2",
|
||||||
|
"self_attn.k_proj": "self_attn_k_proj",
|
||||||
|
"self_attn.q_proj": "self_attn_q_proj",
|
||||||
|
"self_attn.v_proj": "self_attn_v_proj",
|
||||||
|
"self_attn.out_proj": "self_attn_out_proj",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_lora(lora, to_load):
|
||||||
|
patch_dict = {}
|
||||||
|
loaded_keys = set()
|
||||||
|
for x in to_load:
|
||||||
|
alpha_name = "{}.alpha".format(x)
|
||||||
|
alpha = None
|
||||||
|
if alpha_name in lora.keys():
|
||||||
|
alpha = lora[alpha_name].item()
|
||||||
|
loaded_keys.add(alpha_name)
|
||||||
|
|
||||||
|
dora_scale_name = "{}.dora_scale".format(x)
|
||||||
|
dora_scale = None
|
||||||
|
if dora_scale_name in lora.keys():
|
||||||
|
dora_scale = lora[dora_scale_name]
|
||||||
|
loaded_keys.add(dora_scale_name)
|
||||||
|
|
||||||
|
regular_lora = "{}.lora_up.weight".format(x)
|
||||||
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||||
|
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||||
|
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||||
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||||
|
A_name = None
|
||||||
|
|
||||||
|
if regular_lora in lora.keys():
|
||||||
|
A_name = regular_lora
|
||||||
|
B_name = "{}.lora_down.weight".format(x)
|
||||||
|
mid_name = "{}.lora_mid.weight".format(x)
|
||||||
|
elif diffusers_lora in lora.keys():
|
||||||
|
A_name = diffusers_lora
|
||||||
|
B_name = "{}_lora.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif diffusers2_lora in lora.keys():
|
||||||
|
A_name = diffusers2_lora
|
||||||
|
B_name = "{}.lora_A.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif diffusers3_lora in lora.keys():
|
||||||
|
A_name = diffusers3_lora
|
||||||
|
B_name = "{}.lora.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif transformers_lora in lora.keys():
|
||||||
|
A_name = transformers_lora
|
||||||
|
B_name = "{}.lora_linear_layer.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
|
||||||
|
if A_name is not None:
|
||||||
|
mid = None
|
||||||
|
if mid_name is not None and mid_name in lora.keys():
|
||||||
|
mid = lora[mid_name]
|
||||||
|
loaded_keys.add(mid_name)
|
||||||
|
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
|
||||||
|
loaded_keys.add(A_name)
|
||||||
|
loaded_keys.add(B_name)
|
||||||
|
|
||||||
|
######## loha
|
||||||
|
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
||||||
|
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
||||||
|
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
||||||
|
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
||||||
|
hada_t1_name = "{}.hada_t1".format(x)
|
||||||
|
hada_t2_name = "{}.hada_t2".format(x)
|
||||||
|
if hada_w1_a_name in lora.keys():
|
||||||
|
hada_t1 = None
|
||||||
|
hada_t2 = None
|
||||||
|
if hada_t1_name in lora.keys():
|
||||||
|
hada_t1 = lora[hada_t1_name]
|
||||||
|
hada_t2 = lora[hada_t2_name]
|
||||||
|
loaded_keys.add(hada_t1_name)
|
||||||
|
loaded_keys.add(hada_t2_name)
|
||||||
|
|
||||||
|
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
|
||||||
|
loaded_keys.add(hada_w1_a_name)
|
||||||
|
loaded_keys.add(hada_w1_b_name)
|
||||||
|
loaded_keys.add(hada_w2_a_name)
|
||||||
|
loaded_keys.add(hada_w2_b_name)
|
||||||
|
|
||||||
|
######## lokr
|
||||||
|
lokr_w1_name = "{}.lokr_w1".format(x)
|
||||||
|
lokr_w2_name = "{}.lokr_w2".format(x)
|
||||||
|
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
||||||
|
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
||||||
|
lokr_t2_name = "{}.lokr_t2".format(x)
|
||||||
|
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
||||||
|
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
||||||
|
|
||||||
|
lokr_w1 = None
|
||||||
|
if lokr_w1_name in lora.keys():
|
||||||
|
lokr_w1 = lora[lokr_w1_name]
|
||||||
|
loaded_keys.add(lokr_w1_name)
|
||||||
|
|
||||||
|
lokr_w2 = None
|
||||||
|
if lokr_w2_name in lora.keys():
|
||||||
|
lokr_w2 = lora[lokr_w2_name]
|
||||||
|
loaded_keys.add(lokr_w2_name)
|
||||||
|
|
||||||
|
lokr_w1_a = None
|
||||||
|
if lokr_w1_a_name in lora.keys():
|
||||||
|
lokr_w1_a = lora[lokr_w1_a_name]
|
||||||
|
loaded_keys.add(lokr_w1_a_name)
|
||||||
|
|
||||||
|
lokr_w1_b = None
|
||||||
|
if lokr_w1_b_name in lora.keys():
|
||||||
|
lokr_w1_b = lora[lokr_w1_b_name]
|
||||||
|
loaded_keys.add(lokr_w1_b_name)
|
||||||
|
|
||||||
|
lokr_w2_a = None
|
||||||
|
if lokr_w2_a_name in lora.keys():
|
||||||
|
lokr_w2_a = lora[lokr_w2_a_name]
|
||||||
|
loaded_keys.add(lokr_w2_a_name)
|
||||||
|
|
||||||
|
lokr_w2_b = None
|
||||||
|
if lokr_w2_b_name in lora.keys():
|
||||||
|
lokr_w2_b = lora[lokr_w2_b_name]
|
||||||
|
loaded_keys.add(lokr_w2_b_name)
|
||||||
|
|
||||||
|
lokr_t2 = None
|
||||||
|
if lokr_t2_name in lora.keys():
|
||||||
|
lokr_t2 = lora[lokr_t2_name]
|
||||||
|
loaded_keys.add(lokr_t2_name)
|
||||||
|
|
||||||
|
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||||
|
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
|
||||||
|
|
||||||
|
# glora
|
||||||
|
a1_name = "{}.a1.weight".format(x)
|
||||||
|
a2_name = "{}.a2.weight".format(x)
|
||||||
|
b1_name = "{}.b1.weight".format(x)
|
||||||
|
b2_name = "{}.b2.weight".format(x)
|
||||||
|
if a1_name in lora:
|
||||||
|
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
|
||||||
|
loaded_keys.add(a1_name)
|
||||||
|
loaded_keys.add(a2_name)
|
||||||
|
loaded_keys.add(b1_name)
|
||||||
|
loaded_keys.add(b2_name)
|
||||||
|
|
||||||
|
w_norm_name = "{}.w_norm".format(x)
|
||||||
|
b_norm_name = "{}.b_norm".format(x)
|
||||||
|
w_norm = lora.get(w_norm_name, None)
|
||||||
|
b_norm = lora.get(b_norm_name, None)
|
||||||
|
|
||||||
|
if w_norm is not None:
|
||||||
|
loaded_keys.add(w_norm_name)
|
||||||
|
patch_dict[to_load[x]] = ("diff", (w_norm,))
|
||||||
|
if b_norm is not None:
|
||||||
|
loaded_keys.add(b_norm_name)
|
||||||
|
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
|
||||||
|
|
||||||
|
diff_name = "{}.diff".format(x)
|
||||||
|
diff_weight = lora.get(diff_name, None)
|
||||||
|
if diff_weight is not None:
|
||||||
|
patch_dict[to_load[x]] = ("diff", (diff_weight,))
|
||||||
|
loaded_keys.add(diff_name)
|
||||||
|
|
||||||
|
diff_bias_name = "{}.diff_b".format(x)
|
||||||
|
diff_bias = lora.get(diff_bias_name, None)
|
||||||
|
if diff_bias is not None:
|
||||||
|
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
||||||
|
loaded_keys.add(diff_bias_name)
|
||||||
|
|
||||||
|
remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
|
||||||
|
return patch_dict, remaining_dict
|
||||||
|
|
||||||
|
|
||||||
|
def model_lora_keys_clip(model, key_map={}):
|
||||||
|
sdk = model.state_dict().keys()
|
||||||
|
|
||||||
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||||
|
clip_l_present = False
|
||||||
|
for b in range(32): # TODO: clean up
|
||||||
|
for c in LORA_CLIP_MAP:
|
||||||
|
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
|
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
|
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c)
|
||||||
|
key_map[lora_key] = k
|
||||||
|
|
||||||
|
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
|
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
|
clip_l_present = True
|
||||||
|
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c)
|
||||||
|
key_map[lora_key] = k
|
||||||
|
|
||||||
|
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
|
if k in sdk:
|
||||||
|
if clip_l_present:
|
||||||
|
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
|
lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c)
|
||||||
|
key_map[lora_key] = k
|
||||||
|
else:
|
||||||
|
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
|
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c)
|
||||||
|
key_map[lora_key] = k
|
||||||
|
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
|
|
||||||
|
for k in sdk: # OneTrainer SD3 lora
|
||||||
|
if k.startswith("t5xxl.transformer.") and k.endswith(".weight"):
|
||||||
|
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||||
|
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
||||||
|
key_map[lora_key] = k
|
||||||
|
|
||||||
|
k = "clip_g.transformer.text_projection.weight"
|
||||||
|
if k in sdk:
|
||||||
|
key_map["lora_prior_te_text_projection"] = k
|
||||||
|
key_map["lora_te2_text_projection"] = k
|
||||||
|
|
||||||
|
k = "clip_l.transformer.text_projection.weight"
|
||||||
|
if k in sdk:
|
||||||
|
key_map["lora_te1_text_projection"] = k
|
||||||
|
|
||||||
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
|
def model_lora_keys_unet(model, key_map={}):
|
||||||
|
sd = model.state_dict()
|
||||||
|
sdk = sd.keys()
|
||||||
|
|
||||||
|
for k in sdk:
|
||||||
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||||
|
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||||
|
key_map["lora_unet_{}".format(key_lora)] = k
|
||||||
|
key_map["lora_prior_unet_{}".format(key_lora)] = k
|
||||||
|
|
||||||
|
diffusers_keys = unet_to_diffusers(model.diffusion_model.legacy_config)
|
||||||
|
for k in diffusers_keys:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
||||||
|
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||||
|
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
||||||
|
|
||||||
|
diffusers_lora_prefix = ["", "unet."]
|
||||||
|
for p in diffusers_lora_prefix:
|
||||||
|
diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
|
||||||
|
if diffusers_lora_key.endswith(".to_out.0"):
|
||||||
|
diffusers_lora_key = diffusers_lora_key[:-2]
|
||||||
|
key_map[diffusers_lora_key] = unet_key
|
||||||
|
|
||||||
|
# TODO:
|
||||||
|
|
||||||
|
# if isinstance(model, xxxx.modules.model_base.SD3): # Diffusers lora SD3
|
||||||
|
# diffusers_keys = xxxx.modules.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||||
|
# for k in diffusers_keys:
|
||||||
|
# if k.endswith(".weight"):
|
||||||
|
# to = diffusers_keys[k]
|
||||||
|
# key_lora = "transformer.{}".format(k[:-len(".weight")]) # regular diffusers sd3 lora format
|
||||||
|
# key_map[key_lora] = to
|
||||||
|
#
|
||||||
|
# key_lora = "base_model.model.{}".format(k[:-len(".weight")]) # format for flash-sd3 lora and others?
|
||||||
|
# key_map[key_lora] = to
|
||||||
|
#
|
||||||
|
# key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) # OneTrainer lora
|
||||||
|
# key_map[key_lora] = to
|
||||||
|
#
|
||||||
|
# if isinstance(model, xxxx.modules.model_base.AuraFlow): # Diffusers lora AuraFlow
|
||||||
|
# diffusers_keys = xxxx.modules.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||||
|
# for k in diffusers_keys:
|
||||||
|
# if k.endswith(".weight"):
|
||||||
|
# to = diffusers_keys[k]
|
||||||
|
# key_lora = "transformer.{}".format(k[:-len(".weight")]) # simpletrainer and probably regular diffusers lora format
|
||||||
|
# key_map[key_lora] = to
|
||||||
|
#
|
||||||
|
# if isinstance(model, xxxx.modules.model_base.HunyuanDiT):
|
||||||
|
# for k in sdk:
|
||||||
|
# if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||||
|
# key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||||
|
# key_map["base_model.model.{}".format(key_lora)] = k # official hunyuan lora format
|
||||||
|
|
||||||
|
return key_map
|
||||||
30
backend/utils.py
Normal file
30
backend/utils.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def set_attr(obj, attr, value):
|
||||||
|
attrs = attr.split(".")
|
||||||
|
for name in attrs[:-1]:
|
||||||
|
obj = getattr(obj, name)
|
||||||
|
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
|
||||||
|
|
||||||
|
|
||||||
|
def set_attr_raw(obj, attr, value):
|
||||||
|
attrs = attr.split(".")
|
||||||
|
for name in attrs[:-1]:
|
||||||
|
obj = getattr(obj, name)
|
||||||
|
setattr(obj, attrs[-1], value)
|
||||||
|
|
||||||
|
|
||||||
|
def copy_to_param(obj, attr, value):
|
||||||
|
attrs = attr.split(".")
|
||||||
|
for name in attrs[:-1]:
|
||||||
|
obj = getattr(obj, name)
|
||||||
|
prev = getattr(obj, attrs[-1])
|
||||||
|
prev.data.copy_(value)
|
||||||
|
|
||||||
|
|
||||||
|
def get_attr(obj, attr):
|
||||||
|
attrs = attr.split(".")
|
||||||
|
for name in attrs:
|
||||||
|
obj = getattr(obj, name)
|
||||||
|
return obj
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
class LoraPatches:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def undo(self):
|
|
||||||
pass
|
|
||||||
@@ -1,228 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
import os
|
|
||||||
from collections import namedtuple
|
|
||||||
import enum
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from modules import sd_models, cache, errors, hashes, shared
|
|
||||||
import modules.models.sd3.mmdit
|
|
||||||
|
|
||||||
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
|
|
||||||
|
|
||||||
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
|
||||||
|
|
||||||
|
|
||||||
class SdVersion(enum.Enum):
|
|
||||||
Unknown = 1
|
|
||||||
SD1 = 2
|
|
||||||
SD2 = 3
|
|
||||||
SDXL = 4
|
|
||||||
|
|
||||||
|
|
||||||
class NetworkOnDisk:
|
|
||||||
def __init__(self, name, filename):
|
|
||||||
self.name = name
|
|
||||||
self.filename = filename
|
|
||||||
self.metadata = {}
|
|
||||||
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
|
||||||
|
|
||||||
def read_metadata():
|
|
||||||
metadata = sd_models.read_metadata_from_safetensors(filename)
|
|
||||||
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
if self.is_safetensors:
|
|
||||||
try:
|
|
||||||
self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
|
|
||||||
except Exception as e:
|
|
||||||
errors.display(e, f"reading lora {filename}")
|
|
||||||
|
|
||||||
if self.metadata:
|
|
||||||
m = {}
|
|
||||||
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
|
|
||||||
m[k] = v
|
|
||||||
|
|
||||||
self.metadata = m
|
|
||||||
|
|
||||||
self.alias = self.metadata.get('ss_output_name', self.name)
|
|
||||||
|
|
||||||
self.hash = None
|
|
||||||
self.shorthash = None
|
|
||||||
self.set_hash(
|
|
||||||
self.metadata.get('sshs_model_hash') or
|
|
||||||
hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
|
|
||||||
''
|
|
||||||
)
|
|
||||||
|
|
||||||
self.sd_version = self.detect_version()
|
|
||||||
|
|
||||||
def detect_version(self):
|
|
||||||
if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
|
|
||||||
return SdVersion.SDXL
|
|
||||||
elif str(self.metadata.get('ss_v2', "")) == "True":
|
|
||||||
return SdVersion.SD2
|
|
||||||
elif len(self.metadata):
|
|
||||||
return SdVersion.SD1
|
|
||||||
|
|
||||||
return SdVersion.Unknown
|
|
||||||
|
|
||||||
def set_hash(self, v):
|
|
||||||
self.hash = v
|
|
||||||
self.shorthash = self.hash[0:12]
|
|
||||||
|
|
||||||
if self.shorthash:
|
|
||||||
import networks
|
|
||||||
networks.available_network_hash_lookup[self.shorthash] = self
|
|
||||||
|
|
||||||
def read_hash(self):
|
|
||||||
if not self.hash:
|
|
||||||
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
|
|
||||||
|
|
||||||
def get_alias(self):
|
|
||||||
import networks
|
|
||||||
if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:
|
|
||||||
return self.name
|
|
||||||
else:
|
|
||||||
return self.alias
|
|
||||||
|
|
||||||
|
|
||||||
class Network: # LoraModule
|
|
||||||
def __init__(self, name, network_on_disk: NetworkOnDisk):
|
|
||||||
self.name = name
|
|
||||||
self.network_on_disk = network_on_disk
|
|
||||||
self.te_multiplier = 1.0
|
|
||||||
self.unet_multiplier = 1.0
|
|
||||||
self.dyn_dim = None
|
|
||||||
self.modules = {}
|
|
||||||
self.bundle_embeddings = {}
|
|
||||||
self.mtime = None
|
|
||||||
|
|
||||||
self.mentioned_name = None
|
|
||||||
"""the text that was used to add the network to prompt - can be either name or an alias"""
|
|
||||||
|
|
||||||
|
|
||||||
class ModuleType:
|
|
||||||
def create_module(self, net: Network, weights: NetworkWeights) -> Network | None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class NetworkModule:
|
|
||||||
def __init__(self, net: Network, weights: NetworkWeights):
|
|
||||||
self.network = net
|
|
||||||
self.network_key = weights.network_key
|
|
||||||
self.sd_key = weights.sd_key
|
|
||||||
self.sd_module = weights.sd_module
|
|
||||||
|
|
||||||
if isinstance(self.sd_module, modules.models.sd3.mmdit.QkvLinear):
|
|
||||||
s = self.sd_module.weight.shape
|
|
||||||
self.shape = (s[0] // 3, s[1])
|
|
||||||
elif hasattr(self.sd_module, 'weight'):
|
|
||||||
self.shape = self.sd_module.weight.shape
|
|
||||||
elif isinstance(self.sd_module, nn.MultiheadAttention):
|
|
||||||
# For now, only self-attn use Pytorch's MHA
|
|
||||||
# So assume all qkvo proj have same shape
|
|
||||||
self.shape = self.sd_module.out_proj.weight.shape
|
|
||||||
else:
|
|
||||||
self.shape = None
|
|
||||||
|
|
||||||
self.ops = None
|
|
||||||
self.extra_kwargs = {}
|
|
||||||
if isinstance(self.sd_module, nn.Conv2d):
|
|
||||||
self.ops = F.conv2d
|
|
||||||
self.extra_kwargs = {
|
|
||||||
'stride': self.sd_module.stride,
|
|
||||||
'padding': self.sd_module.padding
|
|
||||||
}
|
|
||||||
elif isinstance(self.sd_module, nn.Linear):
|
|
||||||
self.ops = F.linear
|
|
||||||
elif isinstance(self.sd_module, nn.LayerNorm):
|
|
||||||
self.ops = F.layer_norm
|
|
||||||
self.extra_kwargs = {
|
|
||||||
'normalized_shape': self.sd_module.normalized_shape,
|
|
||||||
'eps': self.sd_module.eps
|
|
||||||
}
|
|
||||||
elif isinstance(self.sd_module, nn.GroupNorm):
|
|
||||||
self.ops = F.group_norm
|
|
||||||
self.extra_kwargs = {
|
|
||||||
'num_groups': self.sd_module.num_groups,
|
|
||||||
'eps': self.sd_module.eps
|
|
||||||
}
|
|
||||||
|
|
||||||
self.dim = None
|
|
||||||
self.bias = weights.w.get("bias")
|
|
||||||
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
|
|
||||||
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
|
|
||||||
|
|
||||||
self.dora_scale = weights.w.get("dora_scale", None)
|
|
||||||
self.dora_norm_dims = len(self.shape) - 1
|
|
||||||
|
|
||||||
def multiplier(self):
|
|
||||||
if 'transformer' in self.sd_key[:20]:
|
|
||||||
return self.network.te_multiplier
|
|
||||||
else:
|
|
||||||
return self.network.unet_multiplier
|
|
||||||
|
|
||||||
def calc_scale(self):
|
|
||||||
if self.scale is not None:
|
|
||||||
return self.scale
|
|
||||||
if self.dim is not None and self.alpha is not None:
|
|
||||||
return self.alpha / self.dim
|
|
||||||
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
def apply_weight_decompose(self, updown, orig_weight):
|
|
||||||
# Match the device/dtype
|
|
||||||
orig_weight = orig_weight.to(updown.dtype)
|
|
||||||
dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype)
|
|
||||||
updown = updown.to(orig_weight.device)
|
|
||||||
|
|
||||||
merged_scale1 = updown + orig_weight
|
|
||||||
merged_scale1_norm = (
|
|
||||||
merged_scale1.transpose(0, 1)
|
|
||||||
.reshape(merged_scale1.shape[1], -1)
|
|
||||||
.norm(dim=1, keepdim=True)
|
|
||||||
.reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims)
|
|
||||||
.transpose(0, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
dora_merged = (
|
|
||||||
merged_scale1 * (dora_scale / merged_scale1_norm)
|
|
||||||
)
|
|
||||||
final_updown = dora_merged - orig_weight
|
|
||||||
return final_updown
|
|
||||||
|
|
||||||
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
|
||||||
if self.bias is not None:
|
|
||||||
updown = updown.reshape(self.bias.shape)
|
|
||||||
updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
|
|
||||||
updown = updown.reshape(output_shape)
|
|
||||||
|
|
||||||
if len(output_shape) == 4:
|
|
||||||
updown = updown.reshape(output_shape)
|
|
||||||
|
|
||||||
if orig_weight.size().numel() == updown.size().numel():
|
|
||||||
updown = updown.reshape(orig_weight.shape)
|
|
||||||
|
|
||||||
if ex_bias is not None:
|
|
||||||
ex_bias = ex_bias * self.multiplier()
|
|
||||||
|
|
||||||
updown = updown * self.calc_scale()
|
|
||||||
|
|
||||||
if self.dora_scale is not None:
|
|
||||||
updown = self.apply_weight_decompose(updown, orig_weight)
|
|
||||||
|
|
||||||
return updown * self.multiplier(), ex_bias
|
|
||||||
|
|
||||||
def calc_updown(self, target):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def forward(self, x, y):
|
|
||||||
"""A general forward implementation for all modules"""
|
|
||||||
if self.ops is None:
|
|
||||||
raise NotImplementedError()
|
|
||||||
else:
|
|
||||||
updown, ex_bias = self.calc_updown(self.sd_module.weight)
|
|
||||||
return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)
|
|
||||||
|
|
||||||
@@ -1,62 +1,62 @@
|
|||||||
from modules import extra_networks, shared
|
from modules import extra_networks, shared
|
||||||
import networks
|
import networks
|
||||||
|
|
||||||
|
|
||||||
class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__('lora')
|
super().__init__('lora')
|
||||||
|
|
||||||
self.errors = {}
|
self.errors = {}
|
||||||
"""mapping of network names to the number of errors the network had during operation"""
|
"""mapping of network names to the number of errors the network had during operation"""
|
||||||
|
|
||||||
remove_symbols = str.maketrans('', '', ":,")
|
remove_symbols = str.maketrans('', '', ":,")
|
||||||
|
|
||||||
def activate(self, p, params_list):
|
def activate(self, p, params_list):
|
||||||
additional = shared.opts.sd_lora
|
additional = shared.opts.sd_lora
|
||||||
|
|
||||||
self.errors.clear()
|
self.errors.clear()
|
||||||
|
|
||||||
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
|
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
|
||||||
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||||
|
|
||||||
names = []
|
names = []
|
||||||
te_multipliers = []
|
te_multipliers = []
|
||||||
unet_multipliers = []
|
unet_multipliers = []
|
||||||
dyn_dims = []
|
dyn_dims = []
|
||||||
for params in params_list:
|
for params in params_list:
|
||||||
assert params.items
|
assert params.items
|
||||||
|
|
||||||
names.append(params.positional[0])
|
names.append(params.positional[0])
|
||||||
|
|
||||||
te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0
|
te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0
|
||||||
te_multiplier = float(params.named.get("te", te_multiplier))
|
te_multiplier = float(params.named.get("te", te_multiplier))
|
||||||
|
|
||||||
unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else te_multiplier
|
unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else te_multiplier
|
||||||
unet_multiplier = float(params.named.get("unet", unet_multiplier))
|
unet_multiplier = float(params.named.get("unet", unet_multiplier))
|
||||||
|
|
||||||
dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None
|
dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None
|
||||||
dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim
|
dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim
|
||||||
|
|
||||||
te_multipliers.append(te_multiplier)
|
te_multipliers.append(te_multiplier)
|
||||||
unet_multipliers.append(unet_multiplier)
|
unet_multipliers.append(unet_multiplier)
|
||||||
dyn_dims.append(dyn_dim)
|
dyn_dims.append(dyn_dim)
|
||||||
|
|
||||||
networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims)
|
networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims)
|
||||||
|
|
||||||
if shared.opts.lora_add_hashes_to_infotext:
|
if shared.opts.lora_add_hashes_to_infotext:
|
||||||
if not getattr(p, "is_hr_pass", False) or not hasattr(p, "lora_hashes"):
|
if not getattr(p, "is_hr_pass", False) or not hasattr(p, "lora_hashes"):
|
||||||
p.lora_hashes = {}
|
p.lora_hashes = {}
|
||||||
|
|
||||||
for item in networks.loaded_networks:
|
for item in networks.loaded_networks:
|
||||||
if item.network_on_disk.shorthash and item.mentioned_name:
|
if item.network_on_disk.shorthash and item.mentioned_name:
|
||||||
p.lora_hashes[item.mentioned_name.translate(self.remove_symbols)] = item.network_on_disk.shorthash
|
p.lora_hashes[item.mentioned_name.translate(self.remove_symbols)] = item.network_on_disk.shorthash
|
||||||
|
|
||||||
if p.lora_hashes:
|
if p.lora_hashes:
|
||||||
p.extra_generation_params["Lora hashes"] = ', '.join(f'{k}: {v}' for k, v in p.lora_hashes.items())
|
p.extra_generation_params["Lora hashes"] = ', '.join(f'{k}: {v}' for k, v in p.lora_hashes.items())
|
||||||
|
|
||||||
def deactivate(self, p):
|
def deactivate(self, p):
|
||||||
if self.errors:
|
if self.errors:
|
||||||
p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))
|
p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))
|
||||||
|
|
||||||
self.errors.clear()
|
self.errors.clear()
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
import networks
|
import networks
|
||||||
|
|
||||||
list_available_loras = networks.list_available_networks
|
list_available_loras = networks.list_available_networks
|
||||||
|
|
||||||
available_loras = networks.available_networks
|
available_loras = networks.available_networks
|
||||||
available_lora_aliases = networks.available_network_aliases
|
available_lora_aliases = networks.available_network_aliases
|
||||||
available_lora_hash_lookup = networks.available_network_hash_lookup
|
available_lora_hash_lookup = networks.available_network_hash_lookup
|
||||||
forbidden_lora_aliases = networks.forbidden_network_aliases
|
forbidden_lora_aliases = networks.forbidden_network_aliases
|
||||||
loaded_loras = networks.loaded_networks
|
loaded_loras = networks.loaded_networks
|
||||||
74
extensions-builtin/sd_forge_lora/network.py
Normal file
74
extensions-builtin/sd_forge_lora/network.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from modules import sd_models, cache, errors, hashes, shared
|
||||||
|
|
||||||
|
|
||||||
|
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkOnDisk:
|
||||||
|
def __init__(self, name, filename):
|
||||||
|
self.name = name
|
||||||
|
self.filename = filename
|
||||||
|
self.metadata = {}
|
||||||
|
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
||||||
|
|
||||||
|
def read_metadata():
|
||||||
|
metadata = sd_models.read_metadata_from_safetensors(filename)
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
if self.is_safetensors:
|
||||||
|
try:
|
||||||
|
self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"reading lora {filename}")
|
||||||
|
|
||||||
|
if self.metadata:
|
||||||
|
m = {}
|
||||||
|
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
|
||||||
|
m[k] = v
|
||||||
|
|
||||||
|
self.metadata = m
|
||||||
|
|
||||||
|
self.alias = self.metadata.get('ss_output_name', self.name)
|
||||||
|
|
||||||
|
self.hash = None
|
||||||
|
self.shorthash = None
|
||||||
|
self.set_hash(
|
||||||
|
self.metadata.get('sshs_model_hash') or
|
||||||
|
hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
|
||||||
|
''
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_hash(self, v):
|
||||||
|
self.hash = v
|
||||||
|
self.shorthash = self.hash[0:12]
|
||||||
|
|
||||||
|
if self.shorthash:
|
||||||
|
import networks
|
||||||
|
networks.available_network_hash_lookup[self.shorthash] = self
|
||||||
|
|
||||||
|
def read_hash(self):
|
||||||
|
if not self.hash:
|
||||||
|
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
|
||||||
|
|
||||||
|
def get_alias(self):
|
||||||
|
import networks
|
||||||
|
if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:
|
||||||
|
return self.name
|
||||||
|
else:
|
||||||
|
return self.alias
|
||||||
|
|
||||||
|
|
||||||
|
class Network:
|
||||||
|
def __init__(self, name, network_on_disk: NetworkOnDisk):
|
||||||
|
self.name = name
|
||||||
|
self.network_on_disk = network_on_disk
|
||||||
|
self.te_multiplier = 1.0
|
||||||
|
self.unet_multiplier = 1.0
|
||||||
|
self.dyn_dim = None
|
||||||
|
self.modules = {}
|
||||||
|
self.bundle_embeddings = {}
|
||||||
|
self.mtime = None
|
||||||
|
self.mentioned_name = None
|
||||||
@@ -1,272 +1,171 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import lora_patches
|
import functools
|
||||||
import functools
|
import network
|
||||||
import network
|
|
||||||
|
import torch
|
||||||
import torch
|
from typing import Union
|
||||||
from typing import Union
|
|
||||||
|
from modules import shared, sd_models, errors, scripts
|
||||||
from modules import shared, sd_models, errors, scripts
|
from ldm_patched.modules.utils import load_torch_file
|
||||||
from ldm_patched.modules.utils import load_torch_file
|
from ldm_patched.modules.sd import load_lora_for_models
|
||||||
from ldm_patched.modules.sd import load_lora_for_models
|
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=5)
|
||||||
@functools.lru_cache(maxsize=5)
|
def load_lora_state_dict(filename):
|
||||||
def load_lora_state_dict(filename):
|
return load_torch_file(filename, safe_load=True)
|
||||||
return load_torch_file(filename, safe_load=True)
|
|
||||||
|
|
||||||
|
def load_network(name, network_on_disk):
|
||||||
def convert_diffusers_name_to_compvis(key, is_sd2):
|
net = network.Network(name, network_on_disk)
|
||||||
pass
|
net.mtime = os.path.getmtime(network_on_disk.filename)
|
||||||
|
|
||||||
|
return net
|
||||||
def assign_network_names_to_compvis_modules(sd_model):
|
|
||||||
pass
|
|
||||||
|
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
||||||
|
global lora_state_dict_cache
|
||||||
class BundledTIHash(str):
|
|
||||||
def __init__(self, hash_str):
|
current_sd = sd_models.model_data.get_sd_model()
|
||||||
self.hash = hash_str
|
if current_sd is None:
|
||||||
|
return
|
||||||
def __str__(self):
|
|
||||||
return self.hash if shared.opts.lora_bundled_ti_to_infotext else ''
|
loaded_networks.clear()
|
||||||
|
|
||||||
|
unavailable_networks = []
|
||||||
def load_network(name, network_on_disk):
|
for name in names:
|
||||||
net = network.Network(name, network_on_disk)
|
if name.lower() in forbidden_network_aliases and available_networks.get(name) is None:
|
||||||
net.mtime = os.path.getmtime(network_on_disk.filename)
|
unavailable_networks.append(name)
|
||||||
|
elif available_network_aliases.get(name) is None:
|
||||||
return net
|
unavailable_networks.append(name)
|
||||||
|
|
||||||
|
if unavailable_networks:
|
||||||
def purge_networks_from_memory():
|
update_available_networks_by_names(unavailable_networks)
|
||||||
pass
|
|
||||||
|
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
|
||||||
|
if any(x is None for x in networks_on_disk):
|
||||||
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
list_available_networks()
|
||||||
global lora_state_dict_cache
|
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
|
||||||
|
|
||||||
current_sd = sd_models.model_data.get_sd_model()
|
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
|
||||||
if current_sd is None:
|
try:
|
||||||
return
|
net = load_network(name, network_on_disk)
|
||||||
|
except Exception as e:
|
||||||
loaded_networks.clear()
|
errors.display(e, f"loading network {network_on_disk.filename}")
|
||||||
|
continue
|
||||||
unavailable_networks = []
|
net.mentioned_name = name
|
||||||
for name in names:
|
network_on_disk.read_hash()
|
||||||
if name.lower() in forbidden_network_aliases and available_networks.get(name) is None:
|
loaded_networks.append(net)
|
||||||
unavailable_networks.append(name)
|
|
||||||
elif available_network_aliases.get(name) is None:
|
compiled_lora_targets = []
|
||||||
unavailable_networks.append(name)
|
for a, b, c in zip(networks_on_disk, unet_multipliers, te_multipliers):
|
||||||
|
compiled_lora_targets.append([a.filename, b, c])
|
||||||
if unavailable_networks:
|
|
||||||
update_available_networks_by_names(unavailable_networks)
|
compiled_lora_targets_hash = str(compiled_lora_targets)
|
||||||
|
|
||||||
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
|
if current_sd.current_lora_hash == compiled_lora_targets_hash:
|
||||||
if any(x is None for x in networks_on_disk):
|
return
|
||||||
list_available_networks()
|
|
||||||
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
|
current_sd.current_lora_hash = compiled_lora_targets_hash
|
||||||
|
current_sd.forge_objects.unet = current_sd.forge_objects_original.unet
|
||||||
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
|
current_sd.forge_objects.clip = current_sd.forge_objects_original.clip
|
||||||
try:
|
|
||||||
net = load_network(name, network_on_disk)
|
for filename, strength_model, strength_clip in compiled_lora_targets:
|
||||||
except Exception as e:
|
lora_sd = load_lora_state_dict(filename)
|
||||||
errors.display(e, f"loading network {network_on_disk.filename}")
|
current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models(
|
||||||
continue
|
current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip,
|
||||||
net.mentioned_name = name
|
filename=filename)
|
||||||
network_on_disk.read_hash()
|
|
||||||
loaded_networks.append(net)
|
current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy()
|
||||||
|
return
|
||||||
compiled_lora_targets = []
|
|
||||||
for a, b, c in zip(networks_on_disk, unet_multipliers, te_multipliers):
|
|
||||||
compiled_lora_targets.append([a.filename, b, c])
|
def process_network_files(names: list[str] | None = None):
|
||||||
|
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
||||||
compiled_lora_targets_hash = str(compiled_lora_targets)
|
for filename in candidates:
|
||||||
|
if os.path.isdir(filename):
|
||||||
if current_sd.current_lora_hash == compiled_lora_targets_hash:
|
continue
|
||||||
return
|
name = os.path.splitext(os.path.basename(filename))[0]
|
||||||
|
# if names is provided, only load networks with names in the list
|
||||||
current_sd.current_lora_hash = compiled_lora_targets_hash
|
if names and name not in names:
|
||||||
current_sd.forge_objects.unet = current_sd.forge_objects_original.unet
|
continue
|
||||||
current_sd.forge_objects.clip = current_sd.forge_objects_original.clip
|
try:
|
||||||
|
entry = network.NetworkOnDisk(name, filename)
|
||||||
for filename, strength_model, strength_clip in compiled_lora_targets:
|
except OSError: # should catch FileNotFoundError and PermissionError etc.
|
||||||
lora_sd = load_lora_state_dict(filename)
|
errors.report(f"Failed to load network {name} from {filename}", exc_info=True)
|
||||||
current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models(
|
continue
|
||||||
current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip,
|
|
||||||
filename=filename)
|
available_networks[name] = entry
|
||||||
|
|
||||||
current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy()
|
if entry.alias in available_network_aliases:
|
||||||
return
|
forbidden_network_aliases[entry.alias.lower()] = 1
|
||||||
|
|
||||||
|
available_network_aliases[name] = entry
|
||||||
def allowed_layer_without_weight(layer):
|
available_network_aliases[entry.alias] = entry
|
||||||
if isinstance(layer, torch.nn.LayerNorm) and not layer.elementwise_affine:
|
|
||||||
return True
|
|
||||||
|
def update_available_networks_by_names(names: list[str]):
|
||||||
return False
|
process_network_files(names)
|
||||||
|
|
||||||
|
|
||||||
def store_weights_backup(weight):
|
def list_available_networks():
|
||||||
if weight is None:
|
available_networks.clear()
|
||||||
return None
|
available_network_aliases.clear()
|
||||||
|
forbidden_network_aliases.clear()
|
||||||
return weight.to(devices.cpu, copy=True)
|
available_network_hash_lookup.clear()
|
||||||
|
forbidden_network_aliases.update({"none": 1, "Addams": 1})
|
||||||
|
|
||||||
def restore_weights_backup(obj, field, weight):
|
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
||||||
if weight is None:
|
|
||||||
setattr(obj, field, None)
|
process_network_files()
|
||||||
return
|
|
||||||
|
|
||||||
getattr(obj, field).copy_(weight)
|
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
|
||||||
|
|
||||||
|
|
||||||
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
def infotext_pasted(infotext, params):
|
||||||
pass
|
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
|
||||||
|
return # if the other extension is active, it will handle those fields, no need to do anything
|
||||||
|
|
||||||
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
added = []
|
||||||
pass
|
|
||||||
|
for k in params:
|
||||||
|
if not k.startswith("AddNet Model "):
|
||||||
def network_forward(org_module, input, original_forward):
|
continue
|
||||||
pass
|
|
||||||
|
num = k[13:]
|
||||||
|
|
||||||
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
if params.get("AddNet Module " + num) != "LoRA":
|
||||||
pass
|
continue
|
||||||
|
|
||||||
|
name = params.get("AddNet Model " + num)
|
||||||
def network_Linear_forward(self, input):
|
if name is None:
|
||||||
pass
|
continue
|
||||||
|
|
||||||
|
m = re_network_name.match(name)
|
||||||
def network_Linear_load_state_dict(self, *args, **kwargs):
|
if m:
|
||||||
pass
|
name = m.group(1)
|
||||||
|
|
||||||
|
multiplier = params.get("AddNet Weight A " + num, "1.0")
|
||||||
def network_Conv2d_forward(self, input):
|
|
||||||
pass
|
added.append(f"<lora:{name}:{multiplier}>")
|
||||||
|
|
||||||
|
if added:
|
||||||
def network_Conv2d_load_state_dict(self, *args, **kwargs):
|
params["Prompt"] += "\n" + "".join(added)
|
||||||
pass
|
|
||||||
|
|
||||||
|
extra_network_lora = None
|
||||||
def network_GroupNorm_forward(self, input):
|
|
||||||
pass
|
available_networks = {}
|
||||||
|
available_network_aliases = {}
|
||||||
|
loaded_networks = []
|
||||||
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
|
loaded_bundle_embeddings = {}
|
||||||
pass
|
networks_in_memory = {}
|
||||||
|
available_network_hash_lookup = {}
|
||||||
|
forbidden_network_aliases = {}
|
||||||
def network_LayerNorm_forward(self, input):
|
|
||||||
pass
|
list_available_networks()
|
||||||
|
|
||||||
|
|
||||||
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def network_MultiheadAttention_forward(self, *args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def process_network_files(names: list[str] | None = None):
|
|
||||||
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
|
||||||
for filename in candidates:
|
|
||||||
if os.path.isdir(filename):
|
|
||||||
continue
|
|
||||||
name = os.path.splitext(os.path.basename(filename))[0]
|
|
||||||
# if names is provided, only load networks with names in the list
|
|
||||||
if names and name not in names:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
entry = network.NetworkOnDisk(name, filename)
|
|
||||||
except OSError: # should catch FileNotFoundError and PermissionError etc.
|
|
||||||
errors.report(f"Failed to load network {name} from {filename}", exc_info=True)
|
|
||||||
continue
|
|
||||||
|
|
||||||
available_networks[name] = entry
|
|
||||||
|
|
||||||
if entry.alias in available_network_aliases:
|
|
||||||
forbidden_network_aliases[entry.alias.lower()] = 1
|
|
||||||
|
|
||||||
available_network_aliases[name] = entry
|
|
||||||
available_network_aliases[entry.alias] = entry
|
|
||||||
|
|
||||||
|
|
||||||
def update_available_networks_by_names(names: list[str]):
|
|
||||||
process_network_files(names)
|
|
||||||
|
|
||||||
|
|
||||||
def list_available_networks():
|
|
||||||
available_networks.clear()
|
|
||||||
available_network_aliases.clear()
|
|
||||||
forbidden_network_aliases.clear()
|
|
||||||
available_network_hash_lookup.clear()
|
|
||||||
forbidden_network_aliases.update({"none": 1, "Addams": 1})
|
|
||||||
|
|
||||||
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
|
||||||
|
|
||||||
process_network_files()
|
|
||||||
|
|
||||||
|
|
||||||
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
|
|
||||||
|
|
||||||
|
|
||||||
def infotext_pasted(infotext, params):
|
|
||||||
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
|
|
||||||
return # if the other extension is active, it will handle those fields, no need to do anything
|
|
||||||
|
|
||||||
added = []
|
|
||||||
|
|
||||||
for k in params:
|
|
||||||
if not k.startswith("AddNet Model "):
|
|
||||||
continue
|
|
||||||
|
|
||||||
num = k[13:]
|
|
||||||
|
|
||||||
if params.get("AddNet Module " + num) != "LoRA":
|
|
||||||
continue
|
|
||||||
|
|
||||||
name = params.get("AddNet Model " + num)
|
|
||||||
if name is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re_network_name.match(name)
|
|
||||||
if m:
|
|
||||||
name = m.group(1)
|
|
||||||
|
|
||||||
multiplier = params.get("AddNet Weight A " + num, "1.0")
|
|
||||||
|
|
||||||
added.append(f"<lora:{name}:{multiplier}>")
|
|
||||||
|
|
||||||
if added:
|
|
||||||
params["Prompt"] += "\n" + "".join(added)
|
|
||||||
|
|
||||||
|
|
||||||
originals: lora_patches.LoraPatches = None
|
|
||||||
|
|
||||||
extra_network_lora = None
|
|
||||||
|
|
||||||
available_networks = {}
|
|
||||||
available_network_aliases = {}
|
|
||||||
loaded_networks = []
|
|
||||||
loaded_bundle_embeddings = {}
|
|
||||||
networks_in_memory = {}
|
|
||||||
available_network_hash_lookup = {}
|
|
||||||
forbidden_network_aliases = {}
|
|
||||||
|
|
||||||
list_available_networks()
|
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
from modules import paths
|
from modules import paths
|
||||||
from modules.paths_internal import normalized_filepath
|
from modules.paths_internal import normalized_filepath
|
||||||
|
|
||||||
|
|
||||||
def preload(parser):
|
def preload(parser):
|
||||||
parser.add_argument("--lora-dir", type=normalized_filepath, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
|
parser.add_argument("--lora-dir", type=normalized_filepath, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
|
||||||
parser.add_argument("--lyco-dir-backcompat", type=normalized_filepath, help="Path to directory with LyCORIS networks (for backawards compatibility; can also use --lyco-dir).", default=os.path.join(paths.models_path, 'LyCORIS'))
|
parser.add_argument("--lyco-dir-backcompat", type=normalized_filepath, help="Path to directory with LyCORIS networks (for backawards compatibility; can also use --lyco-dir).", default=os.path.join(paths.models_path, 'LyCORIS'))
|
||||||
@@ -1,101 +1,90 @@
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
import network
|
import network
|
||||||
import networks
|
import networks
|
||||||
import lora # noqa:F401
|
import lora # noqa:F401
|
||||||
import lora_patches
|
import extra_networks_lora
|
||||||
import extra_networks_lora
|
import ui_extra_networks_lora
|
||||||
import ui_extra_networks_lora
|
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
||||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
|
||||||
|
|
||||||
|
def before_ui():
|
||||||
def unload():
|
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
||||||
networks.originals.undo()
|
|
||||||
|
networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
|
||||||
|
extra_networks.register_extra_network(networks.extra_network_lora)
|
||||||
def before_ui():
|
|
||||||
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
|
||||||
|
script_callbacks.on_before_ui(before_ui)
|
||||||
networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
|
script_callbacks.on_infotext_pasted(networks.infotext_pasted)
|
||||||
extra_networks.register_extra_network(networks.extra_network_lora)
|
|
||||||
|
|
||||||
|
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
||||||
networks.originals = lora_patches.LoraPatches()
|
"sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks),
|
||||||
|
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
|
||||||
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
|
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
|
||||||
script_callbacks.on_script_unloaded(unload)
|
"lora_bundled_ti_to_infotext": shared.OptionInfo(True, "Add Lora name as TI hashes for bundled Textual Inversion").info('"Add Textual Inversion hashes to infotext" needs to be enabled'),
|
||||||
script_callbacks.on_before_ui(before_ui)
|
"lora_filter_disabled": shared.OptionInfo(True, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
|
||||||
script_callbacks.on_infotext_pasted(networks.infotext_pasted)
|
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
|
||||||
|
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
|
||||||
|
"lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
|
||||||
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
"lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),
|
||||||
"sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks),
|
}))
|
||||||
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
|
|
||||||
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
|
|
||||||
"lora_bundled_ti_to_infotext": shared.OptionInfo(True, "Add Lora name as TI hashes for bundled Textual Inversion").info('"Add Textual Inversion hashes to infotext" needs to be enabled'),
|
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
|
||||||
"lora_filter_disabled": shared.OptionInfo(True, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
|
"lora_functional": shared.OptionInfo(False, "Lora/Networks: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
|
||||||
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
|
}))
|
||||||
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
|
|
||||||
"lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
|
|
||||||
"lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),
|
def create_lora_json(obj: network.NetworkOnDisk):
|
||||||
}))
|
return {
|
||||||
|
"name": obj.name,
|
||||||
|
"alias": obj.alias,
|
||||||
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
|
"path": obj.filename,
|
||||||
"lora_functional": shared.OptionInfo(False, "Lora/Networks: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
|
"metadata": obj.metadata,
|
||||||
}))
|
}
|
||||||
|
|
||||||
|
|
||||||
def create_lora_json(obj: network.NetworkOnDisk):
|
def api_networks(_: gr.Blocks, app: FastAPI):
|
||||||
return {
|
@app.get("/sdapi/v1/loras")
|
||||||
"name": obj.name,
|
async def get_loras():
|
||||||
"alias": obj.alias,
|
return [create_lora_json(obj) for obj in networks.available_networks.values()]
|
||||||
"path": obj.filename,
|
|
||||||
"metadata": obj.metadata,
|
@app.post("/sdapi/v1/refresh-loras")
|
||||||
}
|
async def refresh_loras():
|
||||||
|
return networks.list_available_networks()
|
||||||
|
|
||||||
def api_networks(_: gr.Blocks, app: FastAPI):
|
|
||||||
@app.get("/sdapi/v1/loras")
|
script_callbacks.on_app_started(api_networks)
|
||||||
async def get_loras():
|
|
||||||
return [create_lora_json(obj) for obj in networks.available_networks.values()]
|
re_lora = re.compile("<lora:([^:]+):")
|
||||||
|
|
||||||
@app.post("/sdapi/v1/refresh-loras")
|
|
||||||
async def refresh_loras():
|
def infotext_pasted(infotext, d):
|
||||||
return networks.list_available_networks()
|
hashes = d.get("Lora hashes")
|
||||||
|
if not hashes:
|
||||||
|
return
|
||||||
script_callbacks.on_app_started(api_networks)
|
|
||||||
|
hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
|
||||||
re_lora = re.compile("<lora:([^:]+):")
|
hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
|
||||||
|
|
||||||
|
def network_replacement(m):
|
||||||
def infotext_pasted(infotext, d):
|
alias = m.group(1)
|
||||||
hashes = d.get("Lora hashes")
|
shorthash = hashes.get(alias)
|
||||||
if not hashes:
|
if shorthash is None:
|
||||||
return
|
return m.group(0)
|
||||||
|
|
||||||
hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
|
network_on_disk = networks.available_network_hash_lookup.get(shorthash)
|
||||||
hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
|
if network_on_disk is None:
|
||||||
|
return m.group(0)
|
||||||
def network_replacement(m):
|
|
||||||
alias = m.group(1)
|
return f'<lora:{network_on_disk.get_alias()}:'
|
||||||
shorthash = hashes.get(alias)
|
|
||||||
if shorthash is None:
|
d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
|
||||||
return m.group(0)
|
|
||||||
|
|
||||||
network_on_disk = networks.available_network_hash_lookup.get(shorthash)
|
script_callbacks.on_infotext_pasted(infotext_pasted)
|
||||||
if network_on_disk is None:
|
|
||||||
return m.group(0)
|
|
||||||
|
|
||||||
return f'<lora:{network_on_disk.get_alias()}:'
|
|
||||||
|
|
||||||
d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
|
|
||||||
|
|
||||||
|
|
||||||
script_callbacks.on_infotext_pasted(infotext_pasted)
|
|
||||||
|
|
||||||
shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory)
|
|
||||||
@@ -1,226 +1,226 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import html
|
import html
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from modules import ui_extra_networks_user_metadata
|
from modules import ui_extra_networks_user_metadata
|
||||||
|
|
||||||
|
|
||||||
def is_non_comma_tagset(tags):
|
def is_non_comma_tagset(tags):
|
||||||
average_tag_length = sum(len(x) for x in tags.keys()) / len(tags)
|
average_tag_length = sum(len(x) for x in tags.keys()) / len(tags)
|
||||||
|
|
||||||
return average_tag_length >= 16
|
return average_tag_length >= 16
|
||||||
|
|
||||||
|
|
||||||
re_word = re.compile(r"[-_\w']+")
|
re_word = re.compile(r"[-_\w']+")
|
||||||
re_comma = re.compile(r" *, *")
|
re_comma = re.compile(r" *, *")
|
||||||
|
|
||||||
|
|
||||||
def build_tags(metadata):
|
def build_tags(metadata):
|
||||||
tags = {}
|
tags = {}
|
||||||
|
|
||||||
ss_tag_frequency = metadata.get("ss_tag_frequency", {})
|
ss_tag_frequency = metadata.get("ss_tag_frequency", {})
|
||||||
if ss_tag_frequency is not None and hasattr(ss_tag_frequency, 'items'):
|
if ss_tag_frequency is not None and hasattr(ss_tag_frequency, 'items'):
|
||||||
for _, tags_dict in ss_tag_frequency.items():
|
for _, tags_dict in ss_tag_frequency.items():
|
||||||
for tag, tag_count in tags_dict.items():
|
for tag, tag_count in tags_dict.items():
|
||||||
tag = tag.strip()
|
tag = tag.strip()
|
||||||
tags[tag] = tags.get(tag, 0) + int(tag_count)
|
tags[tag] = tags.get(tag, 0) + int(tag_count)
|
||||||
|
|
||||||
if tags and is_non_comma_tagset(tags):
|
if tags and is_non_comma_tagset(tags):
|
||||||
new_tags = {}
|
new_tags = {}
|
||||||
|
|
||||||
for text, text_count in tags.items():
|
for text, text_count in tags.items():
|
||||||
for word in re.findall(re_word, text):
|
for word in re.findall(re_word, text):
|
||||||
if len(word) < 3:
|
if len(word) < 3:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
new_tags[word] = new_tags.get(word, 0) + text_count
|
new_tags[word] = new_tags.get(word, 0) + text_count
|
||||||
|
|
||||||
tags = new_tags
|
tags = new_tags
|
||||||
|
|
||||||
ordered_tags = sorted(tags.keys(), key=tags.get, reverse=True)
|
ordered_tags = sorted(tags.keys(), key=tags.get, reverse=True)
|
||||||
|
|
||||||
return [(tag, tags[tag]) for tag in ordered_tags]
|
return [(tag, tags[tag]) for tag in ordered_tags]
|
||||||
|
|
||||||
|
|
||||||
class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):
|
class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):
|
||||||
def __init__(self, ui, tabname, page):
|
def __init__(self, ui, tabname, page):
|
||||||
super().__init__(ui, tabname, page)
|
super().__init__(ui, tabname, page)
|
||||||
|
|
||||||
self.select_sd_version = None
|
self.select_sd_version = None
|
||||||
|
|
||||||
self.taginfo = None
|
self.taginfo = None
|
||||||
self.edit_activation_text = None
|
self.edit_activation_text = None
|
||||||
self.slider_preferred_weight = None
|
self.slider_preferred_weight = None
|
||||||
self.edit_notes = None
|
self.edit_notes = None
|
||||||
|
|
||||||
def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes):
|
def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes):
|
||||||
user_metadata = self.get_user_metadata(name)
|
user_metadata = self.get_user_metadata(name)
|
||||||
user_metadata["description"] = desc
|
user_metadata["description"] = desc
|
||||||
user_metadata["sd version"] = sd_version
|
user_metadata["sd version"] = sd_version
|
||||||
user_metadata["activation text"] = activation_text
|
user_metadata["activation text"] = activation_text
|
||||||
user_metadata["preferred weight"] = preferred_weight
|
user_metadata["preferred weight"] = preferred_weight
|
||||||
user_metadata["negative text"] = negative_text
|
user_metadata["negative text"] = negative_text
|
||||||
user_metadata["notes"] = notes
|
user_metadata["notes"] = notes
|
||||||
|
|
||||||
self.write_user_metadata(name, user_metadata)
|
self.write_user_metadata(name, user_metadata)
|
||||||
|
|
||||||
def get_metadata_table(self, name):
|
def get_metadata_table(self, name):
|
||||||
table = super().get_metadata_table(name)
|
table = super().get_metadata_table(name)
|
||||||
item = self.page.items.get(name, {})
|
item = self.page.items.get(name, {})
|
||||||
metadata = item.get("metadata") or {}
|
metadata = item.get("metadata") or {}
|
||||||
|
|
||||||
keys = {
|
keys = {
|
||||||
'ss_output_name': "Output name:",
|
'ss_output_name': "Output name:",
|
||||||
'ss_sd_model_name': "Model:",
|
'ss_sd_model_name': "Model:",
|
||||||
'ss_clip_skip': "Clip skip:",
|
'ss_clip_skip': "Clip skip:",
|
||||||
'ss_network_module': "Kohya module:",
|
'ss_network_module': "Kohya module:",
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, label in keys.items():
|
for key, label in keys.items():
|
||||||
value = metadata.get(key, None)
|
value = metadata.get(key, None)
|
||||||
if value is not None and str(value) != "None":
|
if value is not None and str(value) != "None":
|
||||||
table.append((label, html.escape(value)))
|
table.append((label, html.escape(value)))
|
||||||
|
|
||||||
ss_training_started_at = metadata.get('ss_training_started_at')
|
ss_training_started_at = metadata.get('ss_training_started_at')
|
||||||
if ss_training_started_at:
|
if ss_training_started_at:
|
||||||
table.append(("Date trained:", datetime.datetime.utcfromtimestamp(float(ss_training_started_at)).strftime('%Y-%m-%d %H:%M')))
|
table.append(("Date trained:", datetime.datetime.utcfromtimestamp(float(ss_training_started_at)).strftime('%Y-%m-%d %H:%M')))
|
||||||
|
|
||||||
ss_bucket_info = metadata.get("ss_bucket_info")
|
ss_bucket_info = metadata.get("ss_bucket_info")
|
||||||
if ss_bucket_info and "buckets" in ss_bucket_info:
|
if ss_bucket_info and "buckets" in ss_bucket_info:
|
||||||
resolutions = {}
|
resolutions = {}
|
||||||
for _, bucket in ss_bucket_info["buckets"].items():
|
for _, bucket in ss_bucket_info["buckets"].items():
|
||||||
resolution = bucket["resolution"]
|
resolution = bucket["resolution"]
|
||||||
resolution = f'{resolution[1]}x{resolution[0]}'
|
resolution = f'{resolution[1]}x{resolution[0]}'
|
||||||
|
|
||||||
resolutions[resolution] = resolutions.get(resolution, 0) + int(bucket["count"])
|
resolutions[resolution] = resolutions.get(resolution, 0) + int(bucket["count"])
|
||||||
|
|
||||||
resolutions_list = sorted(resolutions.keys(), key=resolutions.get, reverse=True)
|
resolutions_list = sorted(resolutions.keys(), key=resolutions.get, reverse=True)
|
||||||
resolutions_text = html.escape(", ".join(resolutions_list[0:4]))
|
resolutions_text = html.escape(", ".join(resolutions_list[0:4]))
|
||||||
if len(resolutions) > 4:
|
if len(resolutions) > 4:
|
||||||
resolutions_text += ", ..."
|
resolutions_text += ", ..."
|
||||||
resolutions_text = f"<span title='{html.escape(', '.join(resolutions_list))}'>{resolutions_text}</span>"
|
resolutions_text = f"<span title='{html.escape(', '.join(resolutions_list))}'>{resolutions_text}</span>"
|
||||||
|
|
||||||
table.append(('Resolutions:' if len(resolutions_list) > 1 else 'Resolution:', resolutions_text))
|
table.append(('Resolutions:' if len(resolutions_list) > 1 else 'Resolution:', resolutions_text))
|
||||||
|
|
||||||
image_count = 0
|
image_count = 0
|
||||||
for _, params in metadata.get("ss_dataset_dirs", {}).items():
|
for _, params in metadata.get("ss_dataset_dirs", {}).items():
|
||||||
image_count += int(params.get("img_count", 0))
|
image_count += int(params.get("img_count", 0))
|
||||||
|
|
||||||
if image_count:
|
if image_count:
|
||||||
table.append(("Dataset size:", image_count))
|
table.append(("Dataset size:", image_count))
|
||||||
|
|
||||||
return table
|
return table
|
||||||
|
|
||||||
def put_values_into_components(self, name):
|
def put_values_into_components(self, name):
|
||||||
user_metadata = self.get_user_metadata(name)
|
user_metadata = self.get_user_metadata(name)
|
||||||
values = super().put_values_into_components(name)
|
values = super().put_values_into_components(name)
|
||||||
|
|
||||||
item = self.page.items.get(name, {})
|
item = self.page.items.get(name, {})
|
||||||
metadata = item.get("metadata") or {}
|
metadata = item.get("metadata") or {}
|
||||||
|
|
||||||
tags = build_tags(metadata)
|
tags = build_tags(metadata)
|
||||||
gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]]
|
gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]]
|
||||||
|
|
||||||
return [
|
return [
|
||||||
*values[0:5],
|
*values[0:5],
|
||||||
item.get("sd_version", "Unknown"),
|
item.get("sd_version", "Unknown"),
|
||||||
gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
|
gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
|
||||||
user_metadata.get('activation text', ''),
|
user_metadata.get('activation text', ''),
|
||||||
float(user_metadata.get('preferred weight', 0.0)),
|
float(user_metadata.get('preferred weight', 0.0)),
|
||||||
user_metadata.get('negative text', ''),
|
user_metadata.get('negative text', ''),
|
||||||
gr.update(visible=True if tags else False),
|
gr.update(visible=True if tags else False),
|
||||||
gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
|
gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
|
||||||
]
|
]
|
||||||
|
|
||||||
def generate_random_prompt(self, name):
|
def generate_random_prompt(self, name):
|
||||||
item = self.page.items.get(name, {})
|
item = self.page.items.get(name, {})
|
||||||
metadata = item.get("metadata") or {}
|
metadata = item.get("metadata") or {}
|
||||||
tags = build_tags(metadata)
|
tags = build_tags(metadata)
|
||||||
|
|
||||||
return self.generate_random_prompt_from_tags(tags)
|
return self.generate_random_prompt_from_tags(tags)
|
||||||
|
|
||||||
def generate_random_prompt_from_tags(self, tags):
|
def generate_random_prompt_from_tags(self, tags):
|
||||||
max_count = None
|
max_count = None
|
||||||
res = []
|
res = []
|
||||||
for tag, count in tags:
|
for tag, count in tags:
|
||||||
if not max_count:
|
if not max_count:
|
||||||
max_count = count
|
max_count = count
|
||||||
|
|
||||||
v = random.random() * max_count
|
v = random.random() * max_count
|
||||||
if count > v:
|
if count > v:
|
||||||
for x in "({[]})":
|
for x in "({[]})":
|
||||||
tag = tag.replace(x, '\\' + x)
|
tag = tag.replace(x, '\\' + x)
|
||||||
res.append(tag)
|
res.append(tag)
|
||||||
|
|
||||||
return ", ".join(sorted(res))
|
return ", ".join(sorted(res))
|
||||||
|
|
||||||
def create_extra_default_items_in_left_column(self):
|
def create_extra_default_items_in_left_column(self):
|
||||||
|
|
||||||
# this would be a lot better as gr.Radio but I can't make it work
|
# this would be a lot better as gr.Radio but I can't make it work
|
||||||
self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True)
|
self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True)
|
||||||
|
|
||||||
def create_editor(self):
|
def create_editor(self):
|
||||||
self.create_default_editor_elems()
|
self.create_default_editor_elems()
|
||||||
|
|
||||||
self.taginfo = gr.HighlightedText(label="Training dataset tags")
|
self.taginfo = gr.HighlightedText(label="Training dataset tags")
|
||||||
self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
|
self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
|
||||||
self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
|
self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
|
||||||
self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts")
|
self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts")
|
||||||
with gr.Row() as row_random_prompt:
|
with gr.Row() as row_random_prompt:
|
||||||
with gr.Column(scale=8):
|
with gr.Column(scale=8):
|
||||||
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
|
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
|
||||||
|
|
||||||
with gr.Column(scale=1, min_width=120):
|
with gr.Column(scale=1, min_width=120):
|
||||||
generate_random_prompt = gr.Button('Generate', size="lg", scale=1)
|
generate_random_prompt = gr.Button('Generate', size="lg", scale=1)
|
||||||
|
|
||||||
self.edit_notes = gr.TextArea(label='Notes', lines=4)
|
self.edit_notes = gr.TextArea(label='Notes', lines=4)
|
||||||
|
|
||||||
generate_random_prompt.click(fn=self.generate_random_prompt, inputs=[self.edit_name_input], outputs=[random_prompt], show_progress=False)
|
generate_random_prompt.click(fn=self.generate_random_prompt, inputs=[self.edit_name_input], outputs=[random_prompt], show_progress=False)
|
||||||
|
|
||||||
def select_tag(activation_text, evt: gr.SelectData):
|
def select_tag(activation_text, evt: gr.SelectData):
|
||||||
tag = evt.value[0]
|
tag = evt.value[0]
|
||||||
|
|
||||||
words = re.split(re_comma, activation_text)
|
words = re.split(re_comma, activation_text)
|
||||||
if tag in words:
|
if tag in words:
|
||||||
words = [x for x in words if x != tag and x.strip()]
|
words = [x for x in words if x != tag and x.strip()]
|
||||||
return ", ".join(words)
|
return ", ".join(words)
|
||||||
|
|
||||||
return activation_text + ", " + tag if activation_text else tag
|
return activation_text + ", " + tag if activation_text else tag
|
||||||
|
|
||||||
self.taginfo.select(fn=select_tag, inputs=[self.edit_activation_text], outputs=[self.edit_activation_text], show_progress=False)
|
self.taginfo.select(fn=select_tag, inputs=[self.edit_activation_text], outputs=[self.edit_activation_text], show_progress=False)
|
||||||
|
|
||||||
self.create_default_buttons()
|
self.create_default_buttons()
|
||||||
|
|
||||||
viewed_components = [
|
viewed_components = [
|
||||||
self.edit_name,
|
self.edit_name,
|
||||||
self.edit_description,
|
self.edit_description,
|
||||||
self.html_filedata,
|
self.html_filedata,
|
||||||
self.html_preview,
|
self.html_preview,
|
||||||
self.edit_notes,
|
self.edit_notes,
|
||||||
self.select_sd_version,
|
self.select_sd_version,
|
||||||
self.taginfo,
|
self.taginfo,
|
||||||
self.edit_activation_text,
|
self.edit_activation_text,
|
||||||
self.slider_preferred_weight,
|
self.slider_preferred_weight,
|
||||||
self.edit_negative_text,
|
self.edit_negative_text,
|
||||||
row_random_prompt,
|
row_random_prompt,
|
||||||
random_prompt,
|
random_prompt,
|
||||||
]
|
]
|
||||||
|
|
||||||
self.button_edit\
|
self.button_edit\
|
||||||
.click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\
|
.click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\
|
||||||
.then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
|
.then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
|
||||||
|
|
||||||
edited_components = [
|
edited_components = [
|
||||||
self.edit_description,
|
self.edit_description,
|
||||||
self.select_sd_version,
|
self.select_sd_version,
|
||||||
self.edit_activation_text,
|
self.edit_activation_text,
|
||||||
self.slider_preferred_weight,
|
self.slider_preferred_weight,
|
||||||
self.edit_negative_text,
|
self.edit_negative_text,
|
||||||
self.edit_notes,
|
self.edit_notes,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components)
|
self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components)
|
||||||
@@ -1,90 +1,69 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import network
|
import network
|
||||||
import networks
|
import networks
|
||||||
|
|
||||||
from modules import shared, ui_extra_networks
|
from modules import shared, ui_extra_networks
|
||||||
from modules.ui_extra_networks import quote_js
|
from modules.ui_extra_networks import quote_js
|
||||||
from ui_edit_user_metadata import LoraUserMetadataEditor
|
from ui_edit_user_metadata import LoraUserMetadataEditor
|
||||||
|
|
||||||
|
|
||||||
class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__('Lora')
|
super().__init__('Lora')
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
networks.list_available_networks()
|
networks.list_available_networks()
|
||||||
|
|
||||||
def create_item(self, name, index=None, enable_filter=True):
|
def create_item(self, name, index=None, enable_filter=True):
|
||||||
lora_on_disk = networks.available_networks.get(name)
|
lora_on_disk = networks.available_networks.get(name)
|
||||||
if lora_on_disk is None:
|
if lora_on_disk is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
path, ext = os.path.splitext(lora_on_disk.filename)
|
path, ext = os.path.splitext(lora_on_disk.filename)
|
||||||
|
|
||||||
alias = lora_on_disk.get_alias()
|
alias = lora_on_disk.get_alias()
|
||||||
|
|
||||||
search_terms = [self.search_terms_from_path(lora_on_disk.filename)]
|
search_terms = [self.search_terms_from_path(lora_on_disk.filename)]
|
||||||
if lora_on_disk.hash:
|
if lora_on_disk.hash:
|
||||||
search_terms.append(lora_on_disk.hash)
|
search_terms.append(lora_on_disk.hash)
|
||||||
item = {
|
item = {
|
||||||
"name": name,
|
"name": name,
|
||||||
"filename": lora_on_disk.filename,
|
"filename": lora_on_disk.filename,
|
||||||
"shorthash": lora_on_disk.shorthash,
|
"shorthash": lora_on_disk.shorthash,
|
||||||
"preview": self.find_preview(path) or self.find_embedded_preview(path, name, lora_on_disk.metadata),
|
"preview": self.find_preview(path) or self.find_embedded_preview(path, name, lora_on_disk.metadata),
|
||||||
"description": self.find_description(path),
|
"description": self.find_description(path),
|
||||||
"search_terms": search_terms,
|
"search_terms": search_terms,
|
||||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||||
"metadata": lora_on_disk.metadata,
|
"metadata": lora_on_disk.metadata,
|
||||||
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
|
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
|
||||||
"sd_version": lora_on_disk.sd_version.name,
|
}
|
||||||
}
|
|
||||||
|
self.read_user_metadata(item)
|
||||||
self.read_user_metadata(item)
|
activation_text = item["user_metadata"].get("activation text")
|
||||||
activation_text = item["user_metadata"].get("activation text")
|
preferred_weight = item["user_metadata"].get("preferred weight", 0.0)
|
||||||
preferred_weight = item["user_metadata"].get("preferred weight", 0.0)
|
item["prompt"] = quote_js(f"<lora:{alias}:") + " + " + (str(preferred_weight) if preferred_weight else "opts.extra_networks_default_multiplier") + " + " + quote_js(">")
|
||||||
item["prompt"] = quote_js(f"<lora:{alias}:") + " + " + (str(preferred_weight) if preferred_weight else "opts.extra_networks_default_multiplier") + " + " + quote_js(">")
|
|
||||||
|
if activation_text:
|
||||||
if activation_text:
|
item["prompt"] += " + " + quote_js(" " + activation_text)
|
||||||
item["prompt"] += " + " + quote_js(" " + activation_text)
|
|
||||||
|
negative_prompt = item["user_metadata"].get("negative text")
|
||||||
negative_prompt = item["user_metadata"].get("negative text")
|
item["negative_prompt"] = quote_js("")
|
||||||
item["negative_prompt"] = quote_js("")
|
if negative_prompt:
|
||||||
if negative_prompt:
|
item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)')
|
||||||
item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)')
|
|
||||||
|
return item
|
||||||
sd_version = item["user_metadata"].get("sd version")
|
|
||||||
if sd_version in network.SdVersion.__members__:
|
def list_items(self):
|
||||||
item["sd_version"] = sd_version
|
# instantiate a list to protect against concurrent modification
|
||||||
sd_version = network.SdVersion[sd_version]
|
names = list(networks.available_networks)
|
||||||
else:
|
for index, name in enumerate(names):
|
||||||
sd_version = lora_on_disk.sd_version
|
item = self.create_item(name, index)
|
||||||
|
if item is not None:
|
||||||
if shared.opts.lora_filter_disabled or not enable_filter or not shared.sd_model:
|
yield item
|
||||||
pass
|
|
||||||
elif sd_version == network.SdVersion.Unknown:
|
def allowed_directories_for_previews(self):
|
||||||
model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1
|
return [shared.cmd_opts.lora_dir]
|
||||||
if model_version.name in shared.opts.lora_hide_unknown_for_versions:
|
|
||||||
return None
|
def create_user_metadata_editor(self, ui, tabname):
|
||||||
elif shared.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL:
|
return LoraUserMetadataEditor(ui, tabname, self)
|
||||||
return None
|
|
||||||
elif shared.sd_model.is_sd2 and sd_version != network.SdVersion.SD2:
|
|
||||||
return None
|
|
||||||
elif shared.sd_model.is_sd1 and sd_version != network.SdVersion.SD1:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return item
|
|
||||||
|
|
||||||
def list_items(self):
|
|
||||||
# instantiate a list to protect against concurrent modification
|
|
||||||
names = list(networks.available_networks)
|
|
||||||
for index, name in enumerate(names):
|
|
||||||
item = self.create_item(name, index)
|
|
||||||
if item is not None:
|
|
||||||
yield item
|
|
||||||
|
|
||||||
def allowed_directories_for_previews(self):
|
|
||||||
return [shared.cmd_opts.lora_dir]
|
|
||||||
|
|
||||||
def create_user_metadata_editor(self, ui, tabname):
|
|
||||||
return LoraUserMetadataEditor(ui, tabname, self)
|
|
||||||
@@ -1,226 +1 @@
|
|||||||
# 1st edit by https://github.com/comfyanonymous/ComfyUI
|
from backend.patcher.lora import *
|
||||||
# 2nd edit by Forge Official
|
|
||||||
|
|
||||||
|
|
||||||
import ldm_patched.modules.utils
|
|
||||||
|
|
||||||
LORA_CLIP_MAP = {
|
|
||||||
"mlp.fc1": "mlp_fc1",
|
|
||||||
"mlp.fc2": "mlp_fc2",
|
|
||||||
"self_attn.k_proj": "self_attn_k_proj",
|
|
||||||
"self_attn.q_proj": "self_attn_q_proj",
|
|
||||||
"self_attn.v_proj": "self_attn_v_proj",
|
|
||||||
"self_attn.out_proj": "self_attn_out_proj",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def load_lora(lora, to_load):
|
|
||||||
patch_dict = {}
|
|
||||||
loaded_keys = set()
|
|
||||||
for x in to_load:
|
|
||||||
alpha_name = "{}.alpha".format(x)
|
|
||||||
alpha = None
|
|
||||||
if alpha_name in lora.keys():
|
|
||||||
alpha = lora[alpha_name].item()
|
|
||||||
loaded_keys.add(alpha_name)
|
|
||||||
|
|
||||||
regular_lora = "{}.lora_up.weight".format(x)
|
|
||||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
|
||||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
|
||||||
A_name = None
|
|
||||||
|
|
||||||
if regular_lora in lora.keys():
|
|
||||||
A_name = regular_lora
|
|
||||||
B_name = "{}.lora_down.weight".format(x)
|
|
||||||
mid_name = "{}.lora_mid.weight".format(x)
|
|
||||||
elif diffusers_lora in lora.keys():
|
|
||||||
A_name = diffusers_lora
|
|
||||||
B_name = "{}_lora.down.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif transformers_lora in lora.keys():
|
|
||||||
A_name = transformers_lora
|
|
||||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
|
|
||||||
if A_name is not None:
|
|
||||||
mid = None
|
|
||||||
if mid_name is not None and mid_name in lora.keys():
|
|
||||||
mid = lora[mid_name]
|
|
||||||
loaded_keys.add(mid_name)
|
|
||||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
|
|
||||||
loaded_keys.add(A_name)
|
|
||||||
loaded_keys.add(B_name)
|
|
||||||
|
|
||||||
|
|
||||||
######## loha
|
|
||||||
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
|
||||||
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
|
||||||
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
|
||||||
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
|
||||||
hada_t1_name = "{}.hada_t1".format(x)
|
|
||||||
hada_t2_name = "{}.hada_t2".format(x)
|
|
||||||
if hada_w1_a_name in lora.keys():
|
|
||||||
hada_t1 = None
|
|
||||||
hada_t2 = None
|
|
||||||
if hada_t1_name in lora.keys():
|
|
||||||
hada_t1 = lora[hada_t1_name]
|
|
||||||
hada_t2 = lora[hada_t2_name]
|
|
||||||
loaded_keys.add(hada_t1_name)
|
|
||||||
loaded_keys.add(hada_t2_name)
|
|
||||||
|
|
||||||
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
|
|
||||||
loaded_keys.add(hada_w1_a_name)
|
|
||||||
loaded_keys.add(hada_w1_b_name)
|
|
||||||
loaded_keys.add(hada_w2_a_name)
|
|
||||||
loaded_keys.add(hada_w2_b_name)
|
|
||||||
|
|
||||||
|
|
||||||
######## lokr
|
|
||||||
lokr_w1_name = "{}.lokr_w1".format(x)
|
|
||||||
lokr_w2_name = "{}.lokr_w2".format(x)
|
|
||||||
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
|
||||||
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
|
||||||
lokr_t2_name = "{}.lokr_t2".format(x)
|
|
||||||
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
|
||||||
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
|
||||||
|
|
||||||
lokr_w1 = None
|
|
||||||
if lokr_w1_name in lora.keys():
|
|
||||||
lokr_w1 = lora[lokr_w1_name]
|
|
||||||
loaded_keys.add(lokr_w1_name)
|
|
||||||
|
|
||||||
lokr_w2 = None
|
|
||||||
if lokr_w2_name in lora.keys():
|
|
||||||
lokr_w2 = lora[lokr_w2_name]
|
|
||||||
loaded_keys.add(lokr_w2_name)
|
|
||||||
|
|
||||||
lokr_w1_a = None
|
|
||||||
if lokr_w1_a_name in lora.keys():
|
|
||||||
lokr_w1_a = lora[lokr_w1_a_name]
|
|
||||||
loaded_keys.add(lokr_w1_a_name)
|
|
||||||
|
|
||||||
lokr_w1_b = None
|
|
||||||
if lokr_w1_b_name in lora.keys():
|
|
||||||
lokr_w1_b = lora[lokr_w1_b_name]
|
|
||||||
loaded_keys.add(lokr_w1_b_name)
|
|
||||||
|
|
||||||
lokr_w2_a = None
|
|
||||||
if lokr_w2_a_name in lora.keys():
|
|
||||||
lokr_w2_a = lora[lokr_w2_a_name]
|
|
||||||
loaded_keys.add(lokr_w2_a_name)
|
|
||||||
|
|
||||||
lokr_w2_b = None
|
|
||||||
if lokr_w2_b_name in lora.keys():
|
|
||||||
lokr_w2_b = lora[lokr_w2_b_name]
|
|
||||||
loaded_keys.add(lokr_w2_b_name)
|
|
||||||
|
|
||||||
lokr_t2 = None
|
|
||||||
if lokr_t2_name in lora.keys():
|
|
||||||
lokr_t2 = lora[lokr_t2_name]
|
|
||||||
loaded_keys.add(lokr_t2_name)
|
|
||||||
|
|
||||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
|
||||||
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
|
|
||||||
|
|
||||||
#glora
|
|
||||||
a1_name = "{}.a1.weight".format(x)
|
|
||||||
a2_name = "{}.a2.weight".format(x)
|
|
||||||
b1_name = "{}.b1.weight".format(x)
|
|
||||||
b2_name = "{}.b2.weight".format(x)
|
|
||||||
if a1_name in lora:
|
|
||||||
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
|
|
||||||
loaded_keys.add(a1_name)
|
|
||||||
loaded_keys.add(a2_name)
|
|
||||||
loaded_keys.add(b1_name)
|
|
||||||
loaded_keys.add(b2_name)
|
|
||||||
|
|
||||||
w_norm_name = "{}.w_norm".format(x)
|
|
||||||
b_norm_name = "{}.b_norm".format(x)
|
|
||||||
w_norm = lora.get(w_norm_name, None)
|
|
||||||
b_norm = lora.get(b_norm_name, None)
|
|
||||||
|
|
||||||
if w_norm is not None:
|
|
||||||
loaded_keys.add(w_norm_name)
|
|
||||||
patch_dict[to_load[x]] = ("diff", (w_norm,))
|
|
||||||
if b_norm is not None:
|
|
||||||
loaded_keys.add(b_norm_name)
|
|
||||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
|
|
||||||
|
|
||||||
diff_name = "{}.diff".format(x)
|
|
||||||
diff_weight = lora.get(diff_name, None)
|
|
||||||
if diff_weight is not None:
|
|
||||||
patch_dict[to_load[x]] = ("diff", (diff_weight,))
|
|
||||||
loaded_keys.add(diff_name)
|
|
||||||
|
|
||||||
diff_bias_name = "{}.diff_b".format(x)
|
|
||||||
diff_bias = lora.get(diff_bias_name, None)
|
|
||||||
if diff_bias is not None:
|
|
||||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
|
||||||
loaded_keys.add(diff_bias_name)
|
|
||||||
|
|
||||||
remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
|
|
||||||
return patch_dict, remaining_dict
|
|
||||||
|
|
||||||
def model_lora_keys_clip(model, key_map={}):
|
|
||||||
sdk = model.state_dict().keys()
|
|
||||||
|
|
||||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
|
||||||
clip_l_present = False
|
|
||||||
for b in range(32): #TODO: clean up
|
|
||||||
for c in LORA_CLIP_MAP:
|
|
||||||
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
|
||||||
key_map[lora_key] = k
|
|
||||||
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
|
||||||
key_map[lora_key] = k
|
|
||||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
|
||||||
key_map[lora_key] = k
|
|
||||||
|
|
||||||
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
|
||||||
key_map[lora_key] = k
|
|
||||||
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
|
||||||
key_map[lora_key] = k
|
|
||||||
clip_l_present = True
|
|
||||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
|
||||||
key_map[lora_key] = k
|
|
||||||
|
|
||||||
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
|
||||||
if k in sdk:
|
|
||||||
if clip_l_present:
|
|
||||||
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
|
||||||
key_map[lora_key] = k
|
|
||||||
lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
|
||||||
key_map[lora_key] = k
|
|
||||||
else:
|
|
||||||
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
|
|
||||||
key_map[lora_key] = k
|
|
||||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
|
||||||
key_map[lora_key] = k
|
|
||||||
|
|
||||||
return key_map
|
|
||||||
|
|
||||||
def model_lora_keys_unet(model, key_map={}):
|
|
||||||
sdk = model.state_dict().keys()
|
|
||||||
|
|
||||||
for k in sdk:
|
|
||||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
|
||||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
|
||||||
key_map["lora_unet_{}".format(key_lora)] = k
|
|
||||||
|
|
||||||
diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(model.diffusion_model.legacy_config)
|
|
||||||
for k in diffusers_keys:
|
|
||||||
if k.endswith(".weight"):
|
|
||||||
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
|
||||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
|
||||||
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
|
||||||
|
|
||||||
diffusers_lora_prefix = ["", "unet."]
|
|
||||||
for p in diffusers_lora_prefix:
|
|
||||||
diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
|
|
||||||
if diffusers_lora_key.endswith(".to_out.0"):
|
|
||||||
diffusers_lora_key = diffusers_lora_key[:-2]
|
|
||||||
key_map[diffusers_lora_key] = unet_key
|
|
||||||
return key_map
|
|
||||||
|
|||||||
@@ -1,386 +1 @@
|
|||||||
# 1st edit by https://github.com/comfyanonymous/ComfyUI
|
from backend.patcher.base import *
|
||||||
# 2nd edit by Forge Official
|
|
||||||
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import copy
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
import ldm_patched.modules.utils
|
|
||||||
import ldm_patched.modules.model_management
|
|
||||||
|
|
||||||
|
|
||||||
extra_weight_calculators = {}
|
|
||||||
|
|
||||||
|
|
||||||
class ModelPatcher:
|
|
||||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
|
||||||
self.size = size
|
|
||||||
self.model = model
|
|
||||||
self.patches = {}
|
|
||||||
self.backup = {}
|
|
||||||
self.object_patches = {}
|
|
||||||
self.object_patches_backup = {}
|
|
||||||
self.model_options = {"transformer_options":{}}
|
|
||||||
self.model_size()
|
|
||||||
self.load_device = load_device
|
|
||||||
self.offload_device = offload_device
|
|
||||||
if current_device is None:
|
|
||||||
self.current_device = self.offload_device
|
|
||||||
else:
|
|
||||||
self.current_device = current_device
|
|
||||||
|
|
||||||
self.weight_inplace_update = weight_inplace_update
|
|
||||||
|
|
||||||
def model_size(self):
|
|
||||||
if self.size > 0:
|
|
||||||
return self.size
|
|
||||||
model_sd = self.model.state_dict()
|
|
||||||
self.size = ldm_patched.modules.model_management.module_size(self.model)
|
|
||||||
self.model_keys = set(model_sd.keys())
|
|
||||||
return self.size
|
|
||||||
|
|
||||||
def clone(self):
|
|
||||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
|
|
||||||
n.patches = {}
|
|
||||||
for k in self.patches:
|
|
||||||
n.patches[k] = self.patches[k][:]
|
|
||||||
|
|
||||||
n.object_patches = self.object_patches.copy()
|
|
||||||
n.model_options = copy.deepcopy(self.model_options)
|
|
||||||
n.model_keys = self.model_keys
|
|
||||||
return n
|
|
||||||
|
|
||||||
def is_clone(self, other):
|
|
||||||
if hasattr(other, 'model') and self.model is other.model:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def memory_required(self, input_shape):
|
|
||||||
return self.model.memory_required(input_shape=input_shape)
|
|
||||||
|
|
||||||
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
|
||||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
|
||||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
|
||||||
else:
|
|
||||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
|
||||||
if disable_cfg1_optimization:
|
|
||||||
self.model_options["disable_cfg1_optimization"] = True
|
|
||||||
|
|
||||||
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
|
||||||
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
|
|
||||||
if disable_cfg1_optimization:
|
|
||||||
self.model_options["disable_cfg1_optimization"] = True
|
|
||||||
|
|
||||||
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
|
||||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
|
||||||
|
|
||||||
def set_model_vae_encode_wrapper(self, wrapper_function):
|
|
||||||
self.model_options["model_vae_encode_wrapper"] = wrapper_function
|
|
||||||
|
|
||||||
def set_model_vae_decode_wrapper(self, wrapper_function):
|
|
||||||
self.model_options["model_vae_decode_wrapper"] = wrapper_function
|
|
||||||
|
|
||||||
def set_model_vae_regulation(self, vae_regulation):
|
|
||||||
self.model_options["model_vae_regulation"] = vae_regulation
|
|
||||||
|
|
||||||
def set_model_patch(self, patch, name):
|
|
||||||
to = self.model_options["transformer_options"]
|
|
||||||
if "patches" not in to:
|
|
||||||
to["patches"] = {}
|
|
||||||
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
|
||||||
|
|
||||||
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
|
|
||||||
to = self.model_options["transformer_options"]
|
|
||||||
if "patches_replace" not in to:
|
|
||||||
to["patches_replace"] = {}
|
|
||||||
if name not in to["patches_replace"]:
|
|
||||||
to["patches_replace"][name] = {}
|
|
||||||
if transformer_index is not None:
|
|
||||||
block = (block_name, number, transformer_index)
|
|
||||||
else:
|
|
||||||
block = (block_name, number)
|
|
||||||
to["patches_replace"][name][block] = patch
|
|
||||||
|
|
||||||
def set_model_attn1_patch(self, patch):
|
|
||||||
self.set_model_patch(patch, "attn1_patch")
|
|
||||||
|
|
||||||
def set_model_attn2_patch(self, patch):
|
|
||||||
self.set_model_patch(patch, "attn2_patch")
|
|
||||||
|
|
||||||
def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
|
|
||||||
self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
|
|
||||||
|
|
||||||
def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
|
|
||||||
self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
|
|
||||||
|
|
||||||
def set_model_attn1_output_patch(self, patch):
|
|
||||||
self.set_model_patch(patch, "attn1_output_patch")
|
|
||||||
|
|
||||||
def set_model_attn2_output_patch(self, patch):
|
|
||||||
self.set_model_patch(patch, "attn2_output_patch")
|
|
||||||
|
|
||||||
def set_model_input_block_patch(self, patch):
|
|
||||||
self.set_model_patch(patch, "input_block_patch")
|
|
||||||
|
|
||||||
def set_model_input_block_patch_after_skip(self, patch):
|
|
||||||
self.set_model_patch(patch, "input_block_patch_after_skip")
|
|
||||||
|
|
||||||
def set_model_output_block_patch(self, patch):
|
|
||||||
self.set_model_patch(patch, "output_block_patch")
|
|
||||||
|
|
||||||
def add_object_patch(self, name, obj):
|
|
||||||
self.object_patches[name] = obj
|
|
||||||
|
|
||||||
def model_patches_to(self, device):
|
|
||||||
to = self.model_options["transformer_options"]
|
|
||||||
if "patches" in to:
|
|
||||||
patches = to["patches"]
|
|
||||||
for name in patches:
|
|
||||||
patch_list = patches[name]
|
|
||||||
for i in range(len(patch_list)):
|
|
||||||
if hasattr(patch_list[i], "to"):
|
|
||||||
patch_list[i] = patch_list[i].to(device)
|
|
||||||
if "patches_replace" in to:
|
|
||||||
patches = to["patches_replace"]
|
|
||||||
for name in patches:
|
|
||||||
patch_list = patches[name]
|
|
||||||
for k in patch_list:
|
|
||||||
if hasattr(patch_list[k], "to"):
|
|
||||||
patch_list[k] = patch_list[k].to(device)
|
|
||||||
if "model_function_wrapper" in self.model_options:
|
|
||||||
wrap_func = self.model_options["model_function_wrapper"]
|
|
||||||
if hasattr(wrap_func, "to"):
|
|
||||||
self.model_options["model_function_wrapper"] = wrap_func.to(device)
|
|
||||||
|
|
||||||
def model_dtype(self):
|
|
||||||
if hasattr(self.model, "get_dtype"):
|
|
||||||
return self.model.get_dtype()
|
|
||||||
|
|
||||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
|
||||||
p = set()
|
|
||||||
for k in patches:
|
|
||||||
if k in self.model_keys:
|
|
||||||
p.add(k)
|
|
||||||
current_patches = self.patches.get(k, [])
|
|
||||||
current_patches.append((strength_patch, patches[k], strength_model))
|
|
||||||
self.patches[k] = current_patches
|
|
||||||
|
|
||||||
return list(p)
|
|
||||||
|
|
||||||
def get_key_patches(self, filter_prefix=None):
|
|
||||||
ldm_patched.modules.model_management.unload_model_clones(self)
|
|
||||||
model_sd = self.model_state_dict()
|
|
||||||
p = {}
|
|
||||||
for k in model_sd:
|
|
||||||
if filter_prefix is not None:
|
|
||||||
if not k.startswith(filter_prefix):
|
|
||||||
continue
|
|
||||||
if k in self.patches:
|
|
||||||
p[k] = [model_sd[k]] + self.patches[k]
|
|
||||||
else:
|
|
||||||
p[k] = (model_sd[k],)
|
|
||||||
return p
|
|
||||||
|
|
||||||
def model_state_dict(self, filter_prefix=None):
|
|
||||||
sd = self.model.state_dict()
|
|
||||||
keys = list(sd.keys())
|
|
||||||
if filter_prefix is not None:
|
|
||||||
for k in keys:
|
|
||||||
if not k.startswith(filter_prefix):
|
|
||||||
sd.pop(k)
|
|
||||||
return sd
|
|
||||||
|
|
||||||
def patch_model(self, device_to=None, patch_weights=True):
|
|
||||||
for k in self.object_patches:
|
|
||||||
old = ldm_patched.modules.utils.get_attr(self.model, k)
|
|
||||||
if k not in self.object_patches_backup:
|
|
||||||
self.object_patches_backup[k] = old
|
|
||||||
ldm_patched.modules.utils.set_attr_raw(self.model, k, self.object_patches[k])
|
|
||||||
|
|
||||||
if patch_weights:
|
|
||||||
model_sd = self.model_state_dict()
|
|
||||||
for key in self.patches:
|
|
||||||
if key not in model_sd:
|
|
||||||
print("could not patch. key doesn't exist in model:", key)
|
|
||||||
continue
|
|
||||||
|
|
||||||
weight = model_sd[key]
|
|
||||||
|
|
||||||
inplace_update = self.weight_inplace_update
|
|
||||||
|
|
||||||
if key not in self.backup:
|
|
||||||
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
|
||||||
|
|
||||||
if device_to is not None:
|
|
||||||
temp_weight = ldm_patched.modules.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
|
||||||
else:
|
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
|
||||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
|
||||||
if inplace_update:
|
|
||||||
ldm_patched.modules.utils.copy_to_param(self.model, key, out_weight)
|
|
||||||
else:
|
|
||||||
ldm_patched.modules.utils.set_attr(self.model, key, out_weight)
|
|
||||||
del temp_weight
|
|
||||||
|
|
||||||
if device_to is not None:
|
|
||||||
self.model.to(device_to)
|
|
||||||
self.current_device = device_to
|
|
||||||
|
|
||||||
return self.model
|
|
||||||
|
|
||||||
def calculate_weight(self, patches, weight, key):
|
|
||||||
for p in patches:
|
|
||||||
alpha = p[0]
|
|
||||||
v = p[1]
|
|
||||||
strength_model = p[2]
|
|
||||||
|
|
||||||
if strength_model != 1.0:
|
|
||||||
weight *= strength_model
|
|
||||||
|
|
||||||
if isinstance(v, list):
|
|
||||||
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
|
|
||||||
|
|
||||||
if len(v) == 1:
|
|
||||||
patch_type = "diff"
|
|
||||||
elif len(v) == 2:
|
|
||||||
patch_type = v[0]
|
|
||||||
v = v[1]
|
|
||||||
|
|
||||||
if patch_type == "diff":
|
|
||||||
w1 = v[0]
|
|
||||||
if alpha != 0.0:
|
|
||||||
if w1.shape != weight.shape:
|
|
||||||
if w1.ndim == weight.ndim == 4:
|
|
||||||
new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)]
|
|
||||||
print(f'Merged with {key} channel changed to {new_shape}')
|
|
||||||
new_diff = alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype)
|
|
||||||
new_weight = torch.zeros(size=new_shape).to(weight)
|
|
||||||
new_weight[:weight.shape[0], :weight.shape[1], :weight.shape[2], :weight.shape[3]] = weight
|
|
||||||
new_weight[:new_diff.shape[0], :new_diff.shape[1], :new_diff.shape[2], :new_diff.shape[3]] += new_diff
|
|
||||||
new_weight = new_weight.contiguous().clone()
|
|
||||||
weight = new_weight
|
|
||||||
else:
|
|
||||||
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
|
||||||
else:
|
|
||||||
weight += alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype)
|
|
||||||
elif patch_type == "lora": #lora/locon
|
|
||||||
mat1 = ldm_patched.modules.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
|
||||||
mat2 = ldm_patched.modules.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha *= v[2] / mat2.shape[0]
|
|
||||||
if v[3] is not None:
|
|
||||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
|
||||||
mat3 = ldm_patched.modules.model_management.cast_to_device(v[3], weight.device, torch.float32)
|
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
|
||||||
try:
|
|
||||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
|
||||||
except Exception as e:
|
|
||||||
print("ERROR", key, e)
|
|
||||||
elif patch_type == "lokr":
|
|
||||||
w1 = v[0]
|
|
||||||
w2 = v[1]
|
|
||||||
w1_a = v[3]
|
|
||||||
w1_b = v[4]
|
|
||||||
w2_a = v[5]
|
|
||||||
w2_b = v[6]
|
|
||||||
t2 = v[7]
|
|
||||||
dim = None
|
|
||||||
|
|
||||||
if w1 is None:
|
|
||||||
dim = w1_b.shape[0]
|
|
||||||
w1 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w1_a, weight.device, torch.float32),
|
|
||||||
ldm_patched.modules.model_management.cast_to_device(w1_b, weight.device, torch.float32))
|
|
||||||
else:
|
|
||||||
w1 = ldm_patched.modules.model_management.cast_to_device(w1, weight.device, torch.float32)
|
|
||||||
|
|
||||||
if w2 is None:
|
|
||||||
dim = w2_b.shape[0]
|
|
||||||
if t2 is None:
|
|
||||||
w2 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w2_a, weight.device, torch.float32),
|
|
||||||
ldm_patched.modules.model_management.cast_to_device(w2_b, weight.device, torch.float32))
|
|
||||||
else:
|
|
||||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
ldm_patched.modules.model_management.cast_to_device(t2, weight.device, torch.float32),
|
|
||||||
ldm_patched.modules.model_management.cast_to_device(w2_b, weight.device, torch.float32),
|
|
||||||
ldm_patched.modules.model_management.cast_to_device(w2_a, weight.device, torch.float32))
|
|
||||||
else:
|
|
||||||
w2 = ldm_patched.modules.model_management.cast_to_device(w2, weight.device, torch.float32)
|
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
||||||
if v[2] is not None and dim is not None:
|
|
||||||
alpha *= v[2] / dim
|
|
||||||
|
|
||||||
try:
|
|
||||||
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
|
||||||
except Exception as e:
|
|
||||||
print("ERROR", key, e)
|
|
||||||
elif patch_type == "loha":
|
|
||||||
w1a = v[0]
|
|
||||||
w1b = v[1]
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha *= v[2] / w1b.shape[0]
|
|
||||||
w2a = v[3]
|
|
||||||
w2b = v[4]
|
|
||||||
if v[5] is not None: #cp decomposition
|
|
||||||
t1 = v[5]
|
|
||||||
t2 = v[6]
|
|
||||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
ldm_patched.modules.model_management.cast_to_device(t1, weight.device, torch.float32),
|
|
||||||
ldm_patched.modules.model_management.cast_to_device(w1b, weight.device, torch.float32),
|
|
||||||
ldm_patched.modules.model_management.cast_to_device(w1a, weight.device, torch.float32))
|
|
||||||
|
|
||||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
ldm_patched.modules.model_management.cast_to_device(t2, weight.device, torch.float32),
|
|
||||||
ldm_patched.modules.model_management.cast_to_device(w2b, weight.device, torch.float32),
|
|
||||||
ldm_patched.modules.model_management.cast_to_device(w2a, weight.device, torch.float32))
|
|
||||||
else:
|
|
||||||
m1 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w1a, weight.device, torch.float32),
|
|
||||||
ldm_patched.modules.model_management.cast_to_device(w1b, weight.device, torch.float32))
|
|
||||||
m2 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w2a, weight.device, torch.float32),
|
|
||||||
ldm_patched.modules.model_management.cast_to_device(w2b, weight.device, torch.float32))
|
|
||||||
|
|
||||||
try:
|
|
||||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
|
||||||
except Exception as e:
|
|
||||||
print("ERROR", key, e)
|
|
||||||
elif patch_type == "glora":
|
|
||||||
if v[4] is not None:
|
|
||||||
alpha *= v[4] / v[0].shape[0]
|
|
||||||
|
|
||||||
a1 = ldm_patched.modules.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
|
|
||||||
a2 = ldm_patched.modules.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
|
|
||||||
b1 = ldm_patched.modules.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
|
|
||||||
b2 = ldm_patched.modules.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
|
||||||
|
|
||||||
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
|
|
||||||
elif patch_type in extra_weight_calculators:
|
|
||||||
weight = extra_weight_calculators[patch_type](weight, alpha, v)
|
|
||||||
else:
|
|
||||||
print("patch type not recognized", patch_type, key)
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def unpatch_model(self, device_to=None):
|
|
||||||
keys = list(self.backup.keys())
|
|
||||||
|
|
||||||
if self.weight_inplace_update:
|
|
||||||
for k in keys:
|
|
||||||
ldm_patched.modules.utils.copy_to_param(self.model, k, self.backup[k])
|
|
||||||
else:
|
|
||||||
for k in keys:
|
|
||||||
ldm_patched.modules.utils.set_attr(self.model, k, self.backup[k])
|
|
||||||
|
|
||||||
self.backup = {}
|
|
||||||
|
|
||||||
if device_to is not None:
|
|
||||||
self.model.to(device_to)
|
|
||||||
self.current_device = device_to
|
|
||||||
|
|
||||||
keys = list(self.object_patches_backup.keys())
|
|
||||||
for k in keys:
|
|
||||||
ldm_patched.modules.utils.set_attr_raw(self.model, k, self.object_patches_backup[k])
|
|
||||||
|
|
||||||
self.object_patches_backup = {}
|
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ class UnetPatcher(ModelPatcher):
|
|||||||
|
|
||||||
n.object_patches = self.object_patches.copy()
|
n.object_patches = self.object_patches.copy()
|
||||||
n.model_options = copy.deepcopy(self.model_options)
|
n.model_options = copy.deepcopy(self.model_options)
|
||||||
n.model_keys = self.model_keys
|
|
||||||
n.controlnet_linked_list = self.controlnet_linked_list
|
n.controlnet_linked_list = self.controlnet_linked_list
|
||||||
n.extra_preserved_memory_during_sampling = self.extra_preserved_memory_during_sampling
|
n.extra_preserved_memory_during_sampling = self.extra_preserved_memory_during_sampling
|
||||||
n.extra_model_patchers_during_sampling = self.extra_model_patchers_during_sampling.copy()
|
n.extra_model_patchers_during_sampling = self.extra_model_patchers_during_sampling.copy()
|
||||||
|
|||||||
Reference in New Issue
Block a user