mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-05-01 11:41:23 +00:00
rework lora and patching system
and dora etc - backend rework is 60% finished And I also removed the webui’s extremely annoying lora filter from model versions.
This commit is contained in:
134
backend/misc/diffusers_state_dict.py
Normal file
134
backend/misc/diffusers_state_dict.py
Normal file
@@ -0,0 +1,134 @@
|
||||
UNET_MAP_ATTENTIONS = {
|
||||
"proj_in.weight",
|
||||
"proj_in.bias",
|
||||
"proj_out.weight",
|
||||
"proj_out.bias",
|
||||
"norm.weight",
|
||||
"norm.bias",
|
||||
}
|
||||
|
||||
TRANSFORMER_BLOCKS = {
|
||||
"norm1.weight",
|
||||
"norm1.bias",
|
||||
"norm2.weight",
|
||||
"norm2.bias",
|
||||
"norm3.weight",
|
||||
"norm3.bias",
|
||||
"attn1.to_q.weight",
|
||||
"attn1.to_k.weight",
|
||||
"attn1.to_v.weight",
|
||||
"attn1.to_out.0.weight",
|
||||
"attn1.to_out.0.bias",
|
||||
"attn2.to_q.weight",
|
||||
"attn2.to_k.weight",
|
||||
"attn2.to_v.weight",
|
||||
"attn2.to_out.0.weight",
|
||||
"attn2.to_out.0.bias",
|
||||
"ff.net.0.proj.weight",
|
||||
"ff.net.0.proj.bias",
|
||||
"ff.net.2.weight",
|
||||
"ff.net.2.bias",
|
||||
}
|
||||
|
||||
UNET_MAP_RESNET = {
|
||||
"in_layers.2.weight": "conv1.weight",
|
||||
"in_layers.2.bias": "conv1.bias",
|
||||
"emb_layers.1.weight": "time_emb_proj.weight",
|
||||
"emb_layers.1.bias": "time_emb_proj.bias",
|
||||
"out_layers.3.weight": "conv2.weight",
|
||||
"out_layers.3.bias": "conv2.bias",
|
||||
"skip_connection.weight": "conv_shortcut.weight",
|
||||
"skip_connection.bias": "conv_shortcut.bias",
|
||||
"in_layers.0.weight": "norm1.weight",
|
||||
"in_layers.0.bias": "norm1.bias",
|
||||
"out_layers.0.weight": "norm2.weight",
|
||||
"out_layers.0.bias": "norm2.bias",
|
||||
}
|
||||
|
||||
UNET_MAP_BASIC = {
|
||||
("label_emb.0.0.weight", "class_embedding.linear_1.weight"),
|
||||
("label_emb.0.0.bias", "class_embedding.linear_1.bias"),
|
||||
("label_emb.0.2.weight", "class_embedding.linear_2.weight"),
|
||||
("label_emb.0.2.bias", "class_embedding.linear_2.bias"),
|
||||
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
|
||||
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
|
||||
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
|
||||
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
|
||||
("input_blocks.0.0.weight", "conv_in.weight"),
|
||||
("input_blocks.0.0.bias", "conv_in.bias"),
|
||||
("out.0.weight", "conv_norm_out.weight"),
|
||||
("out.0.bias", "conv_norm_out.bias"),
|
||||
("out.2.weight", "conv_out.weight"),
|
||||
("out.2.bias", "conv_out.bias"),
|
||||
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
||||
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
||||
("time_embed.2.bias", "time_embedding.linear_2.bias")
|
||||
}
|
||||
|
||||
|
||||
def unet_to_diffusers(unet_config):
|
||||
if "num_res_blocks" not in unet_config:
|
||||
return {}
|
||||
num_res_blocks = unet_config["num_res_blocks"]
|
||||
channel_mult = unet_config["channel_mult"]
|
||||
transformer_depth = unet_config["transformer_depth"][:]
|
||||
transformer_depth_output = unet_config["transformer_depth_output"][:]
|
||||
num_blocks = len(channel_mult)
|
||||
|
||||
transformers_mid = unet_config.get("transformer_depth_middle", None)
|
||||
|
||||
diffusers_unet_map = {}
|
||||
for x in range(num_blocks):
|
||||
n = 1 + (num_res_blocks[x] + 1) * x
|
||||
for i in range(num_res_blocks[x]):
|
||||
for b in UNET_MAP_RESNET:
|
||||
diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
|
||||
num_transformers = transformer_depth.pop(0)
|
||||
if num_transformers > 0:
|
||||
for b in UNET_MAP_ATTENTIONS:
|
||||
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
|
||||
for t in range(num_transformers):
|
||||
for b in TRANSFORMER_BLOCKS:
|
||||
diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||
n += 1
|
||||
for k in ["weight", "bias"]:
|
||||
diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k)
|
||||
|
||||
i = 0
|
||||
for b in UNET_MAP_ATTENTIONS:
|
||||
diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b)
|
||||
for t in range(transformers_mid):
|
||||
for b in TRANSFORMER_BLOCKS:
|
||||
diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b)
|
||||
|
||||
for i, n in enumerate([0, 2]):
|
||||
for b in UNET_MAP_RESNET:
|
||||
diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
|
||||
|
||||
num_res_blocks = list(reversed(num_res_blocks))
|
||||
for x in range(num_blocks):
|
||||
n = (num_res_blocks[x] + 1) * x
|
||||
l = num_res_blocks[x] + 1
|
||||
for i in range(l):
|
||||
c = 0
|
||||
for b in UNET_MAP_RESNET:
|
||||
diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
|
||||
c += 1
|
||||
num_transformers = transformer_depth_output.pop()
|
||||
if num_transformers > 0:
|
||||
c += 1
|
||||
for b in UNET_MAP_ATTENTIONS:
|
||||
diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
|
||||
for t in range(num_transformers):
|
||||
for b in TRANSFORMER_BLOCKS:
|
||||
diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||
if i == l - 1:
|
||||
for k in ["weight", "bias"]:
|
||||
diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k)
|
||||
n += 1
|
||||
|
||||
for k in UNET_MAP_BASIC:
|
||||
diffusers_unet_map[k[1]] = k[0]
|
||||
|
||||
return diffusers_unet_map
|
||||
500
backend/patcher/base.py
Normal file
500
backend/patcher/base.py
Normal file
@@ -0,0 +1,500 @@
|
||||
# Model Patching API Template Extracted From ComfyUI
|
||||
# The actual implementation for those APIs are from Forge, implemented from scratch (after forge-v1.0.1),
|
||||
# and may have certain level of differences.
|
||||
|
||||
import torch
|
||||
import copy
|
||||
import inspect
|
||||
|
||||
from backend import memory_management, utils
|
||||
|
||||
extra_weight_calculators = {}
|
||||
|
||||
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
|
||||
dora_scale = memory_management.cast_to_device(dora_scale, weight.device, torch.float32)
|
||||
lora_diff *= alpha
|
||||
weight_calc = weight + lora_diff.type(weight.dtype)
|
||||
weight_norm = (
|
||||
weight_calc.transpose(0, 1)
|
||||
.reshape(weight_calc.shape[1], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||
if strength != 1.0:
|
||||
weight_calc -= weight
|
||||
weight += strength * weight_calc
|
||||
else:
|
||||
weight[:] = weight_calc
|
||||
return weight
|
||||
|
||||
|
||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||
to = model_options["transformer_options"].copy()
|
||||
|
||||
if "patches_replace" not in to:
|
||||
to["patches_replace"] = {}
|
||||
else:
|
||||
to["patches_replace"] = to["patches_replace"].copy()
|
||||
|
||||
if name not in to["patches_replace"]:
|
||||
to["patches_replace"][name] = {}
|
||||
else:
|
||||
to["patches_replace"][name] = to["patches_replace"][name].copy()
|
||||
|
||||
if transformer_index is not None:
|
||||
block = (block_name, number, transformer_index)
|
||||
else:
|
||||
block = (block_name, number)
|
||||
to["patches_replace"][name][block] = patch
|
||||
model_options["transformer_options"] = to
|
||||
return model_options
|
||||
|
||||
|
||||
def set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=False):
|
||||
model_options["sampler_post_cfg_function"] = model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
|
||||
if disable_cfg1_optimization:
|
||||
model_options["disable_cfg1_optimization"] = True
|
||||
return model_options
|
||||
|
||||
|
||||
def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False):
|
||||
model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function]
|
||||
if disable_cfg1_optimization:
|
||||
model_options["disable_cfg1_optimization"] = True
|
||||
return model_options
|
||||
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.patches = {}
|
||||
self.backup = {}
|
||||
self.object_patches = {}
|
||||
self.object_patches_backup = {}
|
||||
self.model_options = {"transformer_options": {}}
|
||||
self.model_size()
|
||||
self.load_device = load_device
|
||||
self.offload_device = offload_device
|
||||
if current_device is None:
|
||||
self.current_device = self.offload_device
|
||||
else:
|
||||
self.current_device = current_device
|
||||
|
||||
self.weight_inplace_update = weight_inplace_update
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
self.size = memory_management.module_size(self.model)
|
||||
return self.size
|
||||
|
||||
def clone(self):
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
return n
|
||||
|
||||
def is_clone(self, other):
|
||||
if hasattr(other, 'model') and self.model is other.model:
|
||||
return True
|
||||
return False
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
return self.model.memory_required(input_shape=input_shape)
|
||||
|
||||
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) # Old way
|
||||
else:
|
||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||
if disable_cfg1_optimization:
|
||||
self.model_options["disable_cfg1_optimization"] = True
|
||||
|
||||
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
||||
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
||||
|
||||
def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False):
|
||||
self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization)
|
||||
|
||||
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||
|
||||
def set_model_vae_encode_wrapper(self, wrapper_function):
|
||||
self.model_options["model_vae_encode_wrapper"] = wrapper_function
|
||||
|
||||
def set_model_vae_decode_wrapper(self, wrapper_function):
|
||||
self.model_options["model_vae_decode_wrapper"] = wrapper_function
|
||||
|
||||
def set_model_vae_regulation(self, vae_regulation):
|
||||
self.model_options["model_vae_regulation"] = vae_regulation
|
||||
|
||||
def set_model_denoise_mask_function(self, denoise_mask_function):
|
||||
self.model_options["denoise_mask_function"] = denoise_mask_function
|
||||
|
||||
def set_model_patch(self, patch, name):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" not in to:
|
||||
to["patches"] = {}
|
||||
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
||||
|
||||
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
|
||||
self.model_options = set_model_options_patch_replace(self.model_options, patch, name, block_name, number, transformer_index=transformer_index)
|
||||
|
||||
def set_model_attn1_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn1_patch")
|
||||
|
||||
def set_model_attn2_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn2_patch")
|
||||
|
||||
def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
|
||||
self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
|
||||
|
||||
def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
|
||||
self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
|
||||
|
||||
def set_model_attn1_output_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn1_output_patch")
|
||||
|
||||
def set_model_attn2_output_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn2_output_patch")
|
||||
|
||||
def set_model_input_block_patch(self, patch):
|
||||
self.set_model_patch(patch, "input_block_patch")
|
||||
|
||||
def set_model_input_block_patch_after_skip(self, patch):
|
||||
self.set_model_patch(patch, "input_block_patch_after_skip")
|
||||
|
||||
def set_model_output_block_patch(self, patch):
|
||||
self.set_model_patch(patch, "output_block_patch")
|
||||
|
||||
def add_object_patch(self, name, obj):
|
||||
self.object_patches[name] = obj
|
||||
|
||||
def get_model_object(self, name):
|
||||
if name in self.object_patches:
|
||||
return self.object_patches[name]
|
||||
else:
|
||||
if name in self.object_patches_backup:
|
||||
return self.object_patches_backup[name]
|
||||
else:
|
||||
return utils.get_attr(self.model, name)
|
||||
|
||||
def model_patches_to(self, device):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" in to:
|
||||
patches = to["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], "to"):
|
||||
patch_list[i] = patch_list[i].to(device)
|
||||
if "patches_replace" in to:
|
||||
patches = to["patches_replace"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for k in patch_list:
|
||||
if hasattr(patch_list[k], "to"):
|
||||
patch_list[k] = patch_list[k].to(device)
|
||||
if "model_function_wrapper" in self.model_options:
|
||||
wrap_func = self.model_options["model_function_wrapper"]
|
||||
if hasattr(wrap_func, "to"):
|
||||
self.model_options["model_function_wrapper"] = wrap_func.to(device)
|
||||
|
||||
def model_dtype(self):
|
||||
if hasattr(self.model, "get_dtype"):
|
||||
return self.model.get_dtype()
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
p = set()
|
||||
model_sd = self.model.state_dict()
|
||||
for k in patches:
|
||||
offset = None
|
||||
function = None
|
||||
if isinstance(k, str):
|
||||
key = k
|
||||
else:
|
||||
offset = k[1]
|
||||
key = k[0]
|
||||
if len(k) > 2:
|
||||
function = k[2]
|
||||
|
||||
if key in model_sd:
|
||||
p.add(k)
|
||||
current_patches = self.patches.get(key, [])
|
||||
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
|
||||
self.patches[key] = current_patches
|
||||
|
||||
return list(p)
|
||||
|
||||
def get_key_patches(self, filter_prefix=None):
|
||||
memory_management.unload_model_clones(self)
|
||||
model_sd = self.model_state_dict()
|
||||
p = {}
|
||||
for k in model_sd:
|
||||
if filter_prefix is not None:
|
||||
if not k.startswith(filter_prefix):
|
||||
continue
|
||||
if k in self.patches:
|
||||
p[k] = [model_sd[k]] + self.patches[k]
|
||||
else:
|
||||
p[k] = (model_sd[k],)
|
||||
return p
|
||||
|
||||
def model_state_dict(self, filter_prefix=None):
|
||||
sd = self.model.state_dict()
|
||||
keys = list(sd.keys())
|
||||
if filter_prefix is not None:
|
||||
for k in keys:
|
||||
if not k.startswith(filter_prefix):
|
||||
sd.pop(k)
|
||||
return sd
|
||||
|
||||
def patch_model(self, device_to=None, patch_weights=True):
|
||||
for k in self.object_patches:
|
||||
old = utils.get_attr(self.model, k)
|
||||
if k not in self.object_patches_backup:
|
||||
self.object_patches_backup[k] = old
|
||||
utils.set_attr_raw(self.model, k, self.object_patches[k])
|
||||
|
||||
if patch_weights:
|
||||
model_sd = self.model_state_dict()
|
||||
for key in self.patches:
|
||||
if key not in model_sd:
|
||||
print("could not patch. key doesn't exist in model:", key)
|
||||
continue
|
||||
|
||||
weight = model_sd[key]
|
||||
|
||||
inplace_update = self.weight_inplace_update
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
||||
|
||||
if device_to is not None:
|
||||
temp_weight = memory_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||
if inplace_update:
|
||||
utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
utils.set_attr(self.model, key, out_weight)
|
||||
del temp_weight
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
|
||||
return self.model
|
||||
|
||||
def calculate_weight(self, patches, weight, key):
|
||||
for p in patches:
|
||||
strength = p[0]
|
||||
v = p[1]
|
||||
strength_model = p[2]
|
||||
offset = p[3]
|
||||
function = p[4]
|
||||
if function is None:
|
||||
function = lambda a: a
|
||||
|
||||
old_weight = None
|
||||
if offset is not None:
|
||||
old_weight = weight
|
||||
weight = weight.narrow(offset[0], offset[1], offset[2])
|
||||
|
||||
if strength_model != 1.0:
|
||||
weight *= strength_model
|
||||
|
||||
if isinstance(v, list):
|
||||
v = (self.calculate_weight(v[1:], v[0].clone(), key),)
|
||||
|
||||
patch_type = ''
|
||||
|
||||
if len(v) == 1:
|
||||
patch_type = "diff"
|
||||
elif len(v) == 2:
|
||||
patch_type = v[0]
|
||||
v = v[1]
|
||||
|
||||
if patch_type == "diff":
|
||||
w1 = v[0]
|
||||
if strength != 0.0:
|
||||
if w1.shape != weight.shape:
|
||||
if w1.ndim == weight.ndim == 4:
|
||||
new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)]
|
||||
print(f'Merged with {key} channel changed to {new_shape}')
|
||||
new_diff = strength * memory_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||
new_weight = torch.zeros(size=new_shape).to(weight)
|
||||
new_weight[:weight.shape[0], :weight.shape[1], :weight.shape[2], :weight.shape[3]] = weight
|
||||
new_weight[:new_diff.shape[0], :new_diff.shape[1], :new_diff.shape[2], :new_diff.shape[3]] += new_diff
|
||||
new_weight = new_weight.contiguous().clone()
|
||||
weight = new_weight
|
||||
else:
|
||||
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||
else:
|
||||
weight += strength * memory_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||
elif patch_type == "lora":
|
||||
mat1 = memory_management.cast_to_device(v[0], weight.device, torch.float32)
|
||||
mat2 = memory_management.cast_to_device(v[1], weight.device, torch.float32)
|
||||
dora_scale = v[4]
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / mat2.shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
if v[3] is not None:
|
||||
mat3 = memory_management.cast_to_device(v[3], weight.device, torch.float32)
|
||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||
try:
|
||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
print("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "lokr":
|
||||
w1 = v[0]
|
||||
w2 = v[1]
|
||||
w1_a = v[3]
|
||||
w1_b = v[4]
|
||||
w2_a = v[5]
|
||||
w2_b = v[6]
|
||||
t2 = v[7]
|
||||
dora_scale = v[8]
|
||||
dim = None
|
||||
|
||||
if w1 is None:
|
||||
dim = w1_b.shape[0]
|
||||
w1 = torch.mm(memory_management.cast_to_device(w1_a, weight.device, torch.float32),
|
||||
memory_management.cast_to_device(w1_b, weight.device, torch.float32))
|
||||
else:
|
||||
w1 = memory_management.cast_to_device(w1, weight.device, torch.float32)
|
||||
|
||||
if w2 is None:
|
||||
dim = w2_b.shape[0]
|
||||
if t2 is None:
|
||||
w2 = torch.mm(memory_management.cast_to_device(w2_a, weight.device, torch.float32),
|
||||
memory_management.cast_to_device(w2_b, weight.device, torch.float32))
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
memory_management.cast_to_device(t2, weight.device, torch.float32),
|
||||
memory_management.cast_to_device(w2_b, weight.device, torch.float32),
|
||||
memory_management.cast_to_device(w2_a, weight.device, torch.float32))
|
||||
else:
|
||||
w2 = memory_management.cast_to_device(w2, weight.device, torch.float32)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
if v[2] is not None and dim is not None:
|
||||
alpha = v[2] / dim
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
try:
|
||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
print("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "loha":
|
||||
w1a = v[0]
|
||||
w1b = v[1]
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / w1b.shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
w2a = v[3]
|
||||
w2b = v[4]
|
||||
dora_scale = v[7]
|
||||
if v[5] is not None:
|
||||
t1 = v[5]
|
||||
t2 = v[6]
|
||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
memory_management.cast_to_device(t1, weight.device, torch.float32),
|
||||
memory_management.cast_to_device(w1b, weight.device, torch.float32),
|
||||
memory_management.cast_to_device(w1a, weight.device, torch.float32))
|
||||
|
||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
memory_management.cast_to_device(t2, weight.device, torch.float32),
|
||||
memory_management.cast_to_device(w2b, weight.device, torch.float32),
|
||||
memory_management.cast_to_device(w2a, weight.device, torch.float32))
|
||||
else:
|
||||
m1 = torch.mm(memory_management.cast_to_device(w1a, weight.device, torch.float32),
|
||||
memory_management.cast_to_device(w1b, weight.device, torch.float32))
|
||||
m2 = torch.mm(memory_management.cast_to_device(w2a, weight.device, torch.float32),
|
||||
memory_management.cast_to_device(w2b, weight.device, torch.float32))
|
||||
|
||||
try:
|
||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
print("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "glora":
|
||||
if v[4] is not None:
|
||||
alpha = v[4] / v[0].shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
dora_scale = v[5]
|
||||
|
||||
a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
|
||||
a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
|
||||
b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
|
||||
b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
||||
|
||||
try:
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
print("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type in extra_weight_calculators:
|
||||
weight = extra_weight_calculators[patch_type](weight, strength, v)
|
||||
else:
|
||||
print("patch type not recognized {} {}".format(patch_type, key))
|
||||
|
||||
if old_weight is not None:
|
||||
weight = old_weight
|
||||
|
||||
return weight
|
||||
|
||||
def unpatch_model(self, device_to=None):
|
||||
keys = list(self.backup.keys())
|
||||
|
||||
if self.weight_inplace_update:
|
||||
for k in keys:
|
||||
utils.copy_to_param(self.model, k, self.backup[k])
|
||||
else:
|
||||
for k in keys:
|
||||
utils.set_attr(self.model, k, self.backup[k])
|
||||
|
||||
self.backup = {}
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
|
||||
keys = list(self.object_patches_backup.keys())
|
||||
for k in keys:
|
||||
utils.set_attr_raw(self.model, k, self.object_patches_backup[k])
|
||||
|
||||
self.object_patches_backup = {}
|
||||
293
backend/patcher/lora.py
Normal file
293
backend/patcher/lora.py
Normal file
@@ -0,0 +1,293 @@
|
||||
# LoRA Implementation Collection form ComfyUI
|
||||
# Modified by Forge to support greedy loading (load a set or wrong/correct loras to a model and only preserve the correct ones),
|
||||
# which is important to webui experience
|
||||
|
||||
from backend.misc.diffusers_state_dict import unet_to_diffusers
|
||||
|
||||
|
||||
LORA_CLIP_MAP = {
|
||||
"mlp.fc1": "mlp_fc1",
|
||||
"mlp.fc2": "mlp_fc2",
|
||||
"self_attn.k_proj": "self_attn_k_proj",
|
||||
"self_attn.q_proj": "self_attn_q_proj",
|
||||
"self_attn.v_proj": "self_attn_v_proj",
|
||||
"self_attn.out_proj": "self_attn_out_proj",
|
||||
}
|
||||
|
||||
|
||||
def load_lora(lora, to_load):
|
||||
patch_dict = {}
|
||||
loaded_keys = set()
|
||||
for x in to_load:
|
||||
alpha_name = "{}.alpha".format(x)
|
||||
alpha = None
|
||||
if alpha_name in lora.keys():
|
||||
alpha = lora[alpha_name].item()
|
||||
loaded_keys.add(alpha_name)
|
||||
|
||||
dora_scale_name = "{}.dora_scale".format(x)
|
||||
dora_scale = None
|
||||
if dora_scale_name in lora.keys():
|
||||
dora_scale = lora[dora_scale_name]
|
||||
loaded_keys.add(dora_scale_name)
|
||||
|
||||
regular_lora = "{}.lora_up.weight".format(x)
|
||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||
A_name = None
|
||||
|
||||
if regular_lora in lora.keys():
|
||||
A_name = regular_lora
|
||||
B_name = "{}.lora_down.weight".format(x)
|
||||
mid_name = "{}.lora_mid.weight".format(x)
|
||||
elif diffusers_lora in lora.keys():
|
||||
A_name = diffusers_lora
|
||||
B_name = "{}_lora.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif diffusers2_lora in lora.keys():
|
||||
A_name = diffusers2_lora
|
||||
B_name = "{}.lora_A.weight".format(x)
|
||||
mid_name = None
|
||||
elif diffusers3_lora in lora.keys():
|
||||
A_name = diffusers3_lora
|
||||
B_name = "{}.lora.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif transformers_lora in lora.keys():
|
||||
A_name = transformers_lora
|
||||
B_name = "{}.lora_linear_layer.down.weight".format(x)
|
||||
mid_name = None
|
||||
|
||||
if A_name is not None:
|
||||
mid = None
|
||||
if mid_name is not None and mid_name in lora.keys():
|
||||
mid = lora[mid_name]
|
||||
loaded_keys.add(mid_name)
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
|
||||
loaded_keys.add(A_name)
|
||||
loaded_keys.add(B_name)
|
||||
|
||||
######## loha
|
||||
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
||||
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
||||
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
||||
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
||||
hada_t1_name = "{}.hada_t1".format(x)
|
||||
hada_t2_name = "{}.hada_t2".format(x)
|
||||
if hada_w1_a_name in lora.keys():
|
||||
hada_t1 = None
|
||||
hada_t2 = None
|
||||
if hada_t1_name in lora.keys():
|
||||
hada_t1 = lora[hada_t1_name]
|
||||
hada_t2 = lora[hada_t2_name]
|
||||
loaded_keys.add(hada_t1_name)
|
||||
loaded_keys.add(hada_t2_name)
|
||||
|
||||
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
|
||||
loaded_keys.add(hada_w1_a_name)
|
||||
loaded_keys.add(hada_w1_b_name)
|
||||
loaded_keys.add(hada_w2_a_name)
|
||||
loaded_keys.add(hada_w2_b_name)
|
||||
|
||||
######## lokr
|
||||
lokr_w1_name = "{}.lokr_w1".format(x)
|
||||
lokr_w2_name = "{}.lokr_w2".format(x)
|
||||
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
||||
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
||||
lokr_t2_name = "{}.lokr_t2".format(x)
|
||||
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
||||
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
||||
|
||||
lokr_w1 = None
|
||||
if lokr_w1_name in lora.keys():
|
||||
lokr_w1 = lora[lokr_w1_name]
|
||||
loaded_keys.add(lokr_w1_name)
|
||||
|
||||
lokr_w2 = None
|
||||
if lokr_w2_name in lora.keys():
|
||||
lokr_w2 = lora[lokr_w2_name]
|
||||
loaded_keys.add(lokr_w2_name)
|
||||
|
||||
lokr_w1_a = None
|
||||
if lokr_w1_a_name in lora.keys():
|
||||
lokr_w1_a = lora[lokr_w1_a_name]
|
||||
loaded_keys.add(lokr_w1_a_name)
|
||||
|
||||
lokr_w1_b = None
|
||||
if lokr_w1_b_name in lora.keys():
|
||||
lokr_w1_b = lora[lokr_w1_b_name]
|
||||
loaded_keys.add(lokr_w1_b_name)
|
||||
|
||||
lokr_w2_a = None
|
||||
if lokr_w2_a_name in lora.keys():
|
||||
lokr_w2_a = lora[lokr_w2_a_name]
|
||||
loaded_keys.add(lokr_w2_a_name)
|
||||
|
||||
lokr_w2_b = None
|
||||
if lokr_w2_b_name in lora.keys():
|
||||
lokr_w2_b = lora[lokr_w2_b_name]
|
||||
loaded_keys.add(lokr_w2_b_name)
|
||||
|
||||
lokr_t2 = None
|
||||
if lokr_t2_name in lora.keys():
|
||||
lokr_t2 = lora[lokr_t2_name]
|
||||
loaded_keys.add(lokr_t2_name)
|
||||
|
||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
|
||||
|
||||
# glora
|
||||
a1_name = "{}.a1.weight".format(x)
|
||||
a2_name = "{}.a2.weight".format(x)
|
||||
b1_name = "{}.b1.weight".format(x)
|
||||
b2_name = "{}.b2.weight".format(x)
|
||||
if a1_name in lora:
|
||||
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
|
||||
loaded_keys.add(a1_name)
|
||||
loaded_keys.add(a2_name)
|
||||
loaded_keys.add(b1_name)
|
||||
loaded_keys.add(b2_name)
|
||||
|
||||
w_norm_name = "{}.w_norm".format(x)
|
||||
b_norm_name = "{}.b_norm".format(x)
|
||||
w_norm = lora.get(w_norm_name, None)
|
||||
b_norm = lora.get(b_norm_name, None)
|
||||
|
||||
if w_norm is not None:
|
||||
loaded_keys.add(w_norm_name)
|
||||
patch_dict[to_load[x]] = ("diff", (w_norm,))
|
||||
if b_norm is not None:
|
||||
loaded_keys.add(b_norm_name)
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
|
||||
|
||||
diff_name = "{}.diff".format(x)
|
||||
diff_weight = lora.get(diff_name, None)
|
||||
if diff_weight is not None:
|
||||
patch_dict[to_load[x]] = ("diff", (diff_weight,))
|
||||
loaded_keys.add(diff_name)
|
||||
|
||||
diff_bias_name = "{}.diff_b".format(x)
|
||||
diff_bias = lora.get(diff_bias_name, None)
|
||||
if diff_bias is not None:
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
||||
loaded_keys.add(diff_bias_name)
|
||||
|
||||
remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
|
||||
return patch_dict, remaining_dict
|
||||
|
||||
|
||||
def model_lora_keys_clip(model, key_map={}):
|
||||
sdk = model.state_dict().keys()
|
||||
|
||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||
clip_l_present = False
|
||||
for b in range(32): # TODO: clean up
|
||||
for c in LORA_CLIP_MAP:
|
||||
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c)
|
||||
key_map[lora_key] = k
|
||||
|
||||
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
clip_l_present = True
|
||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c)
|
||||
key_map[lora_key] = k
|
||||
|
||||
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
if clip_l_present:
|
||||
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c)
|
||||
key_map[lora_key] = k
|
||||
else:
|
||||
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c)
|
||||
key_map[lora_key] = k
|
||||
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
|
||||
for k in sdk: # OneTrainer SD3 lora
|
||||
if k.startswith("t5xxl.transformer.") and k.endswith(".weight"):
|
||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
||||
key_map[lora_key] = k
|
||||
|
||||
k = "clip_g.transformer.text_projection.weight"
|
||||
if k in sdk:
|
||||
key_map["lora_prior_te_text_projection"] = k
|
||||
key_map["lora_te2_text_projection"] = k
|
||||
|
||||
k = "clip_l.transformer.text_projection.weight"
|
||||
if k in sdk:
|
||||
key_map["lora_te1_text_projection"] = k
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
def model_lora_keys_unet(model, key_map={}):
|
||||
sd = model.state_dict()
|
||||
sdk = sd.keys()
|
||||
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k
|
||||
|
||||
diffusers_keys = unet_to_diffusers(model.diffusion_model.legacy_config)
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
||||
|
||||
diffusers_lora_prefix = ["", "unet."]
|
||||
for p in diffusers_lora_prefix:
|
||||
diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
|
||||
if diffusers_lora_key.endswith(".to_out.0"):
|
||||
diffusers_lora_key = diffusers_lora_key[:-2]
|
||||
key_map[diffusers_lora_key] = unet_key
|
||||
|
||||
# TODO:
|
||||
|
||||
# if isinstance(model, xxxx.modules.model_base.SD3): # Diffusers lora SD3
|
||||
# diffusers_keys = xxxx.modules.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
# for k in diffusers_keys:
|
||||
# if k.endswith(".weight"):
|
||||
# to = diffusers_keys[k]
|
||||
# key_lora = "transformer.{}".format(k[:-len(".weight")]) # regular diffusers sd3 lora format
|
||||
# key_map[key_lora] = to
|
||||
#
|
||||
# key_lora = "base_model.model.{}".format(k[:-len(".weight")]) # format for flash-sd3 lora and others?
|
||||
# key_map[key_lora] = to
|
||||
#
|
||||
# key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) # OneTrainer lora
|
||||
# key_map[key_lora] = to
|
||||
#
|
||||
# if isinstance(model, xxxx.modules.model_base.AuraFlow): # Diffusers lora AuraFlow
|
||||
# diffusers_keys = xxxx.modules.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
# for k in diffusers_keys:
|
||||
# if k.endswith(".weight"):
|
||||
# to = diffusers_keys[k]
|
||||
# key_lora = "transformer.{}".format(k[:-len(".weight")]) # simpletrainer and probably regular diffusers lora format
|
||||
# key_map[key_lora] = to
|
||||
#
|
||||
# if isinstance(model, xxxx.modules.model_base.HunyuanDiT):
|
||||
# for k in sdk:
|
||||
# if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
# key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
# key_map["base_model.model.{}".format(key_lora)] = k # official hunyuan lora format
|
||||
|
||||
return key_map
|
||||
30
backend/utils.py
Normal file
30
backend/utils.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import torch
|
||||
|
||||
|
||||
def set_attr(obj, attr, value):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
|
||||
|
||||
|
||||
def set_attr_raw(obj, attr, value):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
setattr(obj, attrs[-1], value)
|
||||
|
||||
|
||||
def copy_to_param(obj, attr, value):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
prev = getattr(obj, attrs[-1])
|
||||
prev.data.copy_(value)
|
||||
|
||||
|
||||
def get_attr(obj, attr):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs:
|
||||
obj = getattr(obj, name)
|
||||
return obj
|
||||
@@ -1,6 +0,0 @@
|
||||
class LoraPatches:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def undo(self):
|
||||
pass
|
||||
@@ -1,228 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from collections import namedtuple
|
||||
import enum
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modules import sd_models, cache, errors, hashes, shared
|
||||
import modules.models.sd3.mmdit
|
||||
|
||||
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
|
||||
|
||||
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
||||
|
||||
|
||||
class SdVersion(enum.Enum):
|
||||
Unknown = 1
|
||||
SD1 = 2
|
||||
SD2 = 3
|
||||
SDXL = 4
|
||||
|
||||
|
||||
class NetworkOnDisk:
|
||||
def __init__(self, name, filename):
|
||||
self.name = name
|
||||
self.filename = filename
|
||||
self.metadata = {}
|
||||
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
||||
|
||||
def read_metadata():
|
||||
metadata = sd_models.read_metadata_from_safetensors(filename)
|
||||
|
||||
return metadata
|
||||
|
||||
if self.is_safetensors:
|
||||
try:
|
||||
self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading lora {filename}")
|
||||
|
||||
if self.metadata:
|
||||
m = {}
|
||||
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
|
||||
m[k] = v
|
||||
|
||||
self.metadata = m
|
||||
|
||||
self.alias = self.metadata.get('ss_output_name', self.name)
|
||||
|
||||
self.hash = None
|
||||
self.shorthash = None
|
||||
self.set_hash(
|
||||
self.metadata.get('sshs_model_hash') or
|
||||
hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
|
||||
''
|
||||
)
|
||||
|
||||
self.sd_version = self.detect_version()
|
||||
|
||||
def detect_version(self):
|
||||
if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
|
||||
return SdVersion.SDXL
|
||||
elif str(self.metadata.get('ss_v2', "")) == "True":
|
||||
return SdVersion.SD2
|
||||
elif len(self.metadata):
|
||||
return SdVersion.SD1
|
||||
|
||||
return SdVersion.Unknown
|
||||
|
||||
def set_hash(self, v):
|
||||
self.hash = v
|
||||
self.shorthash = self.hash[0:12]
|
||||
|
||||
if self.shorthash:
|
||||
import networks
|
||||
networks.available_network_hash_lookup[self.shorthash] = self
|
||||
|
||||
def read_hash(self):
|
||||
if not self.hash:
|
||||
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
|
||||
|
||||
def get_alias(self):
|
||||
import networks
|
||||
if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:
|
||||
return self.name
|
||||
else:
|
||||
return self.alias
|
||||
|
||||
|
||||
class Network: # LoraModule
|
||||
def __init__(self, name, network_on_disk: NetworkOnDisk):
|
||||
self.name = name
|
||||
self.network_on_disk = network_on_disk
|
||||
self.te_multiplier = 1.0
|
||||
self.unet_multiplier = 1.0
|
||||
self.dyn_dim = None
|
||||
self.modules = {}
|
||||
self.bundle_embeddings = {}
|
||||
self.mtime = None
|
||||
|
||||
self.mentioned_name = None
|
||||
"""the text that was used to add the network to prompt - can be either name or an alias"""
|
||||
|
||||
|
||||
class ModuleType:
|
||||
def create_module(self, net: Network, weights: NetworkWeights) -> Network | None:
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModule:
|
||||
def __init__(self, net: Network, weights: NetworkWeights):
|
||||
self.network = net
|
||||
self.network_key = weights.network_key
|
||||
self.sd_key = weights.sd_key
|
||||
self.sd_module = weights.sd_module
|
||||
|
||||
if isinstance(self.sd_module, modules.models.sd3.mmdit.QkvLinear):
|
||||
s = self.sd_module.weight.shape
|
||||
self.shape = (s[0] // 3, s[1])
|
||||
elif hasattr(self.sd_module, 'weight'):
|
||||
self.shape = self.sd_module.weight.shape
|
||||
elif isinstance(self.sd_module, nn.MultiheadAttention):
|
||||
# For now, only self-attn use Pytorch's MHA
|
||||
# So assume all qkvo proj have same shape
|
||||
self.shape = self.sd_module.out_proj.weight.shape
|
||||
else:
|
||||
self.shape = None
|
||||
|
||||
self.ops = None
|
||||
self.extra_kwargs = {}
|
||||
if isinstance(self.sd_module, nn.Conv2d):
|
||||
self.ops = F.conv2d
|
||||
self.extra_kwargs = {
|
||||
'stride': self.sd_module.stride,
|
||||
'padding': self.sd_module.padding
|
||||
}
|
||||
elif isinstance(self.sd_module, nn.Linear):
|
||||
self.ops = F.linear
|
||||
elif isinstance(self.sd_module, nn.LayerNorm):
|
||||
self.ops = F.layer_norm
|
||||
self.extra_kwargs = {
|
||||
'normalized_shape': self.sd_module.normalized_shape,
|
||||
'eps': self.sd_module.eps
|
||||
}
|
||||
elif isinstance(self.sd_module, nn.GroupNorm):
|
||||
self.ops = F.group_norm
|
||||
self.extra_kwargs = {
|
||||
'num_groups': self.sd_module.num_groups,
|
||||
'eps': self.sd_module.eps
|
||||
}
|
||||
|
||||
self.dim = None
|
||||
self.bias = weights.w.get("bias")
|
||||
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
|
||||
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
|
||||
|
||||
self.dora_scale = weights.w.get("dora_scale", None)
|
||||
self.dora_norm_dims = len(self.shape) - 1
|
||||
|
||||
def multiplier(self):
|
||||
if 'transformer' in self.sd_key[:20]:
|
||||
return self.network.te_multiplier
|
||||
else:
|
||||
return self.network.unet_multiplier
|
||||
|
||||
def calc_scale(self):
|
||||
if self.scale is not None:
|
||||
return self.scale
|
||||
if self.dim is not None and self.alpha is not None:
|
||||
return self.alpha / self.dim
|
||||
|
||||
return 1.0
|
||||
|
||||
def apply_weight_decompose(self, updown, orig_weight):
|
||||
# Match the device/dtype
|
||||
orig_weight = orig_weight.to(updown.dtype)
|
||||
dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype)
|
||||
updown = updown.to(orig_weight.device)
|
||||
|
||||
merged_scale1 = updown + orig_weight
|
||||
merged_scale1_norm = (
|
||||
merged_scale1.transpose(0, 1)
|
||||
.reshape(merged_scale1.shape[1], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
dora_merged = (
|
||||
merged_scale1 * (dora_scale / merged_scale1_norm)
|
||||
)
|
||||
final_updown = dora_merged - orig_weight
|
||||
return final_updown
|
||||
|
||||
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
||||
if self.bias is not None:
|
||||
updown = updown.reshape(self.bias.shape)
|
||||
updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
|
||||
updown = updown.reshape(output_shape)
|
||||
|
||||
if len(output_shape) == 4:
|
||||
updown = updown.reshape(output_shape)
|
||||
|
||||
if orig_weight.size().numel() == updown.size().numel():
|
||||
updown = updown.reshape(orig_weight.shape)
|
||||
|
||||
if ex_bias is not None:
|
||||
ex_bias = ex_bias * self.multiplier()
|
||||
|
||||
updown = updown * self.calc_scale()
|
||||
|
||||
if self.dora_scale is not None:
|
||||
updown = self.apply_weight_decompose(updown, orig_weight)
|
||||
|
||||
return updown * self.multiplier(), ex_bias
|
||||
|
||||
def calc_updown(self, target):
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, x, y):
|
||||
"""A general forward implementation for all modules"""
|
||||
if self.ops is None:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
updown, ex_bias = self.calc_updown(self.sd_module.weight)
|
||||
return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)
|
||||
|
||||
74
extensions-builtin/sd_forge_lora/network.py
Normal file
74
extensions-builtin/sd_forge_lora/network.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import os
|
||||
|
||||
from modules import sd_models, cache, errors, hashes, shared
|
||||
|
||||
|
||||
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
||||
|
||||
|
||||
class NetworkOnDisk:
|
||||
def __init__(self, name, filename):
|
||||
self.name = name
|
||||
self.filename = filename
|
||||
self.metadata = {}
|
||||
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
||||
|
||||
def read_metadata():
|
||||
metadata = sd_models.read_metadata_from_safetensors(filename)
|
||||
|
||||
return metadata
|
||||
|
||||
if self.is_safetensors:
|
||||
try:
|
||||
self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading lora {filename}")
|
||||
|
||||
if self.metadata:
|
||||
m = {}
|
||||
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
|
||||
m[k] = v
|
||||
|
||||
self.metadata = m
|
||||
|
||||
self.alias = self.metadata.get('ss_output_name', self.name)
|
||||
|
||||
self.hash = None
|
||||
self.shorthash = None
|
||||
self.set_hash(
|
||||
self.metadata.get('sshs_model_hash') or
|
||||
hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
|
||||
''
|
||||
)
|
||||
|
||||
def set_hash(self, v):
|
||||
self.hash = v
|
||||
self.shorthash = self.hash[0:12]
|
||||
|
||||
if self.shorthash:
|
||||
import networks
|
||||
networks.available_network_hash_lookup[self.shorthash] = self
|
||||
|
||||
def read_hash(self):
|
||||
if not self.hash:
|
||||
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
|
||||
|
||||
def get_alias(self):
|
||||
import networks
|
||||
if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:
|
||||
return self.name
|
||||
else:
|
||||
return self.alias
|
||||
|
||||
|
||||
class Network:
|
||||
def __init__(self, name, network_on_disk: NetworkOnDisk):
|
||||
self.name = name
|
||||
self.network_on_disk = network_on_disk
|
||||
self.te_multiplier = 1.0
|
||||
self.unet_multiplier = 1.0
|
||||
self.dyn_dim = None
|
||||
self.modules = {}
|
||||
self.bundle_embeddings = {}
|
||||
self.mtime = None
|
||||
self.mentioned_name = None
|
||||
@@ -4,7 +4,6 @@ import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
import lora_patches
|
||||
import functools
|
||||
import network
|
||||
|
||||
@@ -21,22 +20,6 @@ def load_lora_state_dict(filename):
|
||||
return load_torch_file(filename, safe_load=True)
|
||||
|
||||
|
||||
def convert_diffusers_name_to_compvis(key, is_sd2):
|
||||
pass
|
||||
|
||||
|
||||
def assign_network_names_to_compvis_modules(sd_model):
|
||||
pass
|
||||
|
||||
|
||||
class BundledTIHash(str):
|
||||
def __init__(self, hash_str):
|
||||
self.hash = hash_str
|
||||
|
||||
def __str__(self):
|
||||
return self.hash if shared.opts.lora_bundled_ti_to_infotext else ''
|
||||
|
||||
|
||||
def load_network(name, network_on_disk):
|
||||
net = network.Network(name, network_on_disk)
|
||||
net.mtime = os.path.getmtime(network_on_disk.filename)
|
||||
@@ -44,10 +27,6 @@ def load_network(name, network_on_disk):
|
||||
return net
|
||||
|
||||
|
||||
def purge_networks_from_memory():
|
||||
pass
|
||||
|
||||
|
||||
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
||||
global lora_state_dict_cache
|
||||
|
||||
@@ -105,84 +84,6 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
||||
return
|
||||
|
||||
|
||||
def allowed_layer_without_weight(layer):
|
||||
if isinstance(layer, torch.nn.LayerNorm) and not layer.elementwise_affine:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def store_weights_backup(weight):
|
||||
if weight is None:
|
||||
return None
|
||||
|
||||
return weight.to(devices.cpu, copy=True)
|
||||
|
||||
|
||||
def restore_weights_backup(obj, field, weight):
|
||||
if weight is None:
|
||||
setattr(obj, field, None)
|
||||
return
|
||||
|
||||
getattr(obj, field).copy_(weight)
|
||||
|
||||
|
||||
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
||||
pass
|
||||
|
||||
|
||||
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
||||
pass
|
||||
|
||||
|
||||
def network_forward(org_module, input, original_forward):
|
||||
pass
|
||||
|
||||
|
||||
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
||||
pass
|
||||
|
||||
|
||||
def network_Linear_forward(self, input):
|
||||
pass
|
||||
|
||||
|
||||
def network_Linear_load_state_dict(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def network_Conv2d_forward(self, input):
|
||||
pass
|
||||
|
||||
|
||||
def network_Conv2d_load_state_dict(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def network_GroupNorm_forward(self, input):
|
||||
pass
|
||||
|
||||
|
||||
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def network_LayerNorm_forward(self, input):
|
||||
pass
|
||||
|
||||
|
||||
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def network_MultiheadAttention_forward(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def process_network_files(names: list[str] | None = None):
|
||||
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
||||
for filename in candidates:
|
||||
@@ -257,8 +158,6 @@ def infotext_pasted(infotext, params):
|
||||
params["Prompt"] += "\n" + "".join(added)
|
||||
|
||||
|
||||
originals: lora_patches.LoraPatches = None
|
||||
|
||||
extra_network_lora = None
|
||||
|
||||
available_networks = {}
|
||||
@@ -6,16 +6,11 @@ from fastapi import FastAPI
|
||||
import network
|
||||
import networks
|
||||
import lora # noqa:F401
|
||||
import lora_patches
|
||||
import extra_networks_lora
|
||||
import ui_extra_networks_lora
|
||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
||||
|
||||
|
||||
def unload():
|
||||
networks.originals.undo()
|
||||
|
||||
|
||||
def before_ui():
|
||||
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
||||
|
||||
@@ -23,10 +18,6 @@ def before_ui():
|
||||
extra_networks.register_extra_network(networks.extra_network_lora)
|
||||
|
||||
|
||||
networks.originals = lora_patches.LoraPatches()
|
||||
|
||||
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
|
||||
script_callbacks.on_script_unloaded(unload)
|
||||
script_callbacks.on_before_ui(before_ui)
|
||||
script_callbacks.on_infotext_pasted(networks.infotext_pasted)
|
||||
|
||||
@@ -97,5 +88,3 @@ def infotext_pasted(infotext, d):
|
||||
|
||||
|
||||
script_callbacks.on_infotext_pasted(infotext_pasted)
|
||||
|
||||
shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory)
|
||||
@@ -37,7 +37,6 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||
"metadata": lora_on_disk.metadata,
|
||||
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
|
||||
"sd_version": lora_on_disk.sd_version.name,
|
||||
}
|
||||
|
||||
self.read_user_metadata(item)
|
||||
@@ -53,26 +52,6 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
||||
if negative_prompt:
|
||||
item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)')
|
||||
|
||||
sd_version = item["user_metadata"].get("sd version")
|
||||
if sd_version in network.SdVersion.__members__:
|
||||
item["sd_version"] = sd_version
|
||||
sd_version = network.SdVersion[sd_version]
|
||||
else:
|
||||
sd_version = lora_on_disk.sd_version
|
||||
|
||||
if shared.opts.lora_filter_disabled or not enable_filter or not shared.sd_model:
|
||||
pass
|
||||
elif sd_version == network.SdVersion.Unknown:
|
||||
model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1
|
||||
if model_version.name in shared.opts.lora_hide_unknown_for_versions:
|
||||
return None
|
||||
elif shared.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL:
|
||||
return None
|
||||
elif shared.sd_model.is_sd2 and sd_version != network.SdVersion.SD2:
|
||||
return None
|
||||
elif shared.sd_model.is_sd1 and sd_version != network.SdVersion.SD1:
|
||||
return None
|
||||
|
||||
return item
|
||||
|
||||
def list_items(self):
|
||||
@@ -1,226 +1 @@
|
||||
# 1st edit by https://github.com/comfyanonymous/ComfyUI
|
||||
# 2nd edit by Forge Official
|
||||
|
||||
|
||||
import ldm_patched.modules.utils
|
||||
|
||||
LORA_CLIP_MAP = {
|
||||
"mlp.fc1": "mlp_fc1",
|
||||
"mlp.fc2": "mlp_fc2",
|
||||
"self_attn.k_proj": "self_attn_k_proj",
|
||||
"self_attn.q_proj": "self_attn_q_proj",
|
||||
"self_attn.v_proj": "self_attn_v_proj",
|
||||
"self_attn.out_proj": "self_attn_out_proj",
|
||||
}
|
||||
|
||||
|
||||
def load_lora(lora, to_load):
|
||||
patch_dict = {}
|
||||
loaded_keys = set()
|
||||
for x in to_load:
|
||||
alpha_name = "{}.alpha".format(x)
|
||||
alpha = None
|
||||
if alpha_name in lora.keys():
|
||||
alpha = lora[alpha_name].item()
|
||||
loaded_keys.add(alpha_name)
|
||||
|
||||
regular_lora = "{}.lora_up.weight".format(x)
|
||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||
A_name = None
|
||||
|
||||
if regular_lora in lora.keys():
|
||||
A_name = regular_lora
|
||||
B_name = "{}.lora_down.weight".format(x)
|
||||
mid_name = "{}.lora_mid.weight".format(x)
|
||||
elif diffusers_lora in lora.keys():
|
||||
A_name = diffusers_lora
|
||||
B_name = "{}_lora.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif transformers_lora in lora.keys():
|
||||
A_name = transformers_lora
|
||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
||||
mid_name = None
|
||||
|
||||
if A_name is not None:
|
||||
mid = None
|
||||
if mid_name is not None and mid_name in lora.keys():
|
||||
mid = lora[mid_name]
|
||||
loaded_keys.add(mid_name)
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
|
||||
loaded_keys.add(A_name)
|
||||
loaded_keys.add(B_name)
|
||||
|
||||
|
||||
######## loha
|
||||
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
||||
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
||||
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
||||
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
||||
hada_t1_name = "{}.hada_t1".format(x)
|
||||
hada_t2_name = "{}.hada_t2".format(x)
|
||||
if hada_w1_a_name in lora.keys():
|
||||
hada_t1 = None
|
||||
hada_t2 = None
|
||||
if hada_t1_name in lora.keys():
|
||||
hada_t1 = lora[hada_t1_name]
|
||||
hada_t2 = lora[hada_t2_name]
|
||||
loaded_keys.add(hada_t1_name)
|
||||
loaded_keys.add(hada_t2_name)
|
||||
|
||||
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
|
||||
loaded_keys.add(hada_w1_a_name)
|
||||
loaded_keys.add(hada_w1_b_name)
|
||||
loaded_keys.add(hada_w2_a_name)
|
||||
loaded_keys.add(hada_w2_b_name)
|
||||
|
||||
|
||||
######## lokr
|
||||
lokr_w1_name = "{}.lokr_w1".format(x)
|
||||
lokr_w2_name = "{}.lokr_w2".format(x)
|
||||
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
||||
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
||||
lokr_t2_name = "{}.lokr_t2".format(x)
|
||||
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
||||
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
||||
|
||||
lokr_w1 = None
|
||||
if lokr_w1_name in lora.keys():
|
||||
lokr_w1 = lora[lokr_w1_name]
|
||||
loaded_keys.add(lokr_w1_name)
|
||||
|
||||
lokr_w2 = None
|
||||
if lokr_w2_name in lora.keys():
|
||||
lokr_w2 = lora[lokr_w2_name]
|
||||
loaded_keys.add(lokr_w2_name)
|
||||
|
||||
lokr_w1_a = None
|
||||
if lokr_w1_a_name in lora.keys():
|
||||
lokr_w1_a = lora[lokr_w1_a_name]
|
||||
loaded_keys.add(lokr_w1_a_name)
|
||||
|
||||
lokr_w1_b = None
|
||||
if lokr_w1_b_name in lora.keys():
|
||||
lokr_w1_b = lora[lokr_w1_b_name]
|
||||
loaded_keys.add(lokr_w1_b_name)
|
||||
|
||||
lokr_w2_a = None
|
||||
if lokr_w2_a_name in lora.keys():
|
||||
lokr_w2_a = lora[lokr_w2_a_name]
|
||||
loaded_keys.add(lokr_w2_a_name)
|
||||
|
||||
lokr_w2_b = None
|
||||
if lokr_w2_b_name in lora.keys():
|
||||
lokr_w2_b = lora[lokr_w2_b_name]
|
||||
loaded_keys.add(lokr_w2_b_name)
|
||||
|
||||
lokr_t2 = None
|
||||
if lokr_t2_name in lora.keys():
|
||||
lokr_t2 = lora[lokr_t2_name]
|
||||
loaded_keys.add(lokr_t2_name)
|
||||
|
||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
|
||||
|
||||
#glora
|
||||
a1_name = "{}.a1.weight".format(x)
|
||||
a2_name = "{}.a2.weight".format(x)
|
||||
b1_name = "{}.b1.weight".format(x)
|
||||
b2_name = "{}.b2.weight".format(x)
|
||||
if a1_name in lora:
|
||||
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
|
||||
loaded_keys.add(a1_name)
|
||||
loaded_keys.add(a2_name)
|
||||
loaded_keys.add(b1_name)
|
||||
loaded_keys.add(b2_name)
|
||||
|
||||
w_norm_name = "{}.w_norm".format(x)
|
||||
b_norm_name = "{}.b_norm".format(x)
|
||||
w_norm = lora.get(w_norm_name, None)
|
||||
b_norm = lora.get(b_norm_name, None)
|
||||
|
||||
if w_norm is not None:
|
||||
loaded_keys.add(w_norm_name)
|
||||
patch_dict[to_load[x]] = ("diff", (w_norm,))
|
||||
if b_norm is not None:
|
||||
loaded_keys.add(b_norm_name)
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
|
||||
|
||||
diff_name = "{}.diff".format(x)
|
||||
diff_weight = lora.get(diff_name, None)
|
||||
if diff_weight is not None:
|
||||
patch_dict[to_load[x]] = ("diff", (diff_weight,))
|
||||
loaded_keys.add(diff_name)
|
||||
|
||||
diff_bias_name = "{}.diff_b".format(x)
|
||||
diff_bias = lora.get(diff_bias_name, None)
|
||||
if diff_bias is not None:
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
||||
loaded_keys.add(diff_bias_name)
|
||||
|
||||
remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
|
||||
return patch_dict, remaining_dict
|
||||
|
||||
def model_lora_keys_clip(model, key_map={}):
|
||||
sdk = model.state_dict().keys()
|
||||
|
||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||
clip_l_present = False
|
||||
for b in range(32): #TODO: clean up
|
||||
for c in LORA_CLIP_MAP:
|
||||
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
||||
key_map[lora_key] = k
|
||||
|
||||
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||
key_map[lora_key] = k
|
||||
clip_l_present = True
|
||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
||||
key_map[lora_key] = k
|
||||
|
||||
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
if clip_l_present:
|
||||
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||
key_map[lora_key] = k
|
||||
lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
||||
key_map[lora_key] = k
|
||||
else:
|
||||
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
|
||||
key_map[lora_key] = k
|
||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
||||
key_map[lora_key] = k
|
||||
|
||||
return key_map
|
||||
|
||||
def model_lora_keys_unet(model, key_map={}):
|
||||
sdk = model.state_dict().keys()
|
||||
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
|
||||
diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(model.diffusion_model.legacy_config)
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
||||
|
||||
diffusers_lora_prefix = ["", "unet."]
|
||||
for p in diffusers_lora_prefix:
|
||||
diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
|
||||
if diffusers_lora_key.endswith(".to_out.0"):
|
||||
diffusers_lora_key = diffusers_lora_key[:-2]
|
||||
key_map[diffusers_lora_key] = unet_key
|
||||
return key_map
|
||||
from backend.patcher.lora import *
|
||||
|
||||
@@ -1,386 +1 @@
|
||||
# 1st edit by https://github.com/comfyanonymous/ComfyUI
|
||||
# 2nd edit by Forge Official
|
||||
|
||||
|
||||
import torch
|
||||
import copy
|
||||
import inspect
|
||||
|
||||
import ldm_patched.modules.utils
|
||||
import ldm_patched.modules.model_management
|
||||
|
||||
|
||||
extra_weight_calculators = {}
|
||||
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.patches = {}
|
||||
self.backup = {}
|
||||
self.object_patches = {}
|
||||
self.object_patches_backup = {}
|
||||
self.model_options = {"transformer_options":{}}
|
||||
self.model_size()
|
||||
self.load_device = load_device
|
||||
self.offload_device = offload_device
|
||||
if current_device is None:
|
||||
self.current_device = self.offload_device
|
||||
else:
|
||||
self.current_device = current_device
|
||||
|
||||
self.weight_inplace_update = weight_inplace_update
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
model_sd = self.model.state_dict()
|
||||
self.size = ldm_patched.modules.model_management.module_size(self.model)
|
||||
self.model_keys = set(model_sd.keys())
|
||||
return self.size
|
||||
|
||||
def clone(self):
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.model_keys = self.model_keys
|
||||
return n
|
||||
|
||||
def is_clone(self, other):
|
||||
if hasattr(other, 'model') and self.model is other.model:
|
||||
return True
|
||||
return False
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
return self.model.memory_required(input_shape=input_shape)
|
||||
|
||||
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||
else:
|
||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||
if disable_cfg1_optimization:
|
||||
self.model_options["disable_cfg1_optimization"] = True
|
||||
|
||||
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
||||
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
|
||||
if disable_cfg1_optimization:
|
||||
self.model_options["disable_cfg1_optimization"] = True
|
||||
|
||||
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||
|
||||
def set_model_vae_encode_wrapper(self, wrapper_function):
|
||||
self.model_options["model_vae_encode_wrapper"] = wrapper_function
|
||||
|
||||
def set_model_vae_decode_wrapper(self, wrapper_function):
|
||||
self.model_options["model_vae_decode_wrapper"] = wrapper_function
|
||||
|
||||
def set_model_vae_regulation(self, vae_regulation):
|
||||
self.model_options["model_vae_regulation"] = vae_regulation
|
||||
|
||||
def set_model_patch(self, patch, name):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" not in to:
|
||||
to["patches"] = {}
|
||||
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
||||
|
||||
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches_replace" not in to:
|
||||
to["patches_replace"] = {}
|
||||
if name not in to["patches_replace"]:
|
||||
to["patches_replace"][name] = {}
|
||||
if transformer_index is not None:
|
||||
block = (block_name, number, transformer_index)
|
||||
else:
|
||||
block = (block_name, number)
|
||||
to["patches_replace"][name][block] = patch
|
||||
|
||||
def set_model_attn1_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn1_patch")
|
||||
|
||||
def set_model_attn2_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn2_patch")
|
||||
|
||||
def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
|
||||
self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
|
||||
|
||||
def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
|
||||
self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
|
||||
|
||||
def set_model_attn1_output_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn1_output_patch")
|
||||
|
||||
def set_model_attn2_output_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn2_output_patch")
|
||||
|
||||
def set_model_input_block_patch(self, patch):
|
||||
self.set_model_patch(patch, "input_block_patch")
|
||||
|
||||
def set_model_input_block_patch_after_skip(self, patch):
|
||||
self.set_model_patch(patch, "input_block_patch_after_skip")
|
||||
|
||||
def set_model_output_block_patch(self, patch):
|
||||
self.set_model_patch(patch, "output_block_patch")
|
||||
|
||||
def add_object_patch(self, name, obj):
|
||||
self.object_patches[name] = obj
|
||||
|
||||
def model_patches_to(self, device):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" in to:
|
||||
patches = to["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], "to"):
|
||||
patch_list[i] = patch_list[i].to(device)
|
||||
if "patches_replace" in to:
|
||||
patches = to["patches_replace"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for k in patch_list:
|
||||
if hasattr(patch_list[k], "to"):
|
||||
patch_list[k] = patch_list[k].to(device)
|
||||
if "model_function_wrapper" in self.model_options:
|
||||
wrap_func = self.model_options["model_function_wrapper"]
|
||||
if hasattr(wrap_func, "to"):
|
||||
self.model_options["model_function_wrapper"] = wrap_func.to(device)
|
||||
|
||||
def model_dtype(self):
|
||||
if hasattr(self.model, "get_dtype"):
|
||||
return self.model.get_dtype()
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
p = set()
|
||||
for k in patches:
|
||||
if k in self.model_keys:
|
||||
p.add(k)
|
||||
current_patches = self.patches.get(k, [])
|
||||
current_patches.append((strength_patch, patches[k], strength_model))
|
||||
self.patches[k] = current_patches
|
||||
|
||||
return list(p)
|
||||
|
||||
def get_key_patches(self, filter_prefix=None):
|
||||
ldm_patched.modules.model_management.unload_model_clones(self)
|
||||
model_sd = self.model_state_dict()
|
||||
p = {}
|
||||
for k in model_sd:
|
||||
if filter_prefix is not None:
|
||||
if not k.startswith(filter_prefix):
|
||||
continue
|
||||
if k in self.patches:
|
||||
p[k] = [model_sd[k]] + self.patches[k]
|
||||
else:
|
||||
p[k] = (model_sd[k],)
|
||||
return p
|
||||
|
||||
def model_state_dict(self, filter_prefix=None):
|
||||
sd = self.model.state_dict()
|
||||
keys = list(sd.keys())
|
||||
if filter_prefix is not None:
|
||||
for k in keys:
|
||||
if not k.startswith(filter_prefix):
|
||||
sd.pop(k)
|
||||
return sd
|
||||
|
||||
def patch_model(self, device_to=None, patch_weights=True):
|
||||
for k in self.object_patches:
|
||||
old = ldm_patched.modules.utils.get_attr(self.model, k)
|
||||
if k not in self.object_patches_backup:
|
||||
self.object_patches_backup[k] = old
|
||||
ldm_patched.modules.utils.set_attr_raw(self.model, k, self.object_patches[k])
|
||||
|
||||
if patch_weights:
|
||||
model_sd = self.model_state_dict()
|
||||
for key in self.patches:
|
||||
if key not in model_sd:
|
||||
print("could not patch. key doesn't exist in model:", key)
|
||||
continue
|
||||
|
||||
weight = model_sd[key]
|
||||
|
||||
inplace_update = self.weight_inplace_update
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
||||
|
||||
if device_to is not None:
|
||||
temp_weight = ldm_patched.modules.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||
if inplace_update:
|
||||
ldm_patched.modules.utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
ldm_patched.modules.utils.set_attr(self.model, key, out_weight)
|
||||
del temp_weight
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
|
||||
return self.model
|
||||
|
||||
def calculate_weight(self, patches, weight, key):
|
||||
for p in patches:
|
||||
alpha = p[0]
|
||||
v = p[1]
|
||||
strength_model = p[2]
|
||||
|
||||
if strength_model != 1.0:
|
||||
weight *= strength_model
|
||||
|
||||
if isinstance(v, list):
|
||||
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
|
||||
|
||||
if len(v) == 1:
|
||||
patch_type = "diff"
|
||||
elif len(v) == 2:
|
||||
patch_type = v[0]
|
||||
v = v[1]
|
||||
|
||||
if patch_type == "diff":
|
||||
w1 = v[0]
|
||||
if alpha != 0.0:
|
||||
if w1.shape != weight.shape:
|
||||
if w1.ndim == weight.ndim == 4:
|
||||
new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)]
|
||||
print(f'Merged with {key} channel changed to {new_shape}')
|
||||
new_diff = alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||
new_weight = torch.zeros(size=new_shape).to(weight)
|
||||
new_weight[:weight.shape[0], :weight.shape[1], :weight.shape[2], :weight.shape[3]] = weight
|
||||
new_weight[:new_diff.shape[0], :new_diff.shape[1], :new_diff.shape[2], :new_diff.shape[3]] += new_diff
|
||||
new_weight = new_weight.contiguous().clone()
|
||||
weight = new_weight
|
||||
else:
|
||||
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||
else:
|
||||
weight += alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||
elif patch_type == "lora": #lora/locon
|
||||
mat1 = ldm_patched.modules.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
||||
mat2 = ldm_patched.modules.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
||||
if v[2] is not None:
|
||||
alpha *= v[2] / mat2.shape[0]
|
||||
if v[3] is not None:
|
||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||
mat3 = ldm_patched.modules.model_management.cast_to_device(v[3], weight.device, torch.float32)
|
||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||
try:
|
||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
elif patch_type == "lokr":
|
||||
w1 = v[0]
|
||||
w2 = v[1]
|
||||
w1_a = v[3]
|
||||
w1_b = v[4]
|
||||
w2_a = v[5]
|
||||
w2_b = v[6]
|
||||
t2 = v[7]
|
||||
dim = None
|
||||
|
||||
if w1 is None:
|
||||
dim = w1_b.shape[0]
|
||||
w1 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w1_a, weight.device, torch.float32),
|
||||
ldm_patched.modules.model_management.cast_to_device(w1_b, weight.device, torch.float32))
|
||||
else:
|
||||
w1 = ldm_patched.modules.model_management.cast_to_device(w1, weight.device, torch.float32)
|
||||
|
||||
if w2 is None:
|
||||
dim = w2_b.shape[0]
|
||||
if t2 is None:
|
||||
w2 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w2_a, weight.device, torch.float32),
|
||||
ldm_patched.modules.model_management.cast_to_device(w2_b, weight.device, torch.float32))
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
ldm_patched.modules.model_management.cast_to_device(t2, weight.device, torch.float32),
|
||||
ldm_patched.modules.model_management.cast_to_device(w2_b, weight.device, torch.float32),
|
||||
ldm_patched.modules.model_management.cast_to_device(w2_a, weight.device, torch.float32))
|
||||
else:
|
||||
w2 = ldm_patched.modules.model_management.cast_to_device(w2, weight.device, torch.float32)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
if v[2] is not None and dim is not None:
|
||||
alpha *= v[2] / dim
|
||||
|
||||
try:
|
||||
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
elif patch_type == "loha":
|
||||
w1a = v[0]
|
||||
w1b = v[1]
|
||||
if v[2] is not None:
|
||||
alpha *= v[2] / w1b.shape[0]
|
||||
w2a = v[3]
|
||||
w2b = v[4]
|
||||
if v[5] is not None: #cp decomposition
|
||||
t1 = v[5]
|
||||
t2 = v[6]
|
||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
ldm_patched.modules.model_management.cast_to_device(t1, weight.device, torch.float32),
|
||||
ldm_patched.modules.model_management.cast_to_device(w1b, weight.device, torch.float32),
|
||||
ldm_patched.modules.model_management.cast_to_device(w1a, weight.device, torch.float32))
|
||||
|
||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
ldm_patched.modules.model_management.cast_to_device(t2, weight.device, torch.float32),
|
||||
ldm_patched.modules.model_management.cast_to_device(w2b, weight.device, torch.float32),
|
||||
ldm_patched.modules.model_management.cast_to_device(w2a, weight.device, torch.float32))
|
||||
else:
|
||||
m1 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w1a, weight.device, torch.float32),
|
||||
ldm_patched.modules.model_management.cast_to_device(w1b, weight.device, torch.float32))
|
||||
m2 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w2a, weight.device, torch.float32),
|
||||
ldm_patched.modules.model_management.cast_to_device(w2b, weight.device, torch.float32))
|
||||
|
||||
try:
|
||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
elif patch_type == "glora":
|
||||
if v[4] is not None:
|
||||
alpha *= v[4] / v[0].shape[0]
|
||||
|
||||
a1 = ldm_patched.modules.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
|
||||
a2 = ldm_patched.modules.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
|
||||
b1 = ldm_patched.modules.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
|
||||
b2 = ldm_patched.modules.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
||||
|
||||
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
|
||||
elif patch_type in extra_weight_calculators:
|
||||
weight = extra_weight_calculators[patch_type](weight, alpha, v)
|
||||
else:
|
||||
print("patch type not recognized", patch_type, key)
|
||||
|
||||
return weight
|
||||
|
||||
def unpatch_model(self, device_to=None):
|
||||
keys = list(self.backup.keys())
|
||||
|
||||
if self.weight_inplace_update:
|
||||
for k in keys:
|
||||
ldm_patched.modules.utils.copy_to_param(self.model, k, self.backup[k])
|
||||
else:
|
||||
for k in keys:
|
||||
ldm_patched.modules.utils.set_attr(self.model, k, self.backup[k])
|
||||
|
||||
self.backup = {}
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
|
||||
keys = list(self.object_patches_backup.keys())
|
||||
for k in keys:
|
||||
ldm_patched.modules.utils.set_attr_raw(self.model, k, self.object_patches_backup[k])
|
||||
|
||||
self.object_patches_backup = {}
|
||||
from backend.patcher.base import *
|
||||
|
||||
@@ -25,7 +25,6 @@ class UnetPatcher(ModelPatcher):
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.model_keys = self.model_keys
|
||||
n.controlnet_linked_list = self.controlnet_linked_list
|
||||
n.extra_preserved_memory_during_sampling = self.extra_preserved_memory_during_sampling
|
||||
n.extra_model_patchers_during_sampling = self.extra_model_patchers_during_sampling.copy()
|
||||
|
||||
Reference in New Issue
Block a user