mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-28 02:33:57 +00:00
multiple lora implementation sources
This commit is contained in:
@@ -1,13 +1,15 @@
|
||||
import torch
|
||||
|
||||
from backend import memory_management, attention, operations
|
||||
from backend import memory_management, attention
|
||||
from backend.modules.k_prediction import k_prediction_from_diffusers_scheduler
|
||||
|
||||
|
||||
class KModel(torch.nn.Module):
|
||||
def __init__(self, model, diffusers_scheduler, k_predictor=None):
|
||||
def __init__(self, model, diffusers_scheduler, k_predictor=None, config=None):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.storage_dtype = model.storage_dtype
|
||||
self.computation_dtype = model.computation_dtype
|
||||
|
||||
|
||||
@@ -1,293 +1,31 @@
|
||||
# LoRA Implementation Collection form ComfyUI
|
||||
# Modified by Forge to support greedy loading (load a set or wrong/correct loras to a model and only preserve the correct ones),
|
||||
# which is important to webui experience
|
||||
|
||||
from backend.misc.diffusers_state_dict import unet_to_diffusers
|
||||
import packages_3rdparty.webui_lora_collection.lora as lora_utils_webui
|
||||
import packages_3rdparty.comfyui_lora_collection.lora as lora_utils_comfyui
|
||||
|
||||
|
||||
LORA_CLIP_MAP = {
|
||||
"mlp.fc1": "mlp_fc1",
|
||||
"mlp.fc2": "mlp_fc2",
|
||||
"self_attn.k_proj": "self_attn_k_proj",
|
||||
"self_attn.q_proj": "self_attn_q_proj",
|
||||
"self_attn.v_proj": "self_attn_v_proj",
|
||||
"self_attn.out_proj": "self_attn_out_proj",
|
||||
}
|
||||
class ForgeLoraCollection:
|
||||
# TODO
|
||||
pass
|
||||
|
||||
|
||||
lora_utils_forge = ForgeLoraCollection()
|
||||
|
||||
lora_collection_priority = [lora_utils_forge, lora_utils_webui, lora_utils_comfyui]
|
||||
|
||||
|
||||
def get_function(function_name: str):
|
||||
for lora_collection in lora_collection_priority:
|
||||
if hasattr(lora_collection, function_name):
|
||||
return getattr(lora_collection, function_name)
|
||||
|
||||
|
||||
def load_lora(lora, to_load):
|
||||
patch_dict = {}
|
||||
loaded_keys = set()
|
||||
for x in to_load:
|
||||
alpha_name = "{}.alpha".format(x)
|
||||
alpha = None
|
||||
if alpha_name in lora.keys():
|
||||
alpha = lora[alpha_name].item()
|
||||
loaded_keys.add(alpha_name)
|
||||
|
||||
dora_scale_name = "{}.dora_scale".format(x)
|
||||
dora_scale = None
|
||||
if dora_scale_name in lora.keys():
|
||||
dora_scale = lora[dora_scale_name]
|
||||
loaded_keys.add(dora_scale_name)
|
||||
|
||||
regular_lora = "{}.lora_up.weight".format(x)
|
||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||
A_name = None
|
||||
|
||||
if regular_lora in lora.keys():
|
||||
A_name = regular_lora
|
||||
B_name = "{}.lora_down.weight".format(x)
|
||||
mid_name = "{}.lora_mid.weight".format(x)
|
||||
elif diffusers_lora in lora.keys():
|
||||
A_name = diffusers_lora
|
||||
B_name = "{}_lora.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif diffusers2_lora in lora.keys():
|
||||
A_name = diffusers2_lora
|
||||
B_name = "{}.lora_A.weight".format(x)
|
||||
mid_name = None
|
||||
elif diffusers3_lora in lora.keys():
|
||||
A_name = diffusers3_lora
|
||||
B_name = "{}.lora.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif transformers_lora in lora.keys():
|
||||
A_name = transformers_lora
|
||||
B_name = "{}.lora_linear_layer.down.weight".format(x)
|
||||
mid_name = None
|
||||
|
||||
if A_name is not None:
|
||||
mid = None
|
||||
if mid_name is not None and mid_name in lora.keys():
|
||||
mid = lora[mid_name]
|
||||
loaded_keys.add(mid_name)
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
|
||||
loaded_keys.add(A_name)
|
||||
loaded_keys.add(B_name)
|
||||
|
||||
######## loha
|
||||
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
||||
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
||||
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
||||
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
||||
hada_t1_name = "{}.hada_t1".format(x)
|
||||
hada_t2_name = "{}.hada_t2".format(x)
|
||||
if hada_w1_a_name in lora.keys():
|
||||
hada_t1 = None
|
||||
hada_t2 = None
|
||||
if hada_t1_name in lora.keys():
|
||||
hada_t1 = lora[hada_t1_name]
|
||||
hada_t2 = lora[hada_t2_name]
|
||||
loaded_keys.add(hada_t1_name)
|
||||
loaded_keys.add(hada_t2_name)
|
||||
|
||||
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
|
||||
loaded_keys.add(hada_w1_a_name)
|
||||
loaded_keys.add(hada_w1_b_name)
|
||||
loaded_keys.add(hada_w2_a_name)
|
||||
loaded_keys.add(hada_w2_b_name)
|
||||
|
||||
######## lokr
|
||||
lokr_w1_name = "{}.lokr_w1".format(x)
|
||||
lokr_w2_name = "{}.lokr_w2".format(x)
|
||||
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
||||
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
||||
lokr_t2_name = "{}.lokr_t2".format(x)
|
||||
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
||||
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
||||
|
||||
lokr_w1 = None
|
||||
if lokr_w1_name in lora.keys():
|
||||
lokr_w1 = lora[lokr_w1_name]
|
||||
loaded_keys.add(lokr_w1_name)
|
||||
|
||||
lokr_w2 = None
|
||||
if lokr_w2_name in lora.keys():
|
||||
lokr_w2 = lora[lokr_w2_name]
|
||||
loaded_keys.add(lokr_w2_name)
|
||||
|
||||
lokr_w1_a = None
|
||||
if lokr_w1_a_name in lora.keys():
|
||||
lokr_w1_a = lora[lokr_w1_a_name]
|
||||
loaded_keys.add(lokr_w1_a_name)
|
||||
|
||||
lokr_w1_b = None
|
||||
if lokr_w1_b_name in lora.keys():
|
||||
lokr_w1_b = lora[lokr_w1_b_name]
|
||||
loaded_keys.add(lokr_w1_b_name)
|
||||
|
||||
lokr_w2_a = None
|
||||
if lokr_w2_a_name in lora.keys():
|
||||
lokr_w2_a = lora[lokr_w2_a_name]
|
||||
loaded_keys.add(lokr_w2_a_name)
|
||||
|
||||
lokr_w2_b = None
|
||||
if lokr_w2_b_name in lora.keys():
|
||||
lokr_w2_b = lora[lokr_w2_b_name]
|
||||
loaded_keys.add(lokr_w2_b_name)
|
||||
|
||||
lokr_t2 = None
|
||||
if lokr_t2_name in lora.keys():
|
||||
lokr_t2 = lora[lokr_t2_name]
|
||||
loaded_keys.add(lokr_t2_name)
|
||||
|
||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
|
||||
|
||||
# glora
|
||||
a1_name = "{}.a1.weight".format(x)
|
||||
a2_name = "{}.a2.weight".format(x)
|
||||
b1_name = "{}.b1.weight".format(x)
|
||||
b2_name = "{}.b2.weight".format(x)
|
||||
if a1_name in lora:
|
||||
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
|
||||
loaded_keys.add(a1_name)
|
||||
loaded_keys.add(a2_name)
|
||||
loaded_keys.add(b1_name)
|
||||
loaded_keys.add(b2_name)
|
||||
|
||||
w_norm_name = "{}.w_norm".format(x)
|
||||
b_norm_name = "{}.b_norm".format(x)
|
||||
w_norm = lora.get(w_norm_name, None)
|
||||
b_norm = lora.get(b_norm_name, None)
|
||||
|
||||
if w_norm is not None:
|
||||
loaded_keys.add(w_norm_name)
|
||||
patch_dict[to_load[x]] = ("diff", (w_norm,))
|
||||
if b_norm is not None:
|
||||
loaded_keys.add(b_norm_name)
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
|
||||
|
||||
diff_name = "{}.diff".format(x)
|
||||
diff_weight = lora.get(diff_name, None)
|
||||
if diff_weight is not None:
|
||||
patch_dict[to_load[x]] = ("diff", (diff_weight,))
|
||||
loaded_keys.add(diff_name)
|
||||
|
||||
diff_bias_name = "{}.diff_b".format(x)
|
||||
diff_bias = lora.get(diff_bias_name, None)
|
||||
if diff_bias is not None:
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
||||
loaded_keys.add(diff_bias_name)
|
||||
|
||||
remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
|
||||
patch_dict, remaining_dict = get_function('load_lora')(lora, to_load)
|
||||
return patch_dict, remaining_dict
|
||||
|
||||
|
||||
def model_lora_keys_clip(model, key_map={}):
|
||||
sdk = model.state_dict().keys()
|
||||
|
||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||
clip_l_present = False
|
||||
for b in range(32): # TODO: clean up
|
||||
for c in LORA_CLIP_MAP:
|
||||
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c)
|
||||
key_map[lora_key] = k
|
||||
|
||||
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
clip_l_present = True
|
||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c)
|
||||
key_map[lora_key] = k
|
||||
|
||||
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
if clip_l_present:
|
||||
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c)
|
||||
key_map[lora_key] = k
|
||||
else:
|
||||
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c)
|
||||
key_map[lora_key] = k
|
||||
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
|
||||
for k in sdk: # OneTrainer SD3 lora
|
||||
if k.startswith("t5xxl.transformer.") and k.endswith(".weight"):
|
||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
||||
key_map[lora_key] = k
|
||||
|
||||
k = "clip_g.transformer.text_projection.weight"
|
||||
if k in sdk:
|
||||
key_map["lora_prior_te_text_projection"] = k
|
||||
key_map["lora_te2_text_projection"] = k
|
||||
|
||||
k = "clip_l.transformer.text_projection.weight"
|
||||
if k in sdk:
|
||||
key_map["lora_te1_text_projection"] = k
|
||||
|
||||
return key_map
|
||||
return get_function('model_lora_keys_clip')(model, key_map)
|
||||
|
||||
|
||||
def model_lora_keys_unet(model, key_map={}):
|
||||
sd = model.state_dict()
|
||||
sdk = sd.keys()
|
||||
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k
|
||||
|
||||
diffusers_keys = unet_to_diffusers(model.diffusion_model.config)
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
||||
|
||||
diffusers_lora_prefix = ["", "unet."]
|
||||
for p in diffusers_lora_prefix:
|
||||
diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
|
||||
if diffusers_lora_key.endswith(".to_out.0"):
|
||||
diffusers_lora_key = diffusers_lora_key[:-2]
|
||||
key_map[diffusers_lora_key] = unet_key
|
||||
|
||||
# TODO:
|
||||
|
||||
# if isinstance(model, xxxx.modules.model_base.SD3): # Diffusers lora SD3
|
||||
# diffusers_keys = xxxx.modules.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
# for k in diffusers_keys:
|
||||
# if k.endswith(".weight"):
|
||||
# to = diffusers_keys[k]
|
||||
# key_lora = "transformer.{}".format(k[:-len(".weight")]) # regular diffusers sd3 lora format
|
||||
# key_map[key_lora] = to
|
||||
#
|
||||
# key_lora = "base_model.model.{}".format(k[:-len(".weight")]) # format for flash-sd3 lora and others?
|
||||
# key_map[key_lora] = to
|
||||
#
|
||||
# key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) # OneTrainer lora
|
||||
# key_map[key_lora] = to
|
||||
#
|
||||
# if isinstance(model, xxxx.modules.model_base.AuraFlow): # Diffusers lora AuraFlow
|
||||
# diffusers_keys = xxxx.modules.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
# for k in diffusers_keys:
|
||||
# if k.endswith(".weight"):
|
||||
# to = diffusers_keys[k]
|
||||
# key_lora = "transformer.{}".format(k[:-len(".weight")]) # simpletrainer and probably regular diffusers lora format
|
||||
# key_map[key_lora] = to
|
||||
#
|
||||
# if isinstance(model, xxxx.modules.model_base.HunyuanDiT):
|
||||
# for k in sdk:
|
||||
# if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
# key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
# key_map["base_model.model.{}".format(key_lora)] = k # official hunyuan lora format
|
||||
|
||||
return key_map
|
||||
return get_function('model_lora_keys_unet')(model, key_map)
|
||||
|
||||
@@ -8,7 +8,7 @@ from backend.patcher.base import ModelPatcher
|
||||
class UnetPatcher(ModelPatcher):
|
||||
@classmethod
|
||||
def from_model(cls, model, diffusers_scheduler, config, k_predictor=None):
|
||||
model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor)
|
||||
model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor, config=config)
|
||||
return UnetPatcher(
|
||||
model,
|
||||
load_device=model.diffusion_model.load_device,
|
||||
|
||||
1
packages_3rdparty/README.md
vendored
Normal file
1
packages_3rdparty/README.md
vendored
Normal file
@@ -0,0 +1 @@
|
||||
Please follow the standard of https://github.com/opencv/opencv/tree/315f85d4f484c1e2fa043c73ac3fdd9fc5997ee7/3rdparty when PR or modifying files.
|
||||
323
packages_3rdparty/comfyui_lora_collection/lora.py
vendored
Normal file
323
packages_3rdparty/comfyui_lora_collection/lora.py
vendored
Normal file
@@ -0,0 +1,323 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from packages_3rdparty.comfyui_lora_collection import utils
|
||||
|
||||
|
||||
LORA_CLIP_MAP = {
|
||||
"mlp.fc1": "mlp_fc1",
|
||||
"mlp.fc2": "mlp_fc2",
|
||||
"self_attn.k_proj": "self_attn_k_proj",
|
||||
"self_attn.q_proj": "self_attn_q_proj",
|
||||
"self_attn.v_proj": "self_attn_v_proj",
|
||||
"self_attn.out_proj": "self_attn_out_proj",
|
||||
}
|
||||
|
||||
|
||||
def load_lora(lora, to_load):
|
||||
patch_dict = {}
|
||||
loaded_keys = set()
|
||||
for x in to_load:
|
||||
alpha_name = "{}.alpha".format(x)
|
||||
alpha = None
|
||||
if alpha_name in lora.keys():
|
||||
alpha = lora[alpha_name].item()
|
||||
loaded_keys.add(alpha_name)
|
||||
|
||||
dora_scale_name = "{}.dora_scale".format(x)
|
||||
dora_scale = None
|
||||
if dora_scale_name in lora.keys():
|
||||
dora_scale = lora[dora_scale_name]
|
||||
loaded_keys.add(dora_scale_name)
|
||||
|
||||
regular_lora = "{}.lora_up.weight".format(x)
|
||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||
A_name = None
|
||||
|
||||
if regular_lora in lora.keys():
|
||||
A_name = regular_lora
|
||||
B_name = "{}.lora_down.weight".format(x)
|
||||
mid_name = "{}.lora_mid.weight".format(x)
|
||||
elif diffusers_lora in lora.keys():
|
||||
A_name = diffusers_lora
|
||||
B_name = "{}_lora.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif diffusers2_lora in lora.keys():
|
||||
A_name = diffusers2_lora
|
||||
B_name = "{}.lora_A.weight".format(x)
|
||||
mid_name = None
|
||||
elif diffusers3_lora in lora.keys():
|
||||
A_name = diffusers3_lora
|
||||
B_name = "{}.lora.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif transformers_lora in lora.keys():
|
||||
A_name = transformers_lora
|
||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
||||
mid_name = None
|
||||
|
||||
if A_name is not None:
|
||||
mid = None
|
||||
if mid_name is not None and mid_name in lora.keys():
|
||||
mid = lora[mid_name]
|
||||
loaded_keys.add(mid_name)
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
|
||||
loaded_keys.add(A_name)
|
||||
loaded_keys.add(B_name)
|
||||
|
||||
|
||||
######## loha
|
||||
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
||||
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
||||
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
||||
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
||||
hada_t1_name = "{}.hada_t1".format(x)
|
||||
hada_t2_name = "{}.hada_t2".format(x)
|
||||
if hada_w1_a_name in lora.keys():
|
||||
hada_t1 = None
|
||||
hada_t2 = None
|
||||
if hada_t1_name in lora.keys():
|
||||
hada_t1 = lora[hada_t1_name]
|
||||
hada_t2 = lora[hada_t2_name]
|
||||
loaded_keys.add(hada_t1_name)
|
||||
loaded_keys.add(hada_t2_name)
|
||||
|
||||
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
|
||||
loaded_keys.add(hada_w1_a_name)
|
||||
loaded_keys.add(hada_w1_b_name)
|
||||
loaded_keys.add(hada_w2_a_name)
|
||||
loaded_keys.add(hada_w2_b_name)
|
||||
|
||||
|
||||
######## lokr
|
||||
lokr_w1_name = "{}.lokr_w1".format(x)
|
||||
lokr_w2_name = "{}.lokr_w2".format(x)
|
||||
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
||||
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
||||
lokr_t2_name = "{}.lokr_t2".format(x)
|
||||
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
||||
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
||||
|
||||
lokr_w1 = None
|
||||
if lokr_w1_name in lora.keys():
|
||||
lokr_w1 = lora[lokr_w1_name]
|
||||
loaded_keys.add(lokr_w1_name)
|
||||
|
||||
lokr_w2 = None
|
||||
if lokr_w2_name in lora.keys():
|
||||
lokr_w2 = lora[lokr_w2_name]
|
||||
loaded_keys.add(lokr_w2_name)
|
||||
|
||||
lokr_w1_a = None
|
||||
if lokr_w1_a_name in lora.keys():
|
||||
lokr_w1_a = lora[lokr_w1_a_name]
|
||||
loaded_keys.add(lokr_w1_a_name)
|
||||
|
||||
lokr_w1_b = None
|
||||
if lokr_w1_b_name in lora.keys():
|
||||
lokr_w1_b = lora[lokr_w1_b_name]
|
||||
loaded_keys.add(lokr_w1_b_name)
|
||||
|
||||
lokr_w2_a = None
|
||||
if lokr_w2_a_name in lora.keys():
|
||||
lokr_w2_a = lora[lokr_w2_a_name]
|
||||
loaded_keys.add(lokr_w2_a_name)
|
||||
|
||||
lokr_w2_b = None
|
||||
if lokr_w2_b_name in lora.keys():
|
||||
lokr_w2_b = lora[lokr_w2_b_name]
|
||||
loaded_keys.add(lokr_w2_b_name)
|
||||
|
||||
lokr_t2 = None
|
||||
if lokr_t2_name in lora.keys():
|
||||
lokr_t2 = lora[lokr_t2_name]
|
||||
loaded_keys.add(lokr_t2_name)
|
||||
|
||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
|
||||
|
||||
#glora
|
||||
a1_name = "{}.a1.weight".format(x)
|
||||
a2_name = "{}.a2.weight".format(x)
|
||||
b1_name = "{}.b1.weight".format(x)
|
||||
b2_name = "{}.b2.weight".format(x)
|
||||
if a1_name in lora:
|
||||
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
|
||||
loaded_keys.add(a1_name)
|
||||
loaded_keys.add(a2_name)
|
||||
loaded_keys.add(b1_name)
|
||||
loaded_keys.add(b2_name)
|
||||
|
||||
w_norm_name = "{}.w_norm".format(x)
|
||||
b_norm_name = "{}.b_norm".format(x)
|
||||
w_norm = lora.get(w_norm_name, None)
|
||||
b_norm = lora.get(b_norm_name, None)
|
||||
|
||||
if w_norm is not None:
|
||||
loaded_keys.add(w_norm_name)
|
||||
patch_dict[to_load[x]] = ("diff", (w_norm,))
|
||||
if b_norm is not None:
|
||||
loaded_keys.add(b_norm_name)
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
|
||||
|
||||
diff_name = "{}.diff".format(x)
|
||||
diff_weight = lora.get(diff_name, None)
|
||||
if diff_weight is not None:
|
||||
patch_dict[to_load[x]] = ("diff", (diff_weight,))
|
||||
loaded_keys.add(diff_name)
|
||||
|
||||
diff_bias_name = "{}.diff_b".format(x)
|
||||
diff_bias = lora.get(diff_bias_name, None)
|
||||
if diff_bias is not None:
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
||||
loaded_keys.add(diff_bias_name)
|
||||
|
||||
remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
|
||||
return patch_dict, remaining_dict
|
||||
|
||||
|
||||
def model_lora_keys_clip(model, key_map={}):
|
||||
sdk = model.state_dict().keys()
|
||||
|
||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||
clip_l_present = False
|
||||
for b in range(32): #TODO: clean up
|
||||
for c in LORA_CLIP_MAP:
|
||||
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
||||
key_map[lora_key] = k
|
||||
|
||||
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||
key_map[lora_key] = k
|
||||
clip_l_present = True
|
||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
||||
key_map[lora_key] = k
|
||||
|
||||
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
if clip_l_present:
|
||||
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||
key_map[lora_key] = k
|
||||
lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
||||
key_map[lora_key] = k
|
||||
else:
|
||||
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
|
||||
key_map[lora_key] = k
|
||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
||||
key_map[lora_key] = k
|
||||
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map[lora_key] = k
|
||||
|
||||
for k in sdk:
|
||||
if k.endswith(".weight"):
|
||||
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
|
||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
||||
key_map[lora_key] = k
|
||||
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
||||
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
||||
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
||||
key_map[lora_key] = k
|
||||
|
||||
|
||||
k = "clip_g.transformer.text_projection.weight"
|
||||
if k in sdk:
|
||||
key_map["lora_prior_te_text_projection"] = k #cascade lora?
|
||||
# key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
|
||||
key_map["lora_te2_text_projection"] = k #OneTrainer SD3 lora
|
||||
|
||||
k = "clip_l.transformer.text_projection.weight"
|
||||
if k in sdk:
|
||||
key_map["lora_te1_text_projection"] = k #OneTrainer SD3 lora, not necessary but omits warning
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
def model_lora_keys_unet(model, key_map={}):
|
||||
sd = model.state_dict()
|
||||
sdk = sd.keys()
|
||||
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
|
||||
diffusers_keys = utils.unet_to_diffusers(model.diffusion_model.config)
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
||||
|
||||
diffusers_lora_prefix = ["", "unet."]
|
||||
for p in diffusers_lora_prefix:
|
||||
diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
|
||||
if diffusers_lora_key.endswith(".to_out.0"):
|
||||
diffusers_lora_key = diffusers_lora_key[:-2]
|
||||
key_map[diffusers_lora_key] = unet_key
|
||||
|
||||
# if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3
|
||||
# diffusers_keys = utils.mmdit_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.")
|
||||
# for k in diffusers_keys:
|
||||
# if k.endswith(".weight"):
|
||||
# to = diffusers_keys[k]
|
||||
# key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format
|
||||
# key_map[key_lora] = to
|
||||
#
|
||||
# key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others?
|
||||
# key_map[key_lora] = to
|
||||
#
|
||||
# key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
||||
# key_map[key_lora] = to
|
||||
#
|
||||
# if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
|
||||
# diffusers_keys = utils.auraflow_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.")
|
||||
# for k in diffusers_keys:
|
||||
# if k.endswith(".weight"):
|
||||
# to = diffusers_keys[k]
|
||||
# key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
|
||||
# key_map[key_lora] = to
|
||||
#
|
||||
# if isinstance(model, comfy.model_base.HunyuanDiT):
|
||||
# for k in sdk:
|
||||
# if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
# key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
# key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format
|
||||
|
||||
if 'flux' in model.config.huggingface_repo.lower(): #Diffusers lora Flux
|
||||
diffusers_keys = utils.flux_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map[key_lora] = to
|
||||
|
||||
return key_map
|
||||
760
packages_3rdparty/comfyui_lora_collection/utils.py
vendored
Normal file
760
packages_3rdparty/comfyui_lora_collection/utils.py
vendored
Normal file
@@ -0,0 +1,760 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
import math
|
||||
import struct
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import itertools
|
||||
|
||||
|
||||
def calculate_parameters(sd, prefix=""):
|
||||
params = 0
|
||||
for k in sd.keys():
|
||||
if k.startswith(prefix):
|
||||
w = sd[k]
|
||||
params += w.nelement()
|
||||
return params
|
||||
|
||||
def weight_dtype(sd, prefix=""):
|
||||
dtypes = {}
|
||||
for k in sd.keys():
|
||||
if k.startswith(prefix):
|
||||
w = sd[k]
|
||||
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1
|
||||
|
||||
if len(dtypes) == 0:
|
||||
return None
|
||||
|
||||
return max(dtypes, key=dtypes.get)
|
||||
|
||||
def state_dict_key_replace(state_dict, keys_to_replace):
|
||||
for x in keys_to_replace:
|
||||
if x in state_dict:
|
||||
state_dict[keys_to_replace[x]] = state_dict.pop(x)
|
||||
return state_dict
|
||||
|
||||
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
|
||||
if filter_keys:
|
||||
out = {}
|
||||
else:
|
||||
out = state_dict
|
||||
for rp in replace_prefix:
|
||||
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys())))
|
||||
for x in replace:
|
||||
w = state_dict.pop(x[0])
|
||||
out[x[1]] = w
|
||||
return out
|
||||
|
||||
|
||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
keys_to_replace = {
|
||||
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
||||
"{}token_embedding.weight": "{}embeddings.token_embedding.weight",
|
||||
"{}ln_final.weight": "{}final_layer_norm.weight",
|
||||
"{}ln_final.bias": "{}final_layer_norm.bias",
|
||||
}
|
||||
|
||||
for k in keys_to_replace:
|
||||
x = k.format(prefix_from)
|
||||
if x in sd:
|
||||
sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x)
|
||||
|
||||
resblock_to_replace = {
|
||||
"ln_1": "layer_norm1",
|
||||
"ln_2": "layer_norm2",
|
||||
"mlp.c_fc": "mlp.fc1",
|
||||
"mlp.c_proj": "mlp.fc2",
|
||||
"attn.out_proj": "self_attn.out_proj",
|
||||
}
|
||||
|
||||
for resblock in range(number):
|
||||
for x in resblock_to_replace:
|
||||
for y in ["weight", "bias"]:
|
||||
k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
|
||||
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
||||
if k in sd:
|
||||
sd[k_to] = sd.pop(k)
|
||||
|
||||
for y in ["weight", "bias"]:
|
||||
k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
|
||||
if k_from in sd:
|
||||
weights = sd.pop(k_from)
|
||||
shape_from = weights.shape[0] // 3
|
||||
for x in range(3):
|
||||
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
|
||||
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
|
||||
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||
|
||||
return sd
|
||||
|
||||
def clip_text_transformers_convert(sd, prefix_from, prefix_to):
|
||||
sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32)
|
||||
|
||||
tp = "{}text_projection.weight".format(prefix_from)
|
||||
if tp in sd:
|
||||
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp)
|
||||
|
||||
tp = "{}text_projection".format(prefix_from)
|
||||
if tp in sd:
|
||||
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1).contiguous()
|
||||
return sd
|
||||
|
||||
|
||||
UNET_MAP_ATTENTIONS = {
|
||||
"proj_in.weight",
|
||||
"proj_in.bias",
|
||||
"proj_out.weight",
|
||||
"proj_out.bias",
|
||||
"norm.weight",
|
||||
"norm.bias",
|
||||
}
|
||||
|
||||
TRANSFORMER_BLOCKS = {
|
||||
"norm1.weight",
|
||||
"norm1.bias",
|
||||
"norm2.weight",
|
||||
"norm2.bias",
|
||||
"norm3.weight",
|
||||
"norm3.bias",
|
||||
"attn1.to_q.weight",
|
||||
"attn1.to_k.weight",
|
||||
"attn1.to_v.weight",
|
||||
"attn1.to_out.0.weight",
|
||||
"attn1.to_out.0.bias",
|
||||
"attn2.to_q.weight",
|
||||
"attn2.to_k.weight",
|
||||
"attn2.to_v.weight",
|
||||
"attn2.to_out.0.weight",
|
||||
"attn2.to_out.0.bias",
|
||||
"ff.net.0.proj.weight",
|
||||
"ff.net.0.proj.bias",
|
||||
"ff.net.2.weight",
|
||||
"ff.net.2.bias",
|
||||
}
|
||||
|
||||
UNET_MAP_RESNET = {
|
||||
"in_layers.2.weight": "conv1.weight",
|
||||
"in_layers.2.bias": "conv1.bias",
|
||||
"emb_layers.1.weight": "time_emb_proj.weight",
|
||||
"emb_layers.1.bias": "time_emb_proj.bias",
|
||||
"out_layers.3.weight": "conv2.weight",
|
||||
"out_layers.3.bias": "conv2.bias",
|
||||
"skip_connection.weight": "conv_shortcut.weight",
|
||||
"skip_connection.bias": "conv_shortcut.bias",
|
||||
"in_layers.0.weight": "norm1.weight",
|
||||
"in_layers.0.bias": "norm1.bias",
|
||||
"out_layers.0.weight": "norm2.weight",
|
||||
"out_layers.0.bias": "norm2.bias",
|
||||
}
|
||||
|
||||
UNET_MAP_BASIC = {
|
||||
("label_emb.0.0.weight", "class_embedding.linear_1.weight"),
|
||||
("label_emb.0.0.bias", "class_embedding.linear_1.bias"),
|
||||
("label_emb.0.2.weight", "class_embedding.linear_2.weight"),
|
||||
("label_emb.0.2.bias", "class_embedding.linear_2.bias"),
|
||||
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
|
||||
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
|
||||
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
|
||||
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
|
||||
("input_blocks.0.0.weight", "conv_in.weight"),
|
||||
("input_blocks.0.0.bias", "conv_in.bias"),
|
||||
("out.0.weight", "conv_norm_out.weight"),
|
||||
("out.0.bias", "conv_norm_out.bias"),
|
||||
("out.2.weight", "conv_out.weight"),
|
||||
("out.2.bias", "conv_out.bias"),
|
||||
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
||||
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
||||
("time_embed.2.bias", "time_embedding.linear_2.bias")
|
||||
}
|
||||
|
||||
def unet_to_diffusers(unet_config):
|
||||
if "num_res_blocks" not in unet_config:
|
||||
return {}
|
||||
num_res_blocks = unet_config["num_res_blocks"]
|
||||
channel_mult = unet_config["channel_mult"]
|
||||
transformer_depth = unet_config["transformer_depth"][:]
|
||||
transformer_depth_output = unet_config["transformer_depth_output"][:]
|
||||
num_blocks = len(channel_mult)
|
||||
|
||||
transformers_mid = unet_config.get("transformer_depth_middle", None)
|
||||
|
||||
diffusers_unet_map = {}
|
||||
for x in range(num_blocks):
|
||||
n = 1 + (num_res_blocks[x] + 1) * x
|
||||
for i in range(num_res_blocks[x]):
|
||||
for b in UNET_MAP_RESNET:
|
||||
diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
|
||||
num_transformers = transformer_depth.pop(0)
|
||||
if num_transformers > 0:
|
||||
for b in UNET_MAP_ATTENTIONS:
|
||||
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
|
||||
for t in range(num_transformers):
|
||||
for b in TRANSFORMER_BLOCKS:
|
||||
diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||
n += 1
|
||||
for k in ["weight", "bias"]:
|
||||
diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k)
|
||||
|
||||
i = 0
|
||||
for b in UNET_MAP_ATTENTIONS:
|
||||
diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b)
|
||||
for t in range(transformers_mid):
|
||||
for b in TRANSFORMER_BLOCKS:
|
||||
diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b)
|
||||
|
||||
for i, n in enumerate([0, 2]):
|
||||
for b in UNET_MAP_RESNET:
|
||||
diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
|
||||
|
||||
num_res_blocks = list(reversed(num_res_blocks))
|
||||
for x in range(num_blocks):
|
||||
n = (num_res_blocks[x] + 1) * x
|
||||
l = num_res_blocks[x] + 1
|
||||
for i in range(l):
|
||||
c = 0
|
||||
for b in UNET_MAP_RESNET:
|
||||
diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
|
||||
c += 1
|
||||
num_transformers = transformer_depth_output.pop()
|
||||
if num_transformers > 0:
|
||||
c += 1
|
||||
for b in UNET_MAP_ATTENTIONS:
|
||||
diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
|
||||
for t in range(num_transformers):
|
||||
for b in TRANSFORMER_BLOCKS:
|
||||
diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||
if i == l - 1:
|
||||
for k in ["weight", "bias"]:
|
||||
diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k)
|
||||
n += 1
|
||||
|
||||
for k in UNET_MAP_BASIC:
|
||||
diffusers_unet_map[k[1]] = k[0]
|
||||
|
||||
return diffusers_unet_map
|
||||
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
MMDIT_MAP_BASIC = {
|
||||
("context_embedder.bias", "context_embedder.bias"),
|
||||
("context_embedder.weight", "context_embedder.weight"),
|
||||
("t_embedder.mlp.0.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
|
||||
("t_embedder.mlp.0.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
|
||||
("t_embedder.mlp.2.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
|
||||
("t_embedder.mlp.2.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
|
||||
("x_embedder.proj.bias", "pos_embed.proj.bias"),
|
||||
("x_embedder.proj.weight", "pos_embed.proj.weight"),
|
||||
("y_embedder.mlp.0.bias", "time_text_embed.text_embedder.linear_1.bias"),
|
||||
("y_embedder.mlp.0.weight", "time_text_embed.text_embedder.linear_1.weight"),
|
||||
("y_embedder.mlp.2.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
||||
("y_embedder.mlp.2.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
||||
("pos_embed", "pos_embed.pos_embed"),
|
||||
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||
("final_layer.linear.bias", "proj_out.bias"),
|
||||
("final_layer.linear.weight", "proj_out.weight"),
|
||||
}
|
||||
|
||||
MMDIT_MAP_BLOCK = {
|
||||
("context_block.adaLN_modulation.1.bias", "norm1_context.linear.bias"),
|
||||
("context_block.adaLN_modulation.1.weight", "norm1_context.linear.weight"),
|
||||
("context_block.attn.proj.bias", "attn.to_add_out.bias"),
|
||||
("context_block.attn.proj.weight", "attn.to_add_out.weight"),
|
||||
("context_block.mlp.fc1.bias", "ff_context.net.0.proj.bias"),
|
||||
("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"),
|
||||
("context_block.mlp.fc2.bias", "ff_context.net.2.bias"),
|
||||
("context_block.mlp.fc2.weight", "ff_context.net.2.weight"),
|
||||
("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"),
|
||||
("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"),
|
||||
("x_block.attn.proj.bias", "attn.to_out.0.bias"),
|
||||
("x_block.attn.proj.weight", "attn.to_out.0.weight"),
|
||||
("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"),
|
||||
("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"),
|
||||
("x_block.mlp.fc2.bias", "ff.net.2.bias"),
|
||||
("x_block.mlp.fc2.weight", "ff.net.2.weight"),
|
||||
}
|
||||
|
||||
def mmdit_to_diffusers(mmdit_config, output_prefix=""):
|
||||
key_map = {}
|
||||
|
||||
depth = mmdit_config.get("depth", 0)
|
||||
num_blocks = mmdit_config.get("num_blocks", depth)
|
||||
for i in range(num_blocks):
|
||||
block_from = "transformer_blocks.{}".format(i)
|
||||
block_to = "{}joint_blocks.{}".format(output_prefix, i)
|
||||
|
||||
offset = depth * 64
|
||||
|
||||
for end in ("weight", "bias"):
|
||||
k = "{}.attn.".format(block_from)
|
||||
qkv = "{}.x_block.attn.qkv.{}".format(block_to, end)
|
||||
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset))
|
||||
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset))
|
||||
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
||||
|
||||
qkv = "{}.context_block.attn.qkv.{}".format(block_to, end)
|
||||
key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, offset))
|
||||
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset))
|
||||
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
||||
|
||||
for k in MMDIT_MAP_BLOCK:
|
||||
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
|
||||
|
||||
map_basic = MMDIT_MAP_BASIC.copy()
|
||||
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.bias".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.bias".format(depth - 1), swap_scale_shift))
|
||||
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.weight".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.weight".format(depth - 1), swap_scale_shift))
|
||||
|
||||
for k in map_basic:
|
||||
if len(k) > 2:
|
||||
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
||||
else:
|
||||
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
def auraflow_to_diffusers(mmdit_config, output_prefix=""):
|
||||
n_double_layers = mmdit_config.get("n_double_layers", 0)
|
||||
n_layers = mmdit_config.get("n_layers", 0)
|
||||
|
||||
key_map = {}
|
||||
for i in range(n_layers):
|
||||
if i < n_double_layers:
|
||||
index = i
|
||||
prefix_from = "joint_transformer_blocks"
|
||||
prefix_to = "{}double_layers".format(output_prefix)
|
||||
block_map = {
|
||||
"attn.to_q.weight": "attn.w2q.weight",
|
||||
"attn.to_k.weight": "attn.w2k.weight",
|
||||
"attn.to_v.weight": "attn.w2v.weight",
|
||||
"attn.to_out.0.weight": "attn.w2o.weight",
|
||||
"attn.add_q_proj.weight": "attn.w1q.weight",
|
||||
"attn.add_k_proj.weight": "attn.w1k.weight",
|
||||
"attn.add_v_proj.weight": "attn.w1v.weight",
|
||||
"attn.to_add_out.weight": "attn.w1o.weight",
|
||||
"ff.linear_1.weight": "mlpX.c_fc1.weight",
|
||||
"ff.linear_2.weight": "mlpX.c_fc2.weight",
|
||||
"ff.out_projection.weight": "mlpX.c_proj.weight",
|
||||
"ff_context.linear_1.weight": "mlpC.c_fc1.weight",
|
||||
"ff_context.linear_2.weight": "mlpC.c_fc2.weight",
|
||||
"ff_context.out_projection.weight": "mlpC.c_proj.weight",
|
||||
"norm1.linear.weight": "modX.1.weight",
|
||||
"norm1_context.linear.weight": "modC.1.weight",
|
||||
}
|
||||
else:
|
||||
index = i - n_double_layers
|
||||
prefix_from = "single_transformer_blocks"
|
||||
prefix_to = "{}single_layers".format(output_prefix)
|
||||
|
||||
block_map = {
|
||||
"attn.to_q.weight": "attn.w1q.weight",
|
||||
"attn.to_k.weight": "attn.w1k.weight",
|
||||
"attn.to_v.weight": "attn.w1v.weight",
|
||||
"attn.to_out.0.weight": "attn.w1o.weight",
|
||||
"norm1.linear.weight": "modCX.1.weight",
|
||||
"ff.linear_1.weight": "mlp.c_fc1.weight",
|
||||
"ff.linear_2.weight": "mlp.c_fc2.weight",
|
||||
"ff.out_projection.weight": "mlp.c_proj.weight"
|
||||
}
|
||||
|
||||
for k in block_map:
|
||||
key_map["{}.{}.{}".format(prefix_from, index, k)] = "{}.{}.{}".format(prefix_to, index, block_map[k])
|
||||
|
||||
MAP_BASIC = {
|
||||
("positional_encoding", "pos_embed.pos_embed"),
|
||||
("register_tokens", "register_tokens"),
|
||||
("t_embedder.mlp.0.weight", "time_step_proj.linear_1.weight"),
|
||||
("t_embedder.mlp.0.bias", "time_step_proj.linear_1.bias"),
|
||||
("t_embedder.mlp.2.weight", "time_step_proj.linear_2.weight"),
|
||||
("t_embedder.mlp.2.bias", "time_step_proj.linear_2.bias"),
|
||||
("cond_seq_linear.weight", "context_embedder.weight"),
|
||||
("init_x_linear.weight", "pos_embed.proj.weight"),
|
||||
("init_x_linear.bias", "pos_embed.proj.bias"),
|
||||
("final_linear.weight", "proj_out.weight"),
|
||||
("modF.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||
}
|
||||
|
||||
for k in MAP_BASIC:
|
||||
if len(k) > 2:
|
||||
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
||||
else:
|
||||
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
||||
|
||||
return key_map
|
||||
|
||||
def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
n_double_layers = mmdit_config.get("depth", 0)
|
||||
n_single_layers = mmdit_config.get("depth_single_blocks", 0)
|
||||
hidden_size = mmdit_config.get("hidden_size", 0)
|
||||
|
||||
key_map = {}
|
||||
for index in range(n_double_layers):
|
||||
prefix_from = "transformer_blocks.{}".format(index)
|
||||
prefix_to = "{}double_blocks.{}".format(output_prefix, index)
|
||||
|
||||
for end in ("weight", "bias"):
|
||||
k = "{}.attn.".format(prefix_from)
|
||||
qkv = "{}.img_attn.qkv.{}".format(prefix_to, end)
|
||||
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||
|
||||
k = "{}.attn.".format(prefix_from)
|
||||
qkv = "{}.txt_attn.qkv.{}".format(prefix_to, end)
|
||||
key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||
|
||||
block_map = {
|
||||
"attn.to_out.0.weight": "img_attn.proj.weight",
|
||||
"attn.to_out.0.bias": "img_attn.proj.bias",
|
||||
"norm1.linear.weight": "img_mod.lin.weight",
|
||||
"norm1.linear.bias": "img_mod.lin.bias",
|
||||
"norm1_context.linear.weight": "txt_mod.lin.weight",
|
||||
"norm1_context.linear.bias": "txt_mod.lin.bias",
|
||||
"attn.to_add_out.weight": "txt_attn.proj.weight",
|
||||
"attn.to_add_out.bias": "txt_attn.proj.bias",
|
||||
"ff.net.0.proj.weight": "img_mlp.0.weight",
|
||||
"ff.net.0.proj.bias": "img_mlp.0.bias",
|
||||
"ff.net.2.weight": "img_mlp.2.weight",
|
||||
"ff.net.2.bias": "img_mlp.2.bias",
|
||||
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
|
||||
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
|
||||
"ff_context.net.2.weight": "txt_mlp.2.weight",
|
||||
"ff_context.net.2.bias": "txt_mlp.2.bias",
|
||||
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
||||
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
||||
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
||||
}
|
||||
|
||||
for k in block_map:
|
||||
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
||||
|
||||
for index in range(n_single_layers):
|
||||
prefix_from = "single_transformer_blocks.{}".format(index)
|
||||
prefix_to = "{}single_blocks.{}".format(output_prefix, index)
|
||||
|
||||
for end in ("weight", "bias"):
|
||||
k = "{}.attn.".format(prefix_from)
|
||||
qkv = "{}.linear1.{}".format(prefix_to, end)
|
||||
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||
key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4))
|
||||
|
||||
block_map = {
|
||||
"norm.linear.weight": "modulation.lin.weight",
|
||||
"norm.linear.bias": "modulation.lin.bias",
|
||||
"proj_out.weight": "linear2.weight",
|
||||
"proj_out.bias": "linear2.bias",
|
||||
"attn.norm_q.weight": "norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "norm.key_norm.scale",
|
||||
}
|
||||
|
||||
for k in block_map:
|
||||
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
||||
|
||||
MAP_BASIC = {
|
||||
("final_layer.linear.bias", "proj_out.bias"),
|
||||
("final_layer.linear.weight", "proj_out.weight"),
|
||||
("img_in.bias", "x_embedder.bias"),
|
||||
("img_in.weight", "x_embedder.weight"),
|
||||
("time_in.in_layer.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
|
||||
("time_in.in_layer.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
|
||||
("time_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
|
||||
("time_in.out_layer.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
|
||||
("txt_in.bias", "context_embedder.bias"),
|
||||
("txt_in.weight", "context_embedder.weight"),
|
||||
("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"),
|
||||
("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"),
|
||||
("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
||||
("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
||||
("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"),
|
||||
("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"),
|
||||
("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"),
|
||||
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
||||
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||
}
|
||||
|
||||
for k in MAP_BASIC:
|
||||
if len(k) > 2:
|
||||
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
||||
else:
|
||||
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
||||
|
||||
return key_map
|
||||
|
||||
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
||||
if tensor.shape[dim] > batch_size:
|
||||
return tensor.narrow(dim, 0, batch_size)
|
||||
elif tensor.shape[dim] < batch_size:
|
||||
return tensor.repeat(dim * [1] + [math.ceil(batch_size / tensor.shape[dim])] + [1] * (len(tensor.shape) - 1 - dim)).narrow(dim, 0, batch_size)
|
||||
return tensor
|
||||
|
||||
def resize_to_batch_size(tensor, batch_size):
|
||||
in_batch_size = tensor.shape[0]
|
||||
if in_batch_size == batch_size:
|
||||
return tensor
|
||||
|
||||
if batch_size <= 1:
|
||||
return tensor[:batch_size]
|
||||
|
||||
output = torch.empty([batch_size] + list(tensor.shape)[1:], dtype=tensor.dtype, device=tensor.device)
|
||||
if batch_size < in_batch_size:
|
||||
scale = (in_batch_size - 1) / (batch_size - 1)
|
||||
for i in range(batch_size):
|
||||
output[i] = tensor[min(round(i * scale), in_batch_size - 1)]
|
||||
else:
|
||||
scale = in_batch_size / batch_size
|
||||
for i in range(batch_size):
|
||||
output[i] = tensor[min(math.floor((i + 0.5) * scale), in_batch_size - 1)]
|
||||
|
||||
return output
|
||||
|
||||
def convert_sd_to(state_dict, dtype):
|
||||
keys = list(state_dict.keys())
|
||||
for k in keys:
|
||||
state_dict[k] = state_dict[k].to(dtype)
|
||||
return state_dict
|
||||
|
||||
def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
||||
with open(safetensors_path, "rb") as f:
|
||||
header = f.read(8)
|
||||
length_of_header = struct.unpack('<Q', header)[0]
|
||||
if length_of_header > max_size:
|
||||
return None
|
||||
return f.read(length_of_header)
|
||||
|
||||
def set_attr(obj, attr, value):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
prev = getattr(obj, attrs[-1])
|
||||
setattr(obj, attrs[-1], value)
|
||||
return prev
|
||||
|
||||
def set_attr_param(obj, attr, value):
|
||||
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
||||
|
||||
def copy_to_param(obj, attr, value):
|
||||
# inplace update tensor instead of replacing it
|
||||
attrs = attr.split(".")
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
prev = getattr(obj, attrs[-1])
|
||||
prev.data.copy_(value)
|
||||
|
||||
def get_attr(obj, attr):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs:
|
||||
obj = getattr(obj, name)
|
||||
return obj
|
||||
|
||||
def bislerp(samples, width, height):
|
||||
def slerp(b1, b2, r):
|
||||
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
|
||||
|
||||
c = b1.shape[-1]
|
||||
|
||||
#norms
|
||||
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
|
||||
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
|
||||
|
||||
#normalize
|
||||
b1_normalized = b1 / b1_norms
|
||||
b2_normalized = b2 / b2_norms
|
||||
|
||||
#zero when norms are zero
|
||||
b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
|
||||
b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
|
||||
|
||||
#slerp
|
||||
dot = (b1_normalized*b2_normalized).sum(1)
|
||||
omega = torch.acos(dot)
|
||||
so = torch.sin(omega)
|
||||
|
||||
#technically not mathematically correct, but more pleasing?
|
||||
res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
|
||||
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
|
||||
|
||||
#edge cases for same or polar opposites
|
||||
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
|
||||
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
|
||||
return res
|
||||
|
||||
def generate_bilinear_data(length_old, length_new, device):
|
||||
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1))
|
||||
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
|
||||
ratios = coords_1 - coords_1.floor()
|
||||
coords_1 = coords_1.to(torch.int64)
|
||||
|
||||
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1
|
||||
coords_2[:,:,:,-1] -= 1
|
||||
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
|
||||
coords_2 = coords_2.to(torch.int64)
|
||||
return ratios, coords_1, coords_2
|
||||
|
||||
orig_dtype = samples.dtype
|
||||
samples = samples.float()
|
||||
n,c,h,w = samples.shape
|
||||
h_new, w_new = (height, width)
|
||||
|
||||
#linear w
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
|
||||
coords_1 = coords_1.expand((n, c, h, -1))
|
||||
coords_2 = coords_2.expand((n, c, h, -1))
|
||||
ratios = ratios.expand((n, 1, h, -1))
|
||||
|
||||
pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c))
|
||||
pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c))
|
||||
ratios = ratios.movedim(1, -1).reshape((-1,1))
|
||||
|
||||
result = slerp(pass_1, pass_2, ratios)
|
||||
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
|
||||
|
||||
#linear h
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device)
|
||||
coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
||||
coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
||||
ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
|
||||
|
||||
pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c))
|
||||
pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c))
|
||||
ratios = ratios.movedim(1, -1).reshape((-1,1))
|
||||
|
||||
result = slerp(pass_1, pass_2, ratios)
|
||||
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
|
||||
return result.to(orig_dtype)
|
||||
|
||||
def lanczos(samples, width, height):
|
||||
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
||||
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
||||
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
|
||||
result = torch.stack(images)
|
||||
return result.to(samples.device, samples.dtype)
|
||||
|
||||
def common_upscale(samples, width, height, upscale_method, crop):
|
||||
if crop == "center":
|
||||
old_width = samples.shape[3]
|
||||
old_height = samples.shape[2]
|
||||
old_aspect = old_width / old_height
|
||||
new_aspect = width / height
|
||||
x = 0
|
||||
y = 0
|
||||
if old_aspect > new_aspect:
|
||||
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
||||
elif old_aspect < new_aspect:
|
||||
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
||||
s = samples[:,:,y:old_height-y,x:old_width-x]
|
||||
else:
|
||||
s = samples
|
||||
|
||||
if upscale_method == "bislerp":
|
||||
return bislerp(s, width, height)
|
||||
elif upscale_method == "lanczos":
|
||||
return lanczos(s, width, height)
|
||||
else:
|
||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||
|
||||
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
||||
|
||||
@torch.inference_mode()
|
||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
dims = len(tile)
|
||||
output = torch.empty([samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device)
|
||||
|
||||
for b in range(samples.shape[0]):
|
||||
s = samples[b:b+1]
|
||||
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||
|
||||
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
|
||||
s_in = s
|
||||
upscaled = []
|
||||
|
||||
for d in range(dims):
|
||||
pos = max(0, min(s.shape[d + 2] - overlap, it[d]))
|
||||
l = min(tile[d], s.shape[d + 2] - pos)
|
||||
s_in = s_in.narrow(d + 2, pos, l)
|
||||
upscaled.append(round(pos * upscale_amount))
|
||||
ps = function(s_in).to(output_device)
|
||||
mask = torch.ones_like(ps)
|
||||
feather = round(overlap * upscale_amount)
|
||||
for t in range(feather):
|
||||
for d in range(2, dims + 2):
|
||||
m = mask.narrow(d, t, 1)
|
||||
m *= ((1.0/feather) * (t + 1))
|
||||
m = mask.narrow(d, mask.shape[d] -1 -t, 1)
|
||||
m *= ((1.0/feather) * (t + 1))
|
||||
|
||||
o = out
|
||||
o_d = out_div
|
||||
for d in range(dims):
|
||||
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||
|
||||
o += ps * mask
|
||||
o_d += mask
|
||||
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
|
||||
output[b:b+1] = out/out_div
|
||||
return output
|
||||
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar)
|
||||
|
||||
PROGRESS_BAR_ENABLED = True
|
||||
def set_progress_bar_enabled(enabled):
|
||||
global PROGRESS_BAR_ENABLED
|
||||
PROGRESS_BAR_ENABLED = enabled
|
||||
|
||||
PROGRESS_BAR_HOOK = None
|
||||
def set_progress_bar_global_hook(function):
|
||||
global PROGRESS_BAR_HOOK
|
||||
PROGRESS_BAR_HOOK = function
|
||||
|
||||
class ProgressBar:
|
||||
def __init__(self, total):
|
||||
global PROGRESS_BAR_HOOK
|
||||
self.total = total
|
||||
self.current = 0
|
||||
self.hook = PROGRESS_BAR_HOOK
|
||||
|
||||
def update_absolute(self, value, total=None, preview=None):
|
||||
if total is not None:
|
||||
self.total = total
|
||||
if value > self.total:
|
||||
value = self.total
|
||||
self.current = value
|
||||
if self.hook is not None:
|
||||
self.hook(self.current, self.total, preview)
|
||||
|
||||
def update(self, value):
|
||||
self.update_absolute(self.current + value)
|
||||
2
packages_3rdparty/webui_lora_collection/lora.py
vendored
Normal file
2
packages_3rdparty/webui_lora_collection/lora.py
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
# TODO: Implement API
|
||||
|
||||
68
packages_3rdparty/webui_lora_collection/lyco_helpers.py
vendored
Normal file
68
packages_3rdparty/webui_lora_collection/lyco_helpers.py
vendored
Normal file
@@ -0,0 +1,68 @@
|
||||
import torch
|
||||
|
||||
|
||||
def make_weight_cp(t, wa, wb):
|
||||
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
|
||||
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
|
||||
|
||||
|
||||
def rebuild_conventional(up, down, shape, dyn_dim=None):
|
||||
up = up.reshape(up.size(0), -1)
|
||||
down = down.reshape(down.size(0), -1)
|
||||
if dyn_dim is not None:
|
||||
up = up[:, :dyn_dim]
|
||||
down = down[:dyn_dim, :]
|
||||
return (up @ down).reshape(shape)
|
||||
|
||||
|
||||
def rebuild_cp_decomposition(up, down, mid):
|
||||
up = up.reshape(up.size(0), -1)
|
||||
down = down.reshape(down.size(0), -1)
|
||||
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
|
||||
|
||||
|
||||
# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
|
||||
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
||||
'''
|
||||
return a tuple of two value of input dimension decomposed by the number closest to factor
|
||||
second value is higher or equal than first value.
|
||||
|
||||
In LoRA with Kroneckor Product, first value is a value for weight scale.
|
||||
secon value is a value for weight.
|
||||
|
||||
Because of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
|
||||
|
||||
examples)
|
||||
factor
|
||||
-1 2 4 8 16 ...
|
||||
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
|
||||
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
|
||||
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
|
||||
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
|
||||
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
|
||||
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
|
||||
'''
|
||||
|
||||
if factor > 0 and (dimension % factor) == 0:
|
||||
m = factor
|
||||
n = dimension // factor
|
||||
if m > n:
|
||||
n, m = m, n
|
||||
return m, n
|
||||
if factor < 0:
|
||||
factor = dimension
|
||||
m, n = 1, dimension
|
||||
length = m + n
|
||||
while m<n:
|
||||
new_m = m + 1
|
||||
while dimension%new_m != 0:
|
||||
new_m += 1
|
||||
new_n = dimension // new_m
|
||||
if new_m + new_n > length or new_m>factor:
|
||||
break
|
||||
else:
|
||||
m, n = new_m, new_n
|
||||
if m > n:
|
||||
n, m = m, n
|
||||
return m, n
|
||||
|
||||
228
packages_3rdparty/webui_lora_collection/network.py
vendored
Normal file
228
packages_3rdparty/webui_lora_collection/network.py
vendored
Normal file
@@ -0,0 +1,228 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from collections import namedtuple
|
||||
import enum
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modules import sd_models, cache, errors, hashes, shared
|
||||
import modules.models.sd3.mmdit
|
||||
|
||||
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
|
||||
|
||||
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
||||
|
||||
|
||||
class SdVersion(enum.Enum):
|
||||
Unknown = 1
|
||||
SD1 = 2
|
||||
SD2 = 3
|
||||
SDXL = 4
|
||||
|
||||
|
||||
class NetworkOnDisk:
|
||||
def __init__(self, name, filename):
|
||||
self.name = name
|
||||
self.filename = filename
|
||||
self.metadata = {}
|
||||
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
||||
|
||||
def read_metadata():
|
||||
metadata = sd_models.read_metadata_from_safetensors(filename)
|
||||
|
||||
return metadata
|
||||
|
||||
if self.is_safetensors:
|
||||
try:
|
||||
self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading lora {filename}")
|
||||
|
||||
if self.metadata:
|
||||
m = {}
|
||||
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
|
||||
m[k] = v
|
||||
|
||||
self.metadata = m
|
||||
|
||||
self.alias = self.metadata.get('ss_output_name', self.name)
|
||||
|
||||
self.hash = None
|
||||
self.shorthash = None
|
||||
self.set_hash(
|
||||
self.metadata.get('sshs_model_hash') or
|
||||
hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
|
||||
''
|
||||
)
|
||||
|
||||
self.sd_version = self.detect_version()
|
||||
|
||||
def detect_version(self):
|
||||
if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
|
||||
return SdVersion.SDXL
|
||||
elif str(self.metadata.get('ss_v2', "")) == "True":
|
||||
return SdVersion.SD2
|
||||
elif len(self.metadata):
|
||||
return SdVersion.SD1
|
||||
|
||||
return SdVersion.Unknown
|
||||
|
||||
def set_hash(self, v):
|
||||
self.hash = v
|
||||
self.shorthash = self.hash[0:12]
|
||||
|
||||
if self.shorthash:
|
||||
import networks
|
||||
networks.available_network_hash_lookup[self.shorthash] = self
|
||||
|
||||
def read_hash(self):
|
||||
if not self.hash:
|
||||
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
|
||||
|
||||
def get_alias(self):
|
||||
import networks
|
||||
if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:
|
||||
return self.name
|
||||
else:
|
||||
return self.alias
|
||||
|
||||
|
||||
class Network: # LoraModule
|
||||
def __init__(self, name, network_on_disk: NetworkOnDisk):
|
||||
self.name = name
|
||||
self.network_on_disk = network_on_disk
|
||||
self.te_multiplier = 1.0
|
||||
self.unet_multiplier = 1.0
|
||||
self.dyn_dim = None
|
||||
self.modules = {}
|
||||
self.bundle_embeddings = {}
|
||||
self.mtime = None
|
||||
|
||||
self.mentioned_name = None
|
||||
"""the text that was used to add the network to prompt - can be either name or an alias"""
|
||||
|
||||
|
||||
class ModuleType:
|
||||
def create_module(self, net: Network, weights: NetworkWeights) -> Network | None:
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModule:
|
||||
def __init__(self, net: Network, weights: NetworkWeights):
|
||||
self.network = net
|
||||
self.network_key = weights.network_key
|
||||
self.sd_key = weights.sd_key
|
||||
self.sd_module = weights.sd_module
|
||||
|
||||
if isinstance(self.sd_module, modules.models.sd3.mmdit.QkvLinear):
|
||||
s = self.sd_module.weight.shape
|
||||
self.shape = (s[0] // 3, s[1])
|
||||
elif hasattr(self.sd_module, 'weight'):
|
||||
self.shape = self.sd_module.weight.shape
|
||||
elif isinstance(self.sd_module, nn.MultiheadAttention):
|
||||
# For now, only self-attn use Pytorch's MHA
|
||||
# So assume all qkvo proj have same shape
|
||||
self.shape = self.sd_module.out_proj.weight.shape
|
||||
else:
|
||||
self.shape = None
|
||||
|
||||
self.ops = None
|
||||
self.extra_kwargs = {}
|
||||
if isinstance(self.sd_module, nn.Conv2d):
|
||||
self.ops = F.conv2d
|
||||
self.extra_kwargs = {
|
||||
'stride': self.sd_module.stride,
|
||||
'padding': self.sd_module.padding
|
||||
}
|
||||
elif isinstance(self.sd_module, nn.Linear):
|
||||
self.ops = F.linear
|
||||
elif isinstance(self.sd_module, nn.LayerNorm):
|
||||
self.ops = F.layer_norm
|
||||
self.extra_kwargs = {
|
||||
'normalized_shape': self.sd_module.normalized_shape,
|
||||
'eps': self.sd_module.eps
|
||||
}
|
||||
elif isinstance(self.sd_module, nn.GroupNorm):
|
||||
self.ops = F.group_norm
|
||||
self.extra_kwargs = {
|
||||
'num_groups': self.sd_module.num_groups,
|
||||
'eps': self.sd_module.eps
|
||||
}
|
||||
|
||||
self.dim = None
|
||||
self.bias = weights.w.get("bias")
|
||||
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
|
||||
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
|
||||
|
||||
self.dora_scale = weights.w.get("dora_scale", None)
|
||||
self.dora_norm_dims = len(self.shape) - 1
|
||||
|
||||
def multiplier(self):
|
||||
if 'transformer' in self.sd_key[:20]:
|
||||
return self.network.te_multiplier
|
||||
else:
|
||||
return self.network.unet_multiplier
|
||||
|
||||
def calc_scale(self):
|
||||
if self.scale is not None:
|
||||
return self.scale
|
||||
if self.dim is not None and self.alpha is not None:
|
||||
return self.alpha / self.dim
|
||||
|
||||
return 1.0
|
||||
|
||||
def apply_weight_decompose(self, updown, orig_weight):
|
||||
# Match the device/dtype
|
||||
orig_weight = orig_weight.to(updown.dtype)
|
||||
dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype)
|
||||
updown = updown.to(orig_weight.device)
|
||||
|
||||
merged_scale1 = updown + orig_weight
|
||||
merged_scale1_norm = (
|
||||
merged_scale1.transpose(0, 1)
|
||||
.reshape(merged_scale1.shape[1], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
dora_merged = (
|
||||
merged_scale1 * (dora_scale / merged_scale1_norm)
|
||||
)
|
||||
final_updown = dora_merged - orig_weight
|
||||
return final_updown
|
||||
|
||||
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
||||
if self.bias is not None:
|
||||
updown = updown.reshape(self.bias.shape)
|
||||
updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
|
||||
updown = updown.reshape(output_shape)
|
||||
|
||||
if len(output_shape) == 4:
|
||||
updown = updown.reshape(output_shape)
|
||||
|
||||
if orig_weight.size().numel() == updown.size().numel():
|
||||
updown = updown.reshape(orig_weight.shape)
|
||||
|
||||
if ex_bias is not None:
|
||||
ex_bias = ex_bias * self.multiplier()
|
||||
|
||||
updown = updown * self.calc_scale()
|
||||
|
||||
if self.dora_scale is not None:
|
||||
updown = self.apply_weight_decompose(updown, orig_weight)
|
||||
|
||||
return updown * self.multiplier(), ex_bias
|
||||
|
||||
def calc_updown(self, target):
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, x, y):
|
||||
"""A general forward implementation for all modules"""
|
||||
if self.ops is None:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
updown, ex_bias = self.calc_updown(self.sd_module.weight)
|
||||
return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)
|
||||
|
||||
27
packages_3rdparty/webui_lora_collection/network_full.py
vendored
Normal file
27
packages_3rdparty/webui_lora_collection/network_full.py
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
import network
|
||||
|
||||
|
||||
class ModuleTypeFull(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["diff"]):
|
||||
return NetworkModuleFull(net, weights)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModuleFull(network.NetworkModule):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
|
||||
self.weight = weights.w.get("diff")
|
||||
self.ex_bias = weights.w.get("diff_b")
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
output_shape = self.weight.shape
|
||||
updown = self.weight.to(orig_weight.device)
|
||||
if self.ex_bias is not None:
|
||||
ex_bias = self.ex_bias.to(orig_weight.device)
|
||||
else:
|
||||
ex_bias = None
|
||||
|
||||
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
|
||||
33
packages_3rdparty/webui_lora_collection/network_glora.py
vendored
Normal file
33
packages_3rdparty/webui_lora_collection/network_glora.py
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
|
||||
import network
|
||||
|
||||
class ModuleTypeGLora(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]):
|
||||
return NetworkModuleGLora(net, weights)
|
||||
|
||||
return None
|
||||
|
||||
# adapted from https://github.com/KohakuBlueleaf/LyCORIS
|
||||
class NetworkModuleGLora(network.NetworkModule):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
|
||||
if hasattr(self.sd_module, 'weight'):
|
||||
self.shape = self.sd_module.weight.shape
|
||||
|
||||
self.w1a = weights.w["a1.weight"]
|
||||
self.w1b = weights.w["b1.weight"]
|
||||
self.w2a = weights.w["a2.weight"]
|
||||
self.w2b = weights.w["b2.weight"]
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
w1a = self.w1a.to(orig_weight.device)
|
||||
w1b = self.w1b.to(orig_weight.device)
|
||||
w2a = self.w2a.to(orig_weight.device)
|
||||
w2b = self.w2b.to(orig_weight.device)
|
||||
|
||||
output_shape = [w1a.size(0), w1b.size(1)]
|
||||
updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a))
|
||||
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||
55
packages_3rdparty/webui_lora_collection/network_hada.py
vendored
Normal file
55
packages_3rdparty/webui_lora_collection/network_hada.py
vendored
Normal file
@@ -0,0 +1,55 @@
|
||||
import lyco_helpers
|
||||
import network
|
||||
|
||||
|
||||
class ModuleTypeHada(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]):
|
||||
return NetworkModuleHada(net, weights)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModuleHada(network.NetworkModule):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
|
||||
if hasattr(self.sd_module, 'weight'):
|
||||
self.shape = self.sd_module.weight.shape
|
||||
|
||||
self.w1a = weights.w["hada_w1_a"]
|
||||
self.w1b = weights.w["hada_w1_b"]
|
||||
self.dim = self.w1b.shape[0]
|
||||
self.w2a = weights.w["hada_w2_a"]
|
||||
self.w2b = weights.w["hada_w2_b"]
|
||||
|
||||
self.t1 = weights.w.get("hada_t1")
|
||||
self.t2 = weights.w.get("hada_t2")
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
w1a = self.w1a.to(orig_weight.device)
|
||||
w1b = self.w1b.to(orig_weight.device)
|
||||
w2a = self.w2a.to(orig_weight.device)
|
||||
w2b = self.w2b.to(orig_weight.device)
|
||||
|
||||
output_shape = [w1a.size(0), w1b.size(1)]
|
||||
|
||||
if self.t1 is not None:
|
||||
output_shape = [w1a.size(1), w1b.size(1)]
|
||||
t1 = self.t1.to(orig_weight.device)
|
||||
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
|
||||
output_shape += t1.shape[2:]
|
||||
else:
|
||||
if len(w1b.shape) == 4:
|
||||
output_shape += w1b.shape[2:]
|
||||
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
|
||||
|
||||
if self.t2 is not None:
|
||||
t2 = self.t2.to(orig_weight.device)
|
||||
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
||||
else:
|
||||
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
|
||||
|
||||
updown = updown1 * updown2
|
||||
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||
30
packages_3rdparty/webui_lora_collection/network_ia3.py
vendored
Normal file
30
packages_3rdparty/webui_lora_collection/network_ia3.py
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
import network
|
||||
|
||||
|
||||
class ModuleTypeIa3(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["weight"]):
|
||||
return NetworkModuleIa3(net, weights)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModuleIa3(network.NetworkModule):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
|
||||
self.w = weights.w["weight"]
|
||||
self.on_input = weights.w["on_input"].item()
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
w = self.w.to(orig_weight.device)
|
||||
|
||||
output_shape = [w.size(0), orig_weight.size(1)]
|
||||
if self.on_input:
|
||||
output_shape.reverse()
|
||||
else:
|
||||
w = w.reshape(-1, 1)
|
||||
|
||||
updown = orig_weight * w
|
||||
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||
64
packages_3rdparty/webui_lora_collection/network_lokr.py
vendored
Normal file
64
packages_3rdparty/webui_lora_collection/network_lokr.py
vendored
Normal file
@@ -0,0 +1,64 @@
|
||||
import torch
|
||||
|
||||
import lyco_helpers
|
||||
import network
|
||||
|
||||
|
||||
class ModuleTypeLokr(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
has_1 = "lokr_w1" in weights.w or ("lokr_w1_a" in weights.w and "lokr_w1_b" in weights.w)
|
||||
has_2 = "lokr_w2" in weights.w or ("lokr_w2_a" in weights.w and "lokr_w2_b" in weights.w)
|
||||
if has_1 and has_2:
|
||||
return NetworkModuleLokr(net, weights)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def make_kron(orig_shape, w1, w2):
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
return torch.kron(w1, w2).reshape(orig_shape)
|
||||
|
||||
|
||||
class NetworkModuleLokr(network.NetworkModule):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
|
||||
self.w1 = weights.w.get("lokr_w1")
|
||||
self.w1a = weights.w.get("lokr_w1_a")
|
||||
self.w1b = weights.w.get("lokr_w1_b")
|
||||
self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim
|
||||
self.w2 = weights.w.get("lokr_w2")
|
||||
self.w2a = weights.w.get("lokr_w2_a")
|
||||
self.w2b = weights.w.get("lokr_w2_b")
|
||||
self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim
|
||||
self.t2 = weights.w.get("lokr_t2")
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
if self.w1 is not None:
|
||||
w1 = self.w1.to(orig_weight.device)
|
||||
else:
|
||||
w1a = self.w1a.to(orig_weight.device)
|
||||
w1b = self.w1b.to(orig_weight.device)
|
||||
w1 = w1a @ w1b
|
||||
|
||||
if self.w2 is not None:
|
||||
w2 = self.w2.to(orig_weight.device)
|
||||
elif self.t2 is None:
|
||||
w2a = self.w2a.to(orig_weight.device)
|
||||
w2b = self.w2b.to(orig_weight.device)
|
||||
w2 = w2a @ w2b
|
||||
else:
|
||||
t2 = self.t2.to(orig_weight.device)
|
||||
w2a = self.w2a.to(orig_weight.device)
|
||||
w2b = self.w2b.to(orig_weight.device)
|
||||
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
||||
|
||||
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
|
||||
if len(orig_weight.shape) == 4:
|
||||
output_shape = orig_weight.shape
|
||||
|
||||
updown = make_kron(output_shape, w1, w2)
|
||||
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||
94
packages_3rdparty/webui_lora_collection/network_lora.py
vendored
Normal file
94
packages_3rdparty/webui_lora_collection/network_lora.py
vendored
Normal file
@@ -0,0 +1,94 @@
|
||||
import torch
|
||||
|
||||
import lyco_helpers
|
||||
import modules.models.sd3.mmdit
|
||||
import network
|
||||
from modules import devices
|
||||
|
||||
|
||||
class ModuleTypeLora(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):
|
||||
return NetworkModuleLora(net, weights)
|
||||
|
||||
if all(x in weights.w for x in ["lora_A.weight", "lora_B.weight"]):
|
||||
w = weights.w.copy()
|
||||
weights.w.clear()
|
||||
weights.w.update({"lora_up.weight": w["lora_B.weight"], "lora_down.weight": w["lora_A.weight"]})
|
||||
|
||||
return NetworkModuleLora(net, weights)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModuleLora(network.NetworkModule):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
|
||||
self.up_model = self.create_module(weights.w, "lora_up.weight")
|
||||
self.down_model = self.create_module(weights.w, "lora_down.weight")
|
||||
self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True)
|
||||
|
||||
self.dim = weights.w["lora_down.weight"].shape[0]
|
||||
|
||||
def create_module(self, weights, key, none_ok=False):
|
||||
weight = weights.get(key)
|
||||
|
||||
if weight is None and none_ok:
|
||||
return None
|
||||
|
||||
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear]
|
||||
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
||||
|
||||
if is_linear:
|
||||
weight = weight.reshape(weight.shape[0], -1)
|
||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||
elif is_conv and key == "lora_down.weight" or key == "dyn_up":
|
||||
if len(weight.shape) == 2:
|
||||
weight = weight.reshape(weight.shape[0], -1, 1, 1)
|
||||
|
||||
if weight.shape[2] != 1 or weight.shape[3] != 1:
|
||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
|
||||
else:
|
||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||
elif is_conv and key == "lora_mid.weight":
|
||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
|
||||
elif is_conv and key == "lora_up.weight" or key == "dyn_down":
|
||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||
else:
|
||||
raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')
|
||||
|
||||
with torch.no_grad():
|
||||
if weight.shape != module.weight.shape:
|
||||
weight = weight.reshape(module.weight.shape)
|
||||
module.weight.copy_(weight)
|
||||
|
||||
module.to(device=devices.cpu, dtype=devices.dtype)
|
||||
module.weight.requires_grad_(False)
|
||||
|
||||
return module
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
up = self.up_model.weight.to(orig_weight.device)
|
||||
down = self.down_model.weight.to(orig_weight.device)
|
||||
|
||||
output_shape = [up.size(0), down.size(1)]
|
||||
if self.mid_model is not None:
|
||||
# cp-decomposition
|
||||
mid = self.mid_model.weight.to(orig_weight.device)
|
||||
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
|
||||
output_shape += mid.shape[2:]
|
||||
else:
|
||||
if len(down.shape) == 4:
|
||||
output_shape += down.shape[2:]
|
||||
updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim)
|
||||
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||
|
||||
def forward(self, x, y):
|
||||
self.up_model.to(device=devices.device)
|
||||
self.down_model.to(device=devices.device)
|
||||
|
||||
return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()
|
||||
|
||||
|
||||
28
packages_3rdparty/webui_lora_collection/network_norm.py
vendored
Normal file
28
packages_3rdparty/webui_lora_collection/network_norm.py
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
import network
|
||||
|
||||
|
||||
class ModuleTypeNorm(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["w_norm", "b_norm"]):
|
||||
return NetworkModuleNorm(net, weights)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModuleNorm(network.NetworkModule):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
|
||||
self.w_norm = weights.w.get("w_norm")
|
||||
self.b_norm = weights.w.get("b_norm")
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
output_shape = self.w_norm.shape
|
||||
updown = self.w_norm.to(orig_weight.device)
|
||||
|
||||
if self.b_norm is not None:
|
||||
ex_bias = self.b_norm.to(orig_weight.device)
|
||||
else:
|
||||
ex_bias = None
|
||||
|
||||
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
|
||||
118
packages_3rdparty/webui_lora_collection/network_oft.py
vendored
Normal file
118
packages_3rdparty/webui_lora_collection/network_oft.py
vendored
Normal file
@@ -0,0 +1,118 @@
|
||||
import torch
|
||||
import network
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class ModuleTypeOFT(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]):
|
||||
return NetworkModuleOFT(net, weights)
|
||||
|
||||
return None
|
||||
|
||||
# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
|
||||
# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
|
||||
class NetworkModuleOFT(network.NetworkModule):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
|
||||
super().__init__(net, weights)
|
||||
|
||||
self.lin_module = None
|
||||
self.org_module: list[torch.Module] = [self.sd_module]
|
||||
|
||||
self.scale = 1.0
|
||||
self.is_R = False
|
||||
self.is_boft = False
|
||||
|
||||
# kohya-ss/New LyCORIS OFT/BOFT
|
||||
if "oft_blocks" in weights.w.keys():
|
||||
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
|
||||
self.alpha = weights.w.get("alpha", None) # alpha is constraint
|
||||
self.dim = self.oft_blocks.shape[0] # lora dim
|
||||
# Old LyCORIS OFT
|
||||
elif "oft_diag" in weights.w.keys():
|
||||
self.is_R = True
|
||||
self.oft_blocks = weights.w["oft_diag"]
|
||||
# self.alpha is unused
|
||||
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
|
||||
|
||||
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
|
||||
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
||||
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
|
||||
|
||||
if is_linear:
|
||||
self.out_dim = self.sd_module.out_features
|
||||
elif is_conv:
|
||||
self.out_dim = self.sd_module.out_channels
|
||||
elif is_other_linear:
|
||||
self.out_dim = self.sd_module.embed_dim
|
||||
|
||||
# LyCORIS BOFT
|
||||
if self.oft_blocks.dim() == 4:
|
||||
self.is_boft = True
|
||||
self.rescale = weights.w.get('rescale', None)
|
||||
if self.rescale is not None and not is_other_linear:
|
||||
self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))
|
||||
|
||||
self.num_blocks = self.dim
|
||||
self.block_size = self.out_dim // self.dim
|
||||
self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim
|
||||
if self.is_R:
|
||||
self.constraint = None
|
||||
self.block_size = self.dim
|
||||
self.num_blocks = self.out_dim // self.dim
|
||||
elif self.is_boft:
|
||||
self.boft_m = self.oft_blocks.shape[0]
|
||||
self.num_blocks = self.oft_blocks.shape[1]
|
||||
self.block_size = self.oft_blocks.shape[2]
|
||||
self.boft_b = self.block_size
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
oft_blocks = self.oft_blocks.to(orig_weight.device)
|
||||
eye = torch.eye(self.block_size, device=oft_blocks.device)
|
||||
|
||||
if not self.is_R:
|
||||
block_Q = oft_blocks - oft_blocks.transpose(-1, -2) # ensure skew-symmetric orthogonal matrix
|
||||
if self.constraint != 0:
|
||||
norm_Q = torch.norm(block_Q.flatten())
|
||||
new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device))
|
||||
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
|
||||
|
||||
R = oft_blocks.to(orig_weight.device)
|
||||
|
||||
if not self.is_boft:
|
||||
# This errors out for MultiheadAttention, might need to be handled up-stream
|
||||
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
||||
merged_weight = torch.einsum(
|
||||
'k n m, k n ... -> k m ...',
|
||||
R,
|
||||
merged_weight
|
||||
)
|
||||
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
|
||||
else:
|
||||
# TODO: determine correct value for scale
|
||||
scale = 1.0
|
||||
m = self.boft_m
|
||||
b = self.boft_b
|
||||
r_b = b // 2
|
||||
inp = orig_weight
|
||||
for i in range(m):
|
||||
bi = R[i] # b_num, b_size, b_size
|
||||
if i == 0:
|
||||
# Apply multiplier/scale and rescale into first weight
|
||||
bi = bi * scale + (1 - scale) * eye
|
||||
inp = rearrange(inp, "(c g k) ... -> (c k g) ...", g=2, k=2**i * r_b)
|
||||
inp = rearrange(inp, "(d b) ... -> d b ...", b=b)
|
||||
inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp)
|
||||
inp = rearrange(inp, "d b ... -> (d b) ...")
|
||||
inp = rearrange(inp, "(c k g) ... -> (c g k) ...", g=2, k=2**i * r_b)
|
||||
merged_weight = inp
|
||||
|
||||
# Rescale mechanism
|
||||
if self.rescale is not None:
|
||||
merged_weight = self.rescale.to(merged_weight) * merged_weight
|
||||
|
||||
updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype)
|
||||
output_shape = orig_weight.shape
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||
Reference in New Issue
Block a user