mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
some extra lora support, inc. new glora (#2715)
* support new glora (via ComfyUI) * support BFL FluxTools loras (mostly via ComfyUI) * also support using loras (like Hyper, Turbo) with FluxTools models
This commit is contained in:
@@ -129,10 +129,15 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
|
||||
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||
else:
|
||||
weight += strength * memory_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||
|
||||
elif patch_type == "set":
|
||||
weight.copy_(v[0])
|
||||
|
||||
elif patch_type == "lora":
|
||||
mat1 = memory_management.cast_to_device(v[0], weight.device, computation_dtype)
|
||||
mat2 = memory_management.cast_to_device(v[1], weight.device, computation_dtype)
|
||||
dora_scale = v[4]
|
||||
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / mat2.shape[0]
|
||||
else:
|
||||
@@ -142,12 +147,26 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
|
||||
mat3 = memory_management.cast_to_device(v[3], weight.device, computation_dtype)
|
||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||
|
||||
try:
|
||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))
|
||||
|
||||
try:
|
||||
lora_diff = lora_diff.reshape(weight.shape)
|
||||
except:
|
||||
if weight.shape[1] < lora_diff.shape[1]:
|
||||
expand_factor = (lora_diff.shape[1] - weight.shape[1])
|
||||
weight = torch.nn.functional.pad(weight, (0, expand_factor), mode='constant', value=0)
|
||||
elif weight.shape[1] > lora_diff.shape[1]:
|
||||
# expand factor should be 1*64 (for FluxTools Canny or Depth), or 5*64 (for FluxTools Fill)
|
||||
expand_factor = (weight.shape[1] - lora_diff.shape[1])
|
||||
lora_diff = torch.nn.functional.pad(lora_diff, (0, expand_factor), mode='constant', value=0)
|
||||
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
|
||||
except Exception as e:
|
||||
print("ERROR {} {} {}".format(patch_type, key, e))
|
||||
raise e
|
||||
@@ -236,23 +255,45 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
|
||||
except Exception as e:
|
||||
print("ERROR {} {} {}".format(patch_type, key, e))
|
||||
raise e
|
||||
elif patch_type == "glora":
|
||||
if v[4] is not None:
|
||||
alpha = v[4] / v[0].shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
elif patch_type == "glora":
|
||||
dora_scale = v[5]
|
||||
|
||||
old_glora = False
|
||||
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
||||
old_glora = True
|
||||
|
||||
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||
pass
|
||||
else:
|
||||
old_glora = False
|
||||
|
||||
a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, computation_dtype)
|
||||
a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, computation_dtype)
|
||||
b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, computation_dtype)
|
||||
b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, computation_dtype)
|
||||
|
||||
if v[4] is None:
|
||||
alpha = 1.0
|
||||
else:
|
||||
if old_glora:
|
||||
alpha = v[4] / v[0].shape[0]
|
||||
else:
|
||||
alpha = v[4] / v[1].shape[0]
|
||||
|
||||
try:
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
||||
if old_glora:
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=computation_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||
else:
|
||||
if weight.dim() > 2:
|
||||
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=computation_dtype), a1), a2).reshape(weight.shape)
|
||||
else:
|
||||
lora_diff = torch.mm(torch.mm(weight.to(dtype=computation_dtype), a1), a2).reshape(weight.shape)
|
||||
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
@@ -299,7 +340,7 @@ class LoraLoader:
|
||||
self.loaded_hash = str([])
|
||||
|
||||
@torch.inference_mode()
|
||||
def refresh(self, lora_patches, offload_device=torch.device('cpu'), force_refresh = False):
|
||||
def refresh(self, lora_patches, offload_device=torch.device('cpu'), force_refresh=False):
|
||||
hashes = str(list(lora_patches.keys()))
|
||||
|
||||
if hashes == self.loaded_hash and not force_refresh:
|
||||
|
||||
@@ -30,6 +30,19 @@ LORA_CLIP_MAP = {
|
||||
|
||||
|
||||
def load_lora(lora, to_load):
|
||||
# BFL loras for Flux; from ComfyUI: comfy/lora_convert.py
|
||||
def convert_lora_bfl_control(sd):
|
||||
import torch
|
||||
sd_out = {}
|
||||
for k in sd:
|
||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
|
||||
sd_out[k_to] = sd[k]
|
||||
|
||||
return sd_out
|
||||
|
||||
if "img_in.lora_A.weight" in lora and "single_blocks.0.norm.key_norm.scale" in lora:
|
||||
lora = convert_lora_bfl_control(lora)
|
||||
|
||||
patch_dict = {}
|
||||
loaded_keys = set()
|
||||
for x in to_load:
|
||||
@@ -189,6 +202,12 @@ def load_lora(lora, to_load):
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
||||
loaded_keys.add(diff_bias_name)
|
||||
|
||||
set_weight_name = "{}.set_weight".format(x)
|
||||
set_weight = lora.get(set_weight_name, None)
|
||||
if set_weight is not None:
|
||||
patch_dict[to_load[x]] = ("set", (set_weight,))
|
||||
loaded_keys.add(set_weight_name)
|
||||
|
||||
remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
|
||||
return patch_dict, remaining_dict
|
||||
|
||||
@@ -269,11 +288,13 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
sdk = sd.keys()
|
||||
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
if k.startswith("diffusion_model."):
|
||||
if k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
else:
|
||||
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
|
||||
|
||||
diffusers_keys = utils.unet_to_diffusers(model.diffusion_model.config)
|
||||
for k in diffusers_keys:
|
||||
@@ -281,7 +302,8 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
||||
|
||||
key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format
|
||||
|
||||
diffusers_lora_prefix = ["", "unet."]
|
||||
for p in diffusers_lora_prefix:
|
||||
diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
|
||||
@@ -289,19 +311,19 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
diffusers_lora_key = diffusers_lora_key[:-2]
|
||||
key_map[diffusers_lora_key] = unet_key
|
||||
|
||||
# if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3
|
||||
# diffusers_keys = utils.mmdit_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.")
|
||||
# for k in diffusers_keys:
|
||||
# if k.endswith(".weight"):
|
||||
# to = diffusers_keys[k]
|
||||
# key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format
|
||||
# key_map[key_lora] = to
|
||||
#
|
||||
# key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others?
|
||||
# key_map[key_lora] = to
|
||||
#
|
||||
# key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
||||
# key_map[key_lora] = to
|
||||
# if 'stable-diffusion-3' in model.config.huggingface_repo.lower(): #Diffusers lora SD3
|
||||
# diffusers_keys = utils.mmdit_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.")
|
||||
# for k in diffusers_keys:
|
||||
# if k.endswith(".weight"):
|
||||
# to = diffusers_keys[k]
|
||||
# key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format
|
||||
# key_map[key_lora] = to
|
||||
|
||||
# key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others?
|
||||
# key_map[key_lora] = to
|
||||
|
||||
# key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
||||
# key_map[key_lora] = to
|
||||
#
|
||||
# if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
|
||||
# diffusers_keys = utils.auraflow_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.")
|
||||
|
||||
Reference in New Issue
Block a user