mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 11:11:15 +00:00
multiple lora implementation sources
This commit is contained in:
@@ -1,13 +1,15 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from backend import memory_management, attention, operations
|
from backend import memory_management, attention
|
||||||
from backend.modules.k_prediction import k_prediction_from_diffusers_scheduler
|
from backend.modules.k_prediction import k_prediction_from_diffusers_scheduler
|
||||||
|
|
||||||
|
|
||||||
class KModel(torch.nn.Module):
|
class KModel(torch.nn.Module):
|
||||||
def __init__(self, model, diffusers_scheduler, k_predictor=None):
|
def __init__(self, model, diffusers_scheduler, k_predictor=None, config=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
self.storage_dtype = model.storage_dtype
|
self.storage_dtype = model.storage_dtype
|
||||||
self.computation_dtype = model.computation_dtype
|
self.computation_dtype = model.computation_dtype
|
||||||
|
|
||||||
|
|||||||
@@ -1,293 +1,31 @@
|
|||||||
# LoRA Implementation Collection form ComfyUI
|
import packages_3rdparty.webui_lora_collection.lora as lora_utils_webui
|
||||||
# Modified by Forge to support greedy loading (load a set or wrong/correct loras to a model and only preserve the correct ones),
|
import packages_3rdparty.comfyui_lora_collection.lora as lora_utils_comfyui
|
||||||
# which is important to webui experience
|
|
||||||
|
|
||||||
from backend.misc.diffusers_state_dict import unet_to_diffusers
|
|
||||||
|
|
||||||
|
|
||||||
LORA_CLIP_MAP = {
|
class ForgeLoraCollection:
|
||||||
"mlp.fc1": "mlp_fc1",
|
# TODO
|
||||||
"mlp.fc2": "mlp_fc2",
|
pass
|
||||||
"self_attn.k_proj": "self_attn_k_proj",
|
|
||||||
"self_attn.q_proj": "self_attn_q_proj",
|
|
||||||
"self_attn.v_proj": "self_attn_v_proj",
|
lora_utils_forge = ForgeLoraCollection()
|
||||||
"self_attn.out_proj": "self_attn_out_proj",
|
|
||||||
}
|
lora_collection_priority = [lora_utils_forge, lora_utils_webui, lora_utils_comfyui]
|
||||||
|
|
||||||
|
|
||||||
|
def get_function(function_name: str):
|
||||||
|
for lora_collection in lora_collection_priority:
|
||||||
|
if hasattr(lora_collection, function_name):
|
||||||
|
return getattr(lora_collection, function_name)
|
||||||
|
|
||||||
|
|
||||||
def load_lora(lora, to_load):
|
def load_lora(lora, to_load):
|
||||||
patch_dict = {}
|
patch_dict, remaining_dict = get_function('load_lora')(lora, to_load)
|
||||||
loaded_keys = set()
|
|
||||||
for x in to_load:
|
|
||||||
alpha_name = "{}.alpha".format(x)
|
|
||||||
alpha = None
|
|
||||||
if alpha_name in lora.keys():
|
|
||||||
alpha = lora[alpha_name].item()
|
|
||||||
loaded_keys.add(alpha_name)
|
|
||||||
|
|
||||||
dora_scale_name = "{}.dora_scale".format(x)
|
|
||||||
dora_scale = None
|
|
||||||
if dora_scale_name in lora.keys():
|
|
||||||
dora_scale = lora[dora_scale_name]
|
|
||||||
loaded_keys.add(dora_scale_name)
|
|
||||||
|
|
||||||
regular_lora = "{}.lora_up.weight".format(x)
|
|
||||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
|
||||||
diffusers2_lora = "{}.lora_B.weight".format(x)
|
|
||||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
|
||||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
|
||||||
A_name = None
|
|
||||||
|
|
||||||
if regular_lora in lora.keys():
|
|
||||||
A_name = regular_lora
|
|
||||||
B_name = "{}.lora_down.weight".format(x)
|
|
||||||
mid_name = "{}.lora_mid.weight".format(x)
|
|
||||||
elif diffusers_lora in lora.keys():
|
|
||||||
A_name = diffusers_lora
|
|
||||||
B_name = "{}_lora.down.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif diffusers2_lora in lora.keys():
|
|
||||||
A_name = diffusers2_lora
|
|
||||||
B_name = "{}.lora_A.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif diffusers3_lora in lora.keys():
|
|
||||||
A_name = diffusers3_lora
|
|
||||||
B_name = "{}.lora.down.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif transformers_lora in lora.keys():
|
|
||||||
A_name = transformers_lora
|
|
||||||
B_name = "{}.lora_linear_layer.down.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
|
|
||||||
if A_name is not None:
|
|
||||||
mid = None
|
|
||||||
if mid_name is not None and mid_name in lora.keys():
|
|
||||||
mid = lora[mid_name]
|
|
||||||
loaded_keys.add(mid_name)
|
|
||||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
|
|
||||||
loaded_keys.add(A_name)
|
|
||||||
loaded_keys.add(B_name)
|
|
||||||
|
|
||||||
######## loha
|
|
||||||
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
|
||||||
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
|
||||||
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
|
||||||
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
|
||||||
hada_t1_name = "{}.hada_t1".format(x)
|
|
||||||
hada_t2_name = "{}.hada_t2".format(x)
|
|
||||||
if hada_w1_a_name in lora.keys():
|
|
||||||
hada_t1 = None
|
|
||||||
hada_t2 = None
|
|
||||||
if hada_t1_name in lora.keys():
|
|
||||||
hada_t1 = lora[hada_t1_name]
|
|
||||||
hada_t2 = lora[hada_t2_name]
|
|
||||||
loaded_keys.add(hada_t1_name)
|
|
||||||
loaded_keys.add(hada_t2_name)
|
|
||||||
|
|
||||||
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
|
|
||||||
loaded_keys.add(hada_w1_a_name)
|
|
||||||
loaded_keys.add(hada_w1_b_name)
|
|
||||||
loaded_keys.add(hada_w2_a_name)
|
|
||||||
loaded_keys.add(hada_w2_b_name)
|
|
||||||
|
|
||||||
######## lokr
|
|
||||||
lokr_w1_name = "{}.lokr_w1".format(x)
|
|
||||||
lokr_w2_name = "{}.lokr_w2".format(x)
|
|
||||||
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
|
||||||
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
|
||||||
lokr_t2_name = "{}.lokr_t2".format(x)
|
|
||||||
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
|
||||||
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
|
||||||
|
|
||||||
lokr_w1 = None
|
|
||||||
if lokr_w1_name in lora.keys():
|
|
||||||
lokr_w1 = lora[lokr_w1_name]
|
|
||||||
loaded_keys.add(lokr_w1_name)
|
|
||||||
|
|
||||||
lokr_w2 = None
|
|
||||||
if lokr_w2_name in lora.keys():
|
|
||||||
lokr_w2 = lora[lokr_w2_name]
|
|
||||||
loaded_keys.add(lokr_w2_name)
|
|
||||||
|
|
||||||
lokr_w1_a = None
|
|
||||||
if lokr_w1_a_name in lora.keys():
|
|
||||||
lokr_w1_a = lora[lokr_w1_a_name]
|
|
||||||
loaded_keys.add(lokr_w1_a_name)
|
|
||||||
|
|
||||||
lokr_w1_b = None
|
|
||||||
if lokr_w1_b_name in lora.keys():
|
|
||||||
lokr_w1_b = lora[lokr_w1_b_name]
|
|
||||||
loaded_keys.add(lokr_w1_b_name)
|
|
||||||
|
|
||||||
lokr_w2_a = None
|
|
||||||
if lokr_w2_a_name in lora.keys():
|
|
||||||
lokr_w2_a = lora[lokr_w2_a_name]
|
|
||||||
loaded_keys.add(lokr_w2_a_name)
|
|
||||||
|
|
||||||
lokr_w2_b = None
|
|
||||||
if lokr_w2_b_name in lora.keys():
|
|
||||||
lokr_w2_b = lora[lokr_w2_b_name]
|
|
||||||
loaded_keys.add(lokr_w2_b_name)
|
|
||||||
|
|
||||||
lokr_t2 = None
|
|
||||||
if lokr_t2_name in lora.keys():
|
|
||||||
lokr_t2 = lora[lokr_t2_name]
|
|
||||||
loaded_keys.add(lokr_t2_name)
|
|
||||||
|
|
||||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
|
||||||
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
|
|
||||||
|
|
||||||
# glora
|
|
||||||
a1_name = "{}.a1.weight".format(x)
|
|
||||||
a2_name = "{}.a2.weight".format(x)
|
|
||||||
b1_name = "{}.b1.weight".format(x)
|
|
||||||
b2_name = "{}.b2.weight".format(x)
|
|
||||||
if a1_name in lora:
|
|
||||||
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
|
|
||||||
loaded_keys.add(a1_name)
|
|
||||||
loaded_keys.add(a2_name)
|
|
||||||
loaded_keys.add(b1_name)
|
|
||||||
loaded_keys.add(b2_name)
|
|
||||||
|
|
||||||
w_norm_name = "{}.w_norm".format(x)
|
|
||||||
b_norm_name = "{}.b_norm".format(x)
|
|
||||||
w_norm = lora.get(w_norm_name, None)
|
|
||||||
b_norm = lora.get(b_norm_name, None)
|
|
||||||
|
|
||||||
if w_norm is not None:
|
|
||||||
loaded_keys.add(w_norm_name)
|
|
||||||
patch_dict[to_load[x]] = ("diff", (w_norm,))
|
|
||||||
if b_norm is not None:
|
|
||||||
loaded_keys.add(b_norm_name)
|
|
||||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
|
|
||||||
|
|
||||||
diff_name = "{}.diff".format(x)
|
|
||||||
diff_weight = lora.get(diff_name, None)
|
|
||||||
if diff_weight is not None:
|
|
||||||
patch_dict[to_load[x]] = ("diff", (diff_weight,))
|
|
||||||
loaded_keys.add(diff_name)
|
|
||||||
|
|
||||||
diff_bias_name = "{}.diff_b".format(x)
|
|
||||||
diff_bias = lora.get(diff_bias_name, None)
|
|
||||||
if diff_bias is not None:
|
|
||||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
|
||||||
loaded_keys.add(diff_bias_name)
|
|
||||||
|
|
||||||
remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
|
|
||||||
return patch_dict, remaining_dict
|
return patch_dict, remaining_dict
|
||||||
|
|
||||||
|
|
||||||
def model_lora_keys_clip(model, key_map={}):
|
def model_lora_keys_clip(model, key_map={}):
|
||||||
sdk = model.state_dict().keys()
|
return get_function('model_lora_keys_clip')(model, key_map)
|
||||||
|
|
||||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
|
||||||
clip_l_present = False
|
|
||||||
for b in range(32): # TODO: clean up
|
|
||||||
for c in LORA_CLIP_MAP:
|
|
||||||
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
|
||||||
key_map[lora_key] = k
|
|
||||||
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
|
||||||
key_map[lora_key] = k
|
|
||||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c)
|
|
||||||
key_map[lora_key] = k
|
|
||||||
|
|
||||||
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
|
||||||
key_map[lora_key] = k
|
|
||||||
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
|
||||||
key_map[lora_key] = k
|
|
||||||
clip_l_present = True
|
|
||||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c)
|
|
||||||
key_map[lora_key] = k
|
|
||||||
|
|
||||||
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
|
||||||
if k in sdk:
|
|
||||||
if clip_l_present:
|
|
||||||
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
|
||||||
key_map[lora_key] = k
|
|
||||||
lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c)
|
|
||||||
key_map[lora_key] = k
|
|
||||||
else:
|
|
||||||
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
|
||||||
key_map[lora_key] = k
|
|
||||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c)
|
|
||||||
key_map[lora_key] = k
|
|
||||||
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
|
||||||
key_map[lora_key] = k
|
|
||||||
|
|
||||||
for k in sdk: # OneTrainer SD3 lora
|
|
||||||
if k.startswith("t5xxl.transformer.") and k.endswith(".weight"):
|
|
||||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
|
||||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
|
||||||
key_map[lora_key] = k
|
|
||||||
|
|
||||||
k = "clip_g.transformer.text_projection.weight"
|
|
||||||
if k in sdk:
|
|
||||||
key_map["lora_prior_te_text_projection"] = k
|
|
||||||
key_map["lora_te2_text_projection"] = k
|
|
||||||
|
|
||||||
k = "clip_l.transformer.text_projection.weight"
|
|
||||||
if k in sdk:
|
|
||||||
key_map["lora_te1_text_projection"] = k
|
|
||||||
|
|
||||||
return key_map
|
|
||||||
|
|
||||||
|
|
||||||
def model_lora_keys_unet(model, key_map={}):
|
def model_lora_keys_unet(model, key_map={}):
|
||||||
sd = model.state_dict()
|
return get_function('model_lora_keys_unet')(model, key_map)
|
||||||
sdk = sd.keys()
|
|
||||||
|
|
||||||
for k in sdk:
|
|
||||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
|
||||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
|
||||||
key_map["lora_unet_{}".format(key_lora)] = k
|
|
||||||
key_map["lora_prior_unet_{}".format(key_lora)] = k
|
|
||||||
|
|
||||||
diffusers_keys = unet_to_diffusers(model.diffusion_model.config)
|
|
||||||
for k in diffusers_keys:
|
|
||||||
if k.endswith(".weight"):
|
|
||||||
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
|
||||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
|
||||||
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
|
||||||
|
|
||||||
diffusers_lora_prefix = ["", "unet."]
|
|
||||||
for p in diffusers_lora_prefix:
|
|
||||||
diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
|
|
||||||
if diffusers_lora_key.endswith(".to_out.0"):
|
|
||||||
diffusers_lora_key = diffusers_lora_key[:-2]
|
|
||||||
key_map[diffusers_lora_key] = unet_key
|
|
||||||
|
|
||||||
# TODO:
|
|
||||||
|
|
||||||
# if isinstance(model, xxxx.modules.model_base.SD3): # Diffusers lora SD3
|
|
||||||
# diffusers_keys = xxxx.modules.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
|
||||||
# for k in diffusers_keys:
|
|
||||||
# if k.endswith(".weight"):
|
|
||||||
# to = diffusers_keys[k]
|
|
||||||
# key_lora = "transformer.{}".format(k[:-len(".weight")]) # regular diffusers sd3 lora format
|
|
||||||
# key_map[key_lora] = to
|
|
||||||
#
|
|
||||||
# key_lora = "base_model.model.{}".format(k[:-len(".weight")]) # format for flash-sd3 lora and others?
|
|
||||||
# key_map[key_lora] = to
|
|
||||||
#
|
|
||||||
# key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) # OneTrainer lora
|
|
||||||
# key_map[key_lora] = to
|
|
||||||
#
|
|
||||||
# if isinstance(model, xxxx.modules.model_base.AuraFlow): # Diffusers lora AuraFlow
|
|
||||||
# diffusers_keys = xxxx.modules.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
|
||||||
# for k in diffusers_keys:
|
|
||||||
# if k.endswith(".weight"):
|
|
||||||
# to = diffusers_keys[k]
|
|
||||||
# key_lora = "transformer.{}".format(k[:-len(".weight")]) # simpletrainer and probably regular diffusers lora format
|
|
||||||
# key_map[key_lora] = to
|
|
||||||
#
|
|
||||||
# if isinstance(model, xxxx.modules.model_base.HunyuanDiT):
|
|
||||||
# for k in sdk:
|
|
||||||
# if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
|
||||||
# key_lora = k[len("diffusion_model."):-len(".weight")]
|
|
||||||
# key_map["base_model.model.{}".format(key_lora)] = k # official hunyuan lora format
|
|
||||||
|
|
||||||
return key_map
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from backend.patcher.base import ModelPatcher
|
|||||||
class UnetPatcher(ModelPatcher):
|
class UnetPatcher(ModelPatcher):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_model(cls, model, diffusers_scheduler, config, k_predictor=None):
|
def from_model(cls, model, diffusers_scheduler, config, k_predictor=None):
|
||||||
model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor)
|
model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor, config=config)
|
||||||
return UnetPatcher(
|
return UnetPatcher(
|
||||||
model,
|
model,
|
||||||
load_device=model.diffusion_model.load_device,
|
load_device=model.diffusion_model.load_device,
|
||||||
|
|||||||
1
packages_3rdparty/README.md
vendored
Normal file
1
packages_3rdparty/README.md
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
Please follow the standard of https://github.com/opencv/opencv/tree/315f85d4f484c1e2fa043c73ac3fdd9fc5997ee7/3rdparty when PR or modifying files.
|
||||||
323
packages_3rdparty/comfyui_lora_collection/lora.py
vendored
Normal file
323
packages_3rdparty/comfyui_lora_collection/lora.py
vendored
Normal file
@@ -0,0 +1,323 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from packages_3rdparty.comfyui_lora_collection import utils
|
||||||
|
|
||||||
|
|
||||||
|
LORA_CLIP_MAP = {
|
||||||
|
"mlp.fc1": "mlp_fc1",
|
||||||
|
"mlp.fc2": "mlp_fc2",
|
||||||
|
"self_attn.k_proj": "self_attn_k_proj",
|
||||||
|
"self_attn.q_proj": "self_attn_q_proj",
|
||||||
|
"self_attn.v_proj": "self_attn_v_proj",
|
||||||
|
"self_attn.out_proj": "self_attn_out_proj",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_lora(lora, to_load):
|
||||||
|
patch_dict = {}
|
||||||
|
loaded_keys = set()
|
||||||
|
for x in to_load:
|
||||||
|
alpha_name = "{}.alpha".format(x)
|
||||||
|
alpha = None
|
||||||
|
if alpha_name in lora.keys():
|
||||||
|
alpha = lora[alpha_name].item()
|
||||||
|
loaded_keys.add(alpha_name)
|
||||||
|
|
||||||
|
dora_scale_name = "{}.dora_scale".format(x)
|
||||||
|
dora_scale = None
|
||||||
|
if dora_scale_name in lora.keys():
|
||||||
|
dora_scale = lora[dora_scale_name]
|
||||||
|
loaded_keys.add(dora_scale_name)
|
||||||
|
|
||||||
|
regular_lora = "{}.lora_up.weight".format(x)
|
||||||
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||||
|
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||||
|
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||||
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||||
|
A_name = None
|
||||||
|
|
||||||
|
if regular_lora in lora.keys():
|
||||||
|
A_name = regular_lora
|
||||||
|
B_name = "{}.lora_down.weight".format(x)
|
||||||
|
mid_name = "{}.lora_mid.weight".format(x)
|
||||||
|
elif diffusers_lora in lora.keys():
|
||||||
|
A_name = diffusers_lora
|
||||||
|
B_name = "{}_lora.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif diffusers2_lora in lora.keys():
|
||||||
|
A_name = diffusers2_lora
|
||||||
|
B_name = "{}.lora_A.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif diffusers3_lora in lora.keys():
|
||||||
|
A_name = diffusers3_lora
|
||||||
|
B_name = "{}.lora.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif transformers_lora in lora.keys():
|
||||||
|
A_name = transformers_lora
|
||||||
|
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
|
||||||
|
if A_name is not None:
|
||||||
|
mid = None
|
||||||
|
if mid_name is not None and mid_name in lora.keys():
|
||||||
|
mid = lora[mid_name]
|
||||||
|
loaded_keys.add(mid_name)
|
||||||
|
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
|
||||||
|
loaded_keys.add(A_name)
|
||||||
|
loaded_keys.add(B_name)
|
||||||
|
|
||||||
|
|
||||||
|
######## loha
|
||||||
|
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
||||||
|
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
||||||
|
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
||||||
|
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
||||||
|
hada_t1_name = "{}.hada_t1".format(x)
|
||||||
|
hada_t2_name = "{}.hada_t2".format(x)
|
||||||
|
if hada_w1_a_name in lora.keys():
|
||||||
|
hada_t1 = None
|
||||||
|
hada_t2 = None
|
||||||
|
if hada_t1_name in lora.keys():
|
||||||
|
hada_t1 = lora[hada_t1_name]
|
||||||
|
hada_t2 = lora[hada_t2_name]
|
||||||
|
loaded_keys.add(hada_t1_name)
|
||||||
|
loaded_keys.add(hada_t2_name)
|
||||||
|
|
||||||
|
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
|
||||||
|
loaded_keys.add(hada_w1_a_name)
|
||||||
|
loaded_keys.add(hada_w1_b_name)
|
||||||
|
loaded_keys.add(hada_w2_a_name)
|
||||||
|
loaded_keys.add(hada_w2_b_name)
|
||||||
|
|
||||||
|
|
||||||
|
######## lokr
|
||||||
|
lokr_w1_name = "{}.lokr_w1".format(x)
|
||||||
|
lokr_w2_name = "{}.lokr_w2".format(x)
|
||||||
|
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
||||||
|
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
||||||
|
lokr_t2_name = "{}.lokr_t2".format(x)
|
||||||
|
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
||||||
|
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
||||||
|
|
||||||
|
lokr_w1 = None
|
||||||
|
if lokr_w1_name in lora.keys():
|
||||||
|
lokr_w1 = lora[lokr_w1_name]
|
||||||
|
loaded_keys.add(lokr_w1_name)
|
||||||
|
|
||||||
|
lokr_w2 = None
|
||||||
|
if lokr_w2_name in lora.keys():
|
||||||
|
lokr_w2 = lora[lokr_w2_name]
|
||||||
|
loaded_keys.add(lokr_w2_name)
|
||||||
|
|
||||||
|
lokr_w1_a = None
|
||||||
|
if lokr_w1_a_name in lora.keys():
|
||||||
|
lokr_w1_a = lora[lokr_w1_a_name]
|
||||||
|
loaded_keys.add(lokr_w1_a_name)
|
||||||
|
|
||||||
|
lokr_w1_b = None
|
||||||
|
if lokr_w1_b_name in lora.keys():
|
||||||
|
lokr_w1_b = lora[lokr_w1_b_name]
|
||||||
|
loaded_keys.add(lokr_w1_b_name)
|
||||||
|
|
||||||
|
lokr_w2_a = None
|
||||||
|
if lokr_w2_a_name in lora.keys():
|
||||||
|
lokr_w2_a = lora[lokr_w2_a_name]
|
||||||
|
loaded_keys.add(lokr_w2_a_name)
|
||||||
|
|
||||||
|
lokr_w2_b = None
|
||||||
|
if lokr_w2_b_name in lora.keys():
|
||||||
|
lokr_w2_b = lora[lokr_w2_b_name]
|
||||||
|
loaded_keys.add(lokr_w2_b_name)
|
||||||
|
|
||||||
|
lokr_t2 = None
|
||||||
|
if lokr_t2_name in lora.keys():
|
||||||
|
lokr_t2 = lora[lokr_t2_name]
|
||||||
|
loaded_keys.add(lokr_t2_name)
|
||||||
|
|
||||||
|
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||||
|
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
|
||||||
|
|
||||||
|
#glora
|
||||||
|
a1_name = "{}.a1.weight".format(x)
|
||||||
|
a2_name = "{}.a2.weight".format(x)
|
||||||
|
b1_name = "{}.b1.weight".format(x)
|
||||||
|
b2_name = "{}.b2.weight".format(x)
|
||||||
|
if a1_name in lora:
|
||||||
|
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
|
||||||
|
loaded_keys.add(a1_name)
|
||||||
|
loaded_keys.add(a2_name)
|
||||||
|
loaded_keys.add(b1_name)
|
||||||
|
loaded_keys.add(b2_name)
|
||||||
|
|
||||||
|
w_norm_name = "{}.w_norm".format(x)
|
||||||
|
b_norm_name = "{}.b_norm".format(x)
|
||||||
|
w_norm = lora.get(w_norm_name, None)
|
||||||
|
b_norm = lora.get(b_norm_name, None)
|
||||||
|
|
||||||
|
if w_norm is not None:
|
||||||
|
loaded_keys.add(w_norm_name)
|
||||||
|
patch_dict[to_load[x]] = ("diff", (w_norm,))
|
||||||
|
if b_norm is not None:
|
||||||
|
loaded_keys.add(b_norm_name)
|
||||||
|
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
|
||||||
|
|
||||||
|
diff_name = "{}.diff".format(x)
|
||||||
|
diff_weight = lora.get(diff_name, None)
|
||||||
|
if diff_weight is not None:
|
||||||
|
patch_dict[to_load[x]] = ("diff", (diff_weight,))
|
||||||
|
loaded_keys.add(diff_name)
|
||||||
|
|
||||||
|
diff_bias_name = "{}.diff_b".format(x)
|
||||||
|
diff_bias = lora.get(diff_bias_name, None)
|
||||||
|
if diff_bias is not None:
|
||||||
|
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
||||||
|
loaded_keys.add(diff_bias_name)
|
||||||
|
|
||||||
|
remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
|
||||||
|
return patch_dict, remaining_dict
|
||||||
|
|
||||||
|
|
||||||
|
def model_lora_keys_clip(model, key_map={}):
|
||||||
|
sdk = model.state_dict().keys()
|
||||||
|
|
||||||
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||||
|
clip_l_present = False
|
||||||
|
for b in range(32): #TODO: clean up
|
||||||
|
for c in LORA_CLIP_MAP:
|
||||||
|
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
|
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
|
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
||||||
|
key_map[lora_key] = k
|
||||||
|
|
||||||
|
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
|
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||||
|
key_map[lora_key] = k
|
||||||
|
clip_l_present = True
|
||||||
|
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
||||||
|
key_map[lora_key] = k
|
||||||
|
|
||||||
|
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
|
if k in sdk:
|
||||||
|
if clip_l_present:
|
||||||
|
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||||
|
key_map[lora_key] = k
|
||||||
|
lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
||||||
|
key_map[lora_key] = k
|
||||||
|
else:
|
||||||
|
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
|
||||||
|
key_map[lora_key] = k
|
||||||
|
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
||||||
|
key_map[lora_key] = k
|
||||||
|
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
|
||||||
|
key_map[lora_key] = k
|
||||||
|
|
||||||
|
for k in sdk:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
|
||||||
|
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||||
|
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
||||||
|
key_map[lora_key] = k
|
||||||
|
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
||||||
|
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
||||||
|
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
||||||
|
key_map[lora_key] = k
|
||||||
|
|
||||||
|
|
||||||
|
k = "clip_g.transformer.text_projection.weight"
|
||||||
|
if k in sdk:
|
||||||
|
key_map["lora_prior_te_text_projection"] = k #cascade lora?
|
||||||
|
# key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
|
||||||
|
key_map["lora_te2_text_projection"] = k #OneTrainer SD3 lora
|
||||||
|
|
||||||
|
k = "clip_l.transformer.text_projection.weight"
|
||||||
|
if k in sdk:
|
||||||
|
key_map["lora_te1_text_projection"] = k #OneTrainer SD3 lora, not necessary but omits warning
|
||||||
|
|
||||||
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
|
def model_lora_keys_unet(model, key_map={}):
|
||||||
|
sd = model.state_dict()
|
||||||
|
sdk = sd.keys()
|
||||||
|
|
||||||
|
for k in sdk:
|
||||||
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||||
|
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||||
|
key_map["lora_unet_{}".format(key_lora)] = k
|
||||||
|
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||||
|
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||||
|
|
||||||
|
diffusers_keys = utils.unet_to_diffusers(model.diffusion_model.config)
|
||||||
|
for k in diffusers_keys:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
||||||
|
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||||
|
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
||||||
|
|
||||||
|
diffusers_lora_prefix = ["", "unet."]
|
||||||
|
for p in diffusers_lora_prefix:
|
||||||
|
diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
|
||||||
|
if diffusers_lora_key.endswith(".to_out.0"):
|
||||||
|
diffusers_lora_key = diffusers_lora_key[:-2]
|
||||||
|
key_map[diffusers_lora_key] = unet_key
|
||||||
|
|
||||||
|
# if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3
|
||||||
|
# diffusers_keys = utils.mmdit_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.")
|
||||||
|
# for k in diffusers_keys:
|
||||||
|
# if k.endswith(".weight"):
|
||||||
|
# to = diffusers_keys[k]
|
||||||
|
# key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format
|
||||||
|
# key_map[key_lora] = to
|
||||||
|
#
|
||||||
|
# key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others?
|
||||||
|
# key_map[key_lora] = to
|
||||||
|
#
|
||||||
|
# key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
||||||
|
# key_map[key_lora] = to
|
||||||
|
#
|
||||||
|
# if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
|
||||||
|
# diffusers_keys = utils.auraflow_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.")
|
||||||
|
# for k in diffusers_keys:
|
||||||
|
# if k.endswith(".weight"):
|
||||||
|
# to = diffusers_keys[k]
|
||||||
|
# key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
|
||||||
|
# key_map[key_lora] = to
|
||||||
|
#
|
||||||
|
# if isinstance(model, comfy.model_base.HunyuanDiT):
|
||||||
|
# for k in sdk:
|
||||||
|
# if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||||
|
# key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||||
|
# key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format
|
||||||
|
|
||||||
|
if 'flux' in model.config.huggingface_repo.lower(): #Diffusers lora Flux
|
||||||
|
diffusers_keys = utils.flux_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.")
|
||||||
|
for k in diffusers_keys:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
to = diffusers_keys[k]
|
||||||
|
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format
|
||||||
|
key_map[key_lora] = to
|
||||||
|
|
||||||
|
return key_map
|
||||||
760
packages_3rdparty/comfyui_lora_collection/utils.py
vendored
Normal file
760
packages_3rdparty/comfyui_lora_collection/utils.py
vendored
Normal file
@@ -0,0 +1,760 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import struct
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_parameters(sd, prefix=""):
|
||||||
|
params = 0
|
||||||
|
for k in sd.keys():
|
||||||
|
if k.startswith(prefix):
|
||||||
|
w = sd[k]
|
||||||
|
params += w.nelement()
|
||||||
|
return params
|
||||||
|
|
||||||
|
def weight_dtype(sd, prefix=""):
|
||||||
|
dtypes = {}
|
||||||
|
for k in sd.keys():
|
||||||
|
if k.startswith(prefix):
|
||||||
|
w = sd[k]
|
||||||
|
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1
|
||||||
|
|
||||||
|
if len(dtypes) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return max(dtypes, key=dtypes.get)
|
||||||
|
|
||||||
|
def state_dict_key_replace(state_dict, keys_to_replace):
|
||||||
|
for x in keys_to_replace:
|
||||||
|
if x in state_dict:
|
||||||
|
state_dict[keys_to_replace[x]] = state_dict.pop(x)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
|
||||||
|
if filter_keys:
|
||||||
|
out = {}
|
||||||
|
else:
|
||||||
|
out = state_dict
|
||||||
|
for rp in replace_prefix:
|
||||||
|
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys())))
|
||||||
|
for x in replace:
|
||||||
|
w = state_dict.pop(x[0])
|
||||||
|
out[x[1]] = w
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||||
|
keys_to_replace = {
|
||||||
|
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
||||||
|
"{}token_embedding.weight": "{}embeddings.token_embedding.weight",
|
||||||
|
"{}ln_final.weight": "{}final_layer_norm.weight",
|
||||||
|
"{}ln_final.bias": "{}final_layer_norm.bias",
|
||||||
|
}
|
||||||
|
|
||||||
|
for k in keys_to_replace:
|
||||||
|
x = k.format(prefix_from)
|
||||||
|
if x in sd:
|
||||||
|
sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x)
|
||||||
|
|
||||||
|
resblock_to_replace = {
|
||||||
|
"ln_1": "layer_norm1",
|
||||||
|
"ln_2": "layer_norm2",
|
||||||
|
"mlp.c_fc": "mlp.fc1",
|
||||||
|
"mlp.c_proj": "mlp.fc2",
|
||||||
|
"attn.out_proj": "self_attn.out_proj",
|
||||||
|
}
|
||||||
|
|
||||||
|
for resblock in range(number):
|
||||||
|
for x in resblock_to_replace:
|
||||||
|
for y in ["weight", "bias"]:
|
||||||
|
k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
|
||||||
|
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
||||||
|
if k in sd:
|
||||||
|
sd[k_to] = sd.pop(k)
|
||||||
|
|
||||||
|
for y in ["weight", "bias"]:
|
||||||
|
k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
|
||||||
|
if k_from in sd:
|
||||||
|
weights = sd.pop(k_from)
|
||||||
|
shape_from = weights.shape[0] // 3
|
||||||
|
for x in range(3):
|
||||||
|
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
|
||||||
|
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
|
||||||
|
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||||
|
|
||||||
|
return sd
|
||||||
|
|
||||||
|
def clip_text_transformers_convert(sd, prefix_from, prefix_to):
|
||||||
|
sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32)
|
||||||
|
|
||||||
|
tp = "{}text_projection.weight".format(prefix_from)
|
||||||
|
if tp in sd:
|
||||||
|
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp)
|
||||||
|
|
||||||
|
tp = "{}text_projection".format(prefix_from)
|
||||||
|
if tp in sd:
|
||||||
|
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1).contiguous()
|
||||||
|
return sd
|
||||||
|
|
||||||
|
|
||||||
|
UNET_MAP_ATTENTIONS = {
|
||||||
|
"proj_in.weight",
|
||||||
|
"proj_in.bias",
|
||||||
|
"proj_out.weight",
|
||||||
|
"proj_out.bias",
|
||||||
|
"norm.weight",
|
||||||
|
"norm.bias",
|
||||||
|
}
|
||||||
|
|
||||||
|
TRANSFORMER_BLOCKS = {
|
||||||
|
"norm1.weight",
|
||||||
|
"norm1.bias",
|
||||||
|
"norm2.weight",
|
||||||
|
"norm2.bias",
|
||||||
|
"norm3.weight",
|
||||||
|
"norm3.bias",
|
||||||
|
"attn1.to_q.weight",
|
||||||
|
"attn1.to_k.weight",
|
||||||
|
"attn1.to_v.weight",
|
||||||
|
"attn1.to_out.0.weight",
|
||||||
|
"attn1.to_out.0.bias",
|
||||||
|
"attn2.to_q.weight",
|
||||||
|
"attn2.to_k.weight",
|
||||||
|
"attn2.to_v.weight",
|
||||||
|
"attn2.to_out.0.weight",
|
||||||
|
"attn2.to_out.0.bias",
|
||||||
|
"ff.net.0.proj.weight",
|
||||||
|
"ff.net.0.proj.bias",
|
||||||
|
"ff.net.2.weight",
|
||||||
|
"ff.net.2.bias",
|
||||||
|
}
|
||||||
|
|
||||||
|
UNET_MAP_RESNET = {
|
||||||
|
"in_layers.2.weight": "conv1.weight",
|
||||||
|
"in_layers.2.bias": "conv1.bias",
|
||||||
|
"emb_layers.1.weight": "time_emb_proj.weight",
|
||||||
|
"emb_layers.1.bias": "time_emb_proj.bias",
|
||||||
|
"out_layers.3.weight": "conv2.weight",
|
||||||
|
"out_layers.3.bias": "conv2.bias",
|
||||||
|
"skip_connection.weight": "conv_shortcut.weight",
|
||||||
|
"skip_connection.bias": "conv_shortcut.bias",
|
||||||
|
"in_layers.0.weight": "norm1.weight",
|
||||||
|
"in_layers.0.bias": "norm1.bias",
|
||||||
|
"out_layers.0.weight": "norm2.weight",
|
||||||
|
"out_layers.0.bias": "norm2.bias",
|
||||||
|
}
|
||||||
|
|
||||||
|
UNET_MAP_BASIC = {
|
||||||
|
("label_emb.0.0.weight", "class_embedding.linear_1.weight"),
|
||||||
|
("label_emb.0.0.bias", "class_embedding.linear_1.bias"),
|
||||||
|
("label_emb.0.2.weight", "class_embedding.linear_2.weight"),
|
||||||
|
("label_emb.0.2.bias", "class_embedding.linear_2.bias"),
|
||||||
|
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
|
||||||
|
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
|
||||||
|
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
|
||||||
|
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
|
||||||
|
("input_blocks.0.0.weight", "conv_in.weight"),
|
||||||
|
("input_blocks.0.0.bias", "conv_in.bias"),
|
||||||
|
("out.0.weight", "conv_norm_out.weight"),
|
||||||
|
("out.0.bias", "conv_norm_out.bias"),
|
||||||
|
("out.2.weight", "conv_out.weight"),
|
||||||
|
("out.2.bias", "conv_out.bias"),
|
||||||
|
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||||
|
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
||||||
|
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
||||||
|
("time_embed.2.bias", "time_embedding.linear_2.bias")
|
||||||
|
}
|
||||||
|
|
||||||
|
def unet_to_diffusers(unet_config):
|
||||||
|
if "num_res_blocks" not in unet_config:
|
||||||
|
return {}
|
||||||
|
num_res_blocks = unet_config["num_res_blocks"]
|
||||||
|
channel_mult = unet_config["channel_mult"]
|
||||||
|
transformer_depth = unet_config["transformer_depth"][:]
|
||||||
|
transformer_depth_output = unet_config["transformer_depth_output"][:]
|
||||||
|
num_blocks = len(channel_mult)
|
||||||
|
|
||||||
|
transformers_mid = unet_config.get("transformer_depth_middle", None)
|
||||||
|
|
||||||
|
diffusers_unet_map = {}
|
||||||
|
for x in range(num_blocks):
|
||||||
|
n = 1 + (num_res_blocks[x] + 1) * x
|
||||||
|
for i in range(num_res_blocks[x]):
|
||||||
|
for b in UNET_MAP_RESNET:
|
||||||
|
diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
|
||||||
|
num_transformers = transformer_depth.pop(0)
|
||||||
|
if num_transformers > 0:
|
||||||
|
for b in UNET_MAP_ATTENTIONS:
|
||||||
|
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
|
||||||
|
for t in range(num_transformers):
|
||||||
|
for b in TRANSFORMER_BLOCKS:
|
||||||
|
diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||||
|
n += 1
|
||||||
|
for k in ["weight", "bias"]:
|
||||||
|
diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
for b in UNET_MAP_ATTENTIONS:
|
||||||
|
diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b)
|
||||||
|
for t in range(transformers_mid):
|
||||||
|
for b in TRANSFORMER_BLOCKS:
|
||||||
|
diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b)
|
||||||
|
|
||||||
|
for i, n in enumerate([0, 2]):
|
||||||
|
for b in UNET_MAP_RESNET:
|
||||||
|
diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
|
||||||
|
|
||||||
|
num_res_blocks = list(reversed(num_res_blocks))
|
||||||
|
for x in range(num_blocks):
|
||||||
|
n = (num_res_blocks[x] + 1) * x
|
||||||
|
l = num_res_blocks[x] + 1
|
||||||
|
for i in range(l):
|
||||||
|
c = 0
|
||||||
|
for b in UNET_MAP_RESNET:
|
||||||
|
diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
|
||||||
|
c += 1
|
||||||
|
num_transformers = transformer_depth_output.pop()
|
||||||
|
if num_transformers > 0:
|
||||||
|
c += 1
|
||||||
|
for b in UNET_MAP_ATTENTIONS:
|
||||||
|
diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
|
||||||
|
for t in range(num_transformers):
|
||||||
|
for b in TRANSFORMER_BLOCKS:
|
||||||
|
diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||||
|
if i == l - 1:
|
||||||
|
for k in ["weight", "bias"]:
|
||||||
|
diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k)
|
||||||
|
n += 1
|
||||||
|
|
||||||
|
for k in UNET_MAP_BASIC:
|
||||||
|
diffusers_unet_map[k[1]] = k[0]
|
||||||
|
|
||||||
|
return diffusers_unet_map
|
||||||
|
|
||||||
|
def swap_scale_shift(weight):
|
||||||
|
shift, scale = weight.chunk(2, dim=0)
|
||||||
|
new_weight = torch.cat([scale, shift], dim=0)
|
||||||
|
return new_weight
|
||||||
|
|
||||||
|
MMDIT_MAP_BASIC = {
|
||||||
|
("context_embedder.bias", "context_embedder.bias"),
|
||||||
|
("context_embedder.weight", "context_embedder.weight"),
|
||||||
|
("t_embedder.mlp.0.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
|
||||||
|
("t_embedder.mlp.0.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
|
||||||
|
("t_embedder.mlp.2.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
|
||||||
|
("t_embedder.mlp.2.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
|
||||||
|
("x_embedder.proj.bias", "pos_embed.proj.bias"),
|
||||||
|
("x_embedder.proj.weight", "pos_embed.proj.weight"),
|
||||||
|
("y_embedder.mlp.0.bias", "time_text_embed.text_embedder.linear_1.bias"),
|
||||||
|
("y_embedder.mlp.0.weight", "time_text_embed.text_embedder.linear_1.weight"),
|
||||||
|
("y_embedder.mlp.2.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
||||||
|
("y_embedder.mlp.2.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
||||||
|
("pos_embed", "pos_embed.pos_embed"),
|
||||||
|
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||||
|
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||||
|
("final_layer.linear.bias", "proj_out.bias"),
|
||||||
|
("final_layer.linear.weight", "proj_out.weight"),
|
||||||
|
}
|
||||||
|
|
||||||
|
MMDIT_MAP_BLOCK = {
|
||||||
|
("context_block.adaLN_modulation.1.bias", "norm1_context.linear.bias"),
|
||||||
|
("context_block.adaLN_modulation.1.weight", "norm1_context.linear.weight"),
|
||||||
|
("context_block.attn.proj.bias", "attn.to_add_out.bias"),
|
||||||
|
("context_block.attn.proj.weight", "attn.to_add_out.weight"),
|
||||||
|
("context_block.mlp.fc1.bias", "ff_context.net.0.proj.bias"),
|
||||||
|
("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"),
|
||||||
|
("context_block.mlp.fc2.bias", "ff_context.net.2.bias"),
|
||||||
|
("context_block.mlp.fc2.weight", "ff_context.net.2.weight"),
|
||||||
|
("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"),
|
||||||
|
("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"),
|
||||||
|
("x_block.attn.proj.bias", "attn.to_out.0.bias"),
|
||||||
|
("x_block.attn.proj.weight", "attn.to_out.0.weight"),
|
||||||
|
("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"),
|
||||||
|
("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"),
|
||||||
|
("x_block.mlp.fc2.bias", "ff.net.2.bias"),
|
||||||
|
("x_block.mlp.fc2.weight", "ff.net.2.weight"),
|
||||||
|
}
|
||||||
|
|
||||||
|
def mmdit_to_diffusers(mmdit_config, output_prefix=""):
|
||||||
|
key_map = {}
|
||||||
|
|
||||||
|
depth = mmdit_config.get("depth", 0)
|
||||||
|
num_blocks = mmdit_config.get("num_blocks", depth)
|
||||||
|
for i in range(num_blocks):
|
||||||
|
block_from = "transformer_blocks.{}".format(i)
|
||||||
|
block_to = "{}joint_blocks.{}".format(output_prefix, i)
|
||||||
|
|
||||||
|
offset = depth * 64
|
||||||
|
|
||||||
|
for end in ("weight", "bias"):
|
||||||
|
k = "{}.attn.".format(block_from)
|
||||||
|
qkv = "{}.x_block.attn.qkv.{}".format(block_to, end)
|
||||||
|
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset))
|
||||||
|
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset))
|
||||||
|
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
||||||
|
|
||||||
|
qkv = "{}.context_block.attn.qkv.{}".format(block_to, end)
|
||||||
|
key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, offset))
|
||||||
|
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset))
|
||||||
|
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
||||||
|
|
||||||
|
for k in MMDIT_MAP_BLOCK:
|
||||||
|
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
|
||||||
|
|
||||||
|
map_basic = MMDIT_MAP_BASIC.copy()
|
||||||
|
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.bias".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.bias".format(depth - 1), swap_scale_shift))
|
||||||
|
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.weight".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.weight".format(depth - 1), swap_scale_shift))
|
||||||
|
|
||||||
|
for k in map_basic:
|
||||||
|
if len(k) > 2:
|
||||||
|
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
||||||
|
else:
|
||||||
|
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
||||||
|
|
||||||
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
|
def auraflow_to_diffusers(mmdit_config, output_prefix=""):
|
||||||
|
n_double_layers = mmdit_config.get("n_double_layers", 0)
|
||||||
|
n_layers = mmdit_config.get("n_layers", 0)
|
||||||
|
|
||||||
|
key_map = {}
|
||||||
|
for i in range(n_layers):
|
||||||
|
if i < n_double_layers:
|
||||||
|
index = i
|
||||||
|
prefix_from = "joint_transformer_blocks"
|
||||||
|
prefix_to = "{}double_layers".format(output_prefix)
|
||||||
|
block_map = {
|
||||||
|
"attn.to_q.weight": "attn.w2q.weight",
|
||||||
|
"attn.to_k.weight": "attn.w2k.weight",
|
||||||
|
"attn.to_v.weight": "attn.w2v.weight",
|
||||||
|
"attn.to_out.0.weight": "attn.w2o.weight",
|
||||||
|
"attn.add_q_proj.weight": "attn.w1q.weight",
|
||||||
|
"attn.add_k_proj.weight": "attn.w1k.weight",
|
||||||
|
"attn.add_v_proj.weight": "attn.w1v.weight",
|
||||||
|
"attn.to_add_out.weight": "attn.w1o.weight",
|
||||||
|
"ff.linear_1.weight": "mlpX.c_fc1.weight",
|
||||||
|
"ff.linear_2.weight": "mlpX.c_fc2.weight",
|
||||||
|
"ff.out_projection.weight": "mlpX.c_proj.weight",
|
||||||
|
"ff_context.linear_1.weight": "mlpC.c_fc1.weight",
|
||||||
|
"ff_context.linear_2.weight": "mlpC.c_fc2.weight",
|
||||||
|
"ff_context.out_projection.weight": "mlpC.c_proj.weight",
|
||||||
|
"norm1.linear.weight": "modX.1.weight",
|
||||||
|
"norm1_context.linear.weight": "modC.1.weight",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
index = i - n_double_layers
|
||||||
|
prefix_from = "single_transformer_blocks"
|
||||||
|
prefix_to = "{}single_layers".format(output_prefix)
|
||||||
|
|
||||||
|
block_map = {
|
||||||
|
"attn.to_q.weight": "attn.w1q.weight",
|
||||||
|
"attn.to_k.weight": "attn.w1k.weight",
|
||||||
|
"attn.to_v.weight": "attn.w1v.weight",
|
||||||
|
"attn.to_out.0.weight": "attn.w1o.weight",
|
||||||
|
"norm1.linear.weight": "modCX.1.weight",
|
||||||
|
"ff.linear_1.weight": "mlp.c_fc1.weight",
|
||||||
|
"ff.linear_2.weight": "mlp.c_fc2.weight",
|
||||||
|
"ff.out_projection.weight": "mlp.c_proj.weight"
|
||||||
|
}
|
||||||
|
|
||||||
|
for k in block_map:
|
||||||
|
key_map["{}.{}.{}".format(prefix_from, index, k)] = "{}.{}.{}".format(prefix_to, index, block_map[k])
|
||||||
|
|
||||||
|
MAP_BASIC = {
|
||||||
|
("positional_encoding", "pos_embed.pos_embed"),
|
||||||
|
("register_tokens", "register_tokens"),
|
||||||
|
("t_embedder.mlp.0.weight", "time_step_proj.linear_1.weight"),
|
||||||
|
("t_embedder.mlp.0.bias", "time_step_proj.linear_1.bias"),
|
||||||
|
("t_embedder.mlp.2.weight", "time_step_proj.linear_2.weight"),
|
||||||
|
("t_embedder.mlp.2.bias", "time_step_proj.linear_2.bias"),
|
||||||
|
("cond_seq_linear.weight", "context_embedder.weight"),
|
||||||
|
("init_x_linear.weight", "pos_embed.proj.weight"),
|
||||||
|
("init_x_linear.bias", "pos_embed.proj.bias"),
|
||||||
|
("final_linear.weight", "proj_out.weight"),
|
||||||
|
("modF.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||||
|
}
|
||||||
|
|
||||||
|
for k in MAP_BASIC:
|
||||||
|
if len(k) > 2:
|
||||||
|
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
||||||
|
else:
|
||||||
|
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
||||||
|
|
||||||
|
return key_map
|
||||||
|
|
||||||
|
def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||||
|
n_double_layers = mmdit_config.get("depth", 0)
|
||||||
|
n_single_layers = mmdit_config.get("depth_single_blocks", 0)
|
||||||
|
hidden_size = mmdit_config.get("hidden_size", 0)
|
||||||
|
|
||||||
|
key_map = {}
|
||||||
|
for index in range(n_double_layers):
|
||||||
|
prefix_from = "transformer_blocks.{}".format(index)
|
||||||
|
prefix_to = "{}double_blocks.{}".format(output_prefix, index)
|
||||||
|
|
||||||
|
for end in ("weight", "bias"):
|
||||||
|
k = "{}.attn.".format(prefix_from)
|
||||||
|
qkv = "{}.img_attn.qkv.{}".format(prefix_to, end)
|
||||||
|
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||||
|
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||||
|
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||||
|
|
||||||
|
k = "{}.attn.".format(prefix_from)
|
||||||
|
qkv = "{}.txt_attn.qkv.{}".format(prefix_to, end)
|
||||||
|
key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||||
|
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||||
|
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||||
|
|
||||||
|
block_map = {
|
||||||
|
"attn.to_out.0.weight": "img_attn.proj.weight",
|
||||||
|
"attn.to_out.0.bias": "img_attn.proj.bias",
|
||||||
|
"norm1.linear.weight": "img_mod.lin.weight",
|
||||||
|
"norm1.linear.bias": "img_mod.lin.bias",
|
||||||
|
"norm1_context.linear.weight": "txt_mod.lin.weight",
|
||||||
|
"norm1_context.linear.bias": "txt_mod.lin.bias",
|
||||||
|
"attn.to_add_out.weight": "txt_attn.proj.weight",
|
||||||
|
"attn.to_add_out.bias": "txt_attn.proj.bias",
|
||||||
|
"ff.net.0.proj.weight": "img_mlp.0.weight",
|
||||||
|
"ff.net.0.proj.bias": "img_mlp.0.bias",
|
||||||
|
"ff.net.2.weight": "img_mlp.2.weight",
|
||||||
|
"ff.net.2.bias": "img_mlp.2.bias",
|
||||||
|
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
|
||||||
|
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
|
||||||
|
"ff_context.net.2.weight": "txt_mlp.2.weight",
|
||||||
|
"ff_context.net.2.bias": "txt_mlp.2.bias",
|
||||||
|
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
||||||
|
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
||||||
|
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
||||||
|
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
||||||
|
}
|
||||||
|
|
||||||
|
for k in block_map:
|
||||||
|
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
||||||
|
|
||||||
|
for index in range(n_single_layers):
|
||||||
|
prefix_from = "single_transformer_blocks.{}".format(index)
|
||||||
|
prefix_to = "{}single_blocks.{}".format(output_prefix, index)
|
||||||
|
|
||||||
|
for end in ("weight", "bias"):
|
||||||
|
k = "{}.attn.".format(prefix_from)
|
||||||
|
qkv = "{}.linear1.{}".format(prefix_to, end)
|
||||||
|
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||||
|
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||||
|
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||||
|
key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4))
|
||||||
|
|
||||||
|
block_map = {
|
||||||
|
"norm.linear.weight": "modulation.lin.weight",
|
||||||
|
"norm.linear.bias": "modulation.lin.bias",
|
||||||
|
"proj_out.weight": "linear2.weight",
|
||||||
|
"proj_out.bias": "linear2.bias",
|
||||||
|
"attn.norm_q.weight": "norm.query_norm.scale",
|
||||||
|
"attn.norm_k.weight": "norm.key_norm.scale",
|
||||||
|
}
|
||||||
|
|
||||||
|
for k in block_map:
|
||||||
|
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
||||||
|
|
||||||
|
MAP_BASIC = {
|
||||||
|
("final_layer.linear.bias", "proj_out.bias"),
|
||||||
|
("final_layer.linear.weight", "proj_out.weight"),
|
||||||
|
("img_in.bias", "x_embedder.bias"),
|
||||||
|
("img_in.weight", "x_embedder.weight"),
|
||||||
|
("time_in.in_layer.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
|
||||||
|
("time_in.in_layer.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
|
||||||
|
("time_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
|
||||||
|
("time_in.out_layer.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
|
||||||
|
("txt_in.bias", "context_embedder.bias"),
|
||||||
|
("txt_in.weight", "context_embedder.weight"),
|
||||||
|
("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"),
|
||||||
|
("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"),
|
||||||
|
("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
||||||
|
("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
||||||
|
("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"),
|
||||||
|
("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"),
|
||||||
|
("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"),
|
||||||
|
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
||||||
|
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||||
|
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||||
|
}
|
||||||
|
|
||||||
|
for k in MAP_BASIC:
|
||||||
|
if len(k) > 2:
|
||||||
|
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
||||||
|
else:
|
||||||
|
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
||||||
|
|
||||||
|
return key_map
|
||||||
|
|
||||||
|
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
||||||
|
if tensor.shape[dim] > batch_size:
|
||||||
|
return tensor.narrow(dim, 0, batch_size)
|
||||||
|
elif tensor.shape[dim] < batch_size:
|
||||||
|
return tensor.repeat(dim * [1] + [math.ceil(batch_size / tensor.shape[dim])] + [1] * (len(tensor.shape) - 1 - dim)).narrow(dim, 0, batch_size)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def resize_to_batch_size(tensor, batch_size):
|
||||||
|
in_batch_size = tensor.shape[0]
|
||||||
|
if in_batch_size == batch_size:
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
if batch_size <= 1:
|
||||||
|
return tensor[:batch_size]
|
||||||
|
|
||||||
|
output = torch.empty([batch_size] + list(tensor.shape)[1:], dtype=tensor.dtype, device=tensor.device)
|
||||||
|
if batch_size < in_batch_size:
|
||||||
|
scale = (in_batch_size - 1) / (batch_size - 1)
|
||||||
|
for i in range(batch_size):
|
||||||
|
output[i] = tensor[min(round(i * scale), in_batch_size - 1)]
|
||||||
|
else:
|
||||||
|
scale = in_batch_size / batch_size
|
||||||
|
for i in range(batch_size):
|
||||||
|
output[i] = tensor[min(math.floor((i + 0.5) * scale), in_batch_size - 1)]
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def convert_sd_to(state_dict, dtype):
|
||||||
|
keys = list(state_dict.keys())
|
||||||
|
for k in keys:
|
||||||
|
state_dict[k] = state_dict[k].to(dtype)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
||||||
|
with open(safetensors_path, "rb") as f:
|
||||||
|
header = f.read(8)
|
||||||
|
length_of_header = struct.unpack('<Q', header)[0]
|
||||||
|
if length_of_header > max_size:
|
||||||
|
return None
|
||||||
|
return f.read(length_of_header)
|
||||||
|
|
||||||
|
def set_attr(obj, attr, value):
|
||||||
|
attrs = attr.split(".")
|
||||||
|
for name in attrs[:-1]:
|
||||||
|
obj = getattr(obj, name)
|
||||||
|
prev = getattr(obj, attrs[-1])
|
||||||
|
setattr(obj, attrs[-1], value)
|
||||||
|
return prev
|
||||||
|
|
||||||
|
def set_attr_param(obj, attr, value):
|
||||||
|
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
||||||
|
|
||||||
|
def copy_to_param(obj, attr, value):
|
||||||
|
# inplace update tensor instead of replacing it
|
||||||
|
attrs = attr.split(".")
|
||||||
|
for name in attrs[:-1]:
|
||||||
|
obj = getattr(obj, name)
|
||||||
|
prev = getattr(obj, attrs[-1])
|
||||||
|
prev.data.copy_(value)
|
||||||
|
|
||||||
|
def get_attr(obj, attr):
|
||||||
|
attrs = attr.split(".")
|
||||||
|
for name in attrs:
|
||||||
|
obj = getattr(obj, name)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def bislerp(samples, width, height):
|
||||||
|
def slerp(b1, b2, r):
|
||||||
|
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
|
||||||
|
|
||||||
|
c = b1.shape[-1]
|
||||||
|
|
||||||
|
#norms
|
||||||
|
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
|
||||||
|
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
#normalize
|
||||||
|
b1_normalized = b1 / b1_norms
|
||||||
|
b2_normalized = b2 / b2_norms
|
||||||
|
|
||||||
|
#zero when norms are zero
|
||||||
|
b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
|
||||||
|
b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
|
||||||
|
|
||||||
|
#slerp
|
||||||
|
dot = (b1_normalized*b2_normalized).sum(1)
|
||||||
|
omega = torch.acos(dot)
|
||||||
|
so = torch.sin(omega)
|
||||||
|
|
||||||
|
#technically not mathematically correct, but more pleasing?
|
||||||
|
res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
|
||||||
|
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
|
||||||
|
|
||||||
|
#edge cases for same or polar opposites
|
||||||
|
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
|
||||||
|
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
|
||||||
|
return res
|
||||||
|
|
||||||
|
def generate_bilinear_data(length_old, length_new, device):
|
||||||
|
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1))
|
||||||
|
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
|
||||||
|
ratios = coords_1 - coords_1.floor()
|
||||||
|
coords_1 = coords_1.to(torch.int64)
|
||||||
|
|
||||||
|
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1
|
||||||
|
coords_2[:,:,:,-1] -= 1
|
||||||
|
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
|
||||||
|
coords_2 = coords_2.to(torch.int64)
|
||||||
|
return ratios, coords_1, coords_2
|
||||||
|
|
||||||
|
orig_dtype = samples.dtype
|
||||||
|
samples = samples.float()
|
||||||
|
n,c,h,w = samples.shape
|
||||||
|
h_new, w_new = (height, width)
|
||||||
|
|
||||||
|
#linear w
|
||||||
|
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
|
||||||
|
coords_1 = coords_1.expand((n, c, h, -1))
|
||||||
|
coords_2 = coords_2.expand((n, c, h, -1))
|
||||||
|
ratios = ratios.expand((n, 1, h, -1))
|
||||||
|
|
||||||
|
pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c))
|
||||||
|
pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c))
|
||||||
|
ratios = ratios.movedim(1, -1).reshape((-1,1))
|
||||||
|
|
||||||
|
result = slerp(pass_1, pass_2, ratios)
|
||||||
|
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
|
||||||
|
|
||||||
|
#linear h
|
||||||
|
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device)
|
||||||
|
coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
||||||
|
coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
||||||
|
ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
|
||||||
|
|
||||||
|
pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c))
|
||||||
|
pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c))
|
||||||
|
ratios = ratios.movedim(1, -1).reshape((-1,1))
|
||||||
|
|
||||||
|
result = slerp(pass_1, pass_2, ratios)
|
||||||
|
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
|
||||||
|
return result.to(orig_dtype)
|
||||||
|
|
||||||
|
def lanczos(samples, width, height):
|
||||||
|
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
||||||
|
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
||||||
|
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
|
||||||
|
result = torch.stack(images)
|
||||||
|
return result.to(samples.device, samples.dtype)
|
||||||
|
|
||||||
|
def common_upscale(samples, width, height, upscale_method, crop):
|
||||||
|
if crop == "center":
|
||||||
|
old_width = samples.shape[3]
|
||||||
|
old_height = samples.shape[2]
|
||||||
|
old_aspect = old_width / old_height
|
||||||
|
new_aspect = width / height
|
||||||
|
x = 0
|
||||||
|
y = 0
|
||||||
|
if old_aspect > new_aspect:
|
||||||
|
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
||||||
|
elif old_aspect < new_aspect:
|
||||||
|
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
||||||
|
s = samples[:,:,y:old_height-y,x:old_width-x]
|
||||||
|
else:
|
||||||
|
s = samples
|
||||||
|
|
||||||
|
if upscale_method == "bislerp":
|
||||||
|
return bislerp(s, width, height)
|
||||||
|
elif upscale_method == "lanczos":
|
||||||
|
return lanczos(s, width, height)
|
||||||
|
else:
|
||||||
|
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||||
|
|
||||||
|
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||||
|
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||||
|
dims = len(tile)
|
||||||
|
output = torch.empty([samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device)
|
||||||
|
|
||||||
|
for b in range(samples.shape[0]):
|
||||||
|
s = samples[b:b+1]
|
||||||
|
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||||
|
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||||
|
|
||||||
|
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
|
||||||
|
s_in = s
|
||||||
|
upscaled = []
|
||||||
|
|
||||||
|
for d in range(dims):
|
||||||
|
pos = max(0, min(s.shape[d + 2] - overlap, it[d]))
|
||||||
|
l = min(tile[d], s.shape[d + 2] - pos)
|
||||||
|
s_in = s_in.narrow(d + 2, pos, l)
|
||||||
|
upscaled.append(round(pos * upscale_amount))
|
||||||
|
ps = function(s_in).to(output_device)
|
||||||
|
mask = torch.ones_like(ps)
|
||||||
|
feather = round(overlap * upscale_amount)
|
||||||
|
for t in range(feather):
|
||||||
|
for d in range(2, dims + 2):
|
||||||
|
m = mask.narrow(d, t, 1)
|
||||||
|
m *= ((1.0/feather) * (t + 1))
|
||||||
|
m = mask.narrow(d, mask.shape[d] -1 -t, 1)
|
||||||
|
m *= ((1.0/feather) * (t + 1))
|
||||||
|
|
||||||
|
o = out
|
||||||
|
o_d = out_div
|
||||||
|
for d in range(dims):
|
||||||
|
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||||
|
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||||
|
|
||||||
|
o += ps * mask
|
||||||
|
o_d += mask
|
||||||
|
|
||||||
|
if pbar is not None:
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
output[b:b+1] = out/out_div
|
||||||
|
return output
|
||||||
|
|
||||||
|
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||||
|
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar)
|
||||||
|
|
||||||
|
PROGRESS_BAR_ENABLED = True
|
||||||
|
def set_progress_bar_enabled(enabled):
|
||||||
|
global PROGRESS_BAR_ENABLED
|
||||||
|
PROGRESS_BAR_ENABLED = enabled
|
||||||
|
|
||||||
|
PROGRESS_BAR_HOOK = None
|
||||||
|
def set_progress_bar_global_hook(function):
|
||||||
|
global PROGRESS_BAR_HOOK
|
||||||
|
PROGRESS_BAR_HOOK = function
|
||||||
|
|
||||||
|
class ProgressBar:
|
||||||
|
def __init__(self, total):
|
||||||
|
global PROGRESS_BAR_HOOK
|
||||||
|
self.total = total
|
||||||
|
self.current = 0
|
||||||
|
self.hook = PROGRESS_BAR_HOOK
|
||||||
|
|
||||||
|
def update_absolute(self, value, total=None, preview=None):
|
||||||
|
if total is not None:
|
||||||
|
self.total = total
|
||||||
|
if value > self.total:
|
||||||
|
value = self.total
|
||||||
|
self.current = value
|
||||||
|
if self.hook is not None:
|
||||||
|
self.hook(self.current, self.total, preview)
|
||||||
|
|
||||||
|
def update(self, value):
|
||||||
|
self.update_absolute(self.current + value)
|
||||||
2
packages_3rdparty/webui_lora_collection/lora.py
vendored
Normal file
2
packages_3rdparty/webui_lora_collection/lora.py
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# TODO: Implement API
|
||||||
|
|
||||||
68
packages_3rdparty/webui_lora_collection/lyco_helpers.py
vendored
Normal file
68
packages_3rdparty/webui_lora_collection/lyco_helpers.py
vendored
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def make_weight_cp(t, wa, wb):
|
||||||
|
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
|
||||||
|
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
|
||||||
|
|
||||||
|
|
||||||
|
def rebuild_conventional(up, down, shape, dyn_dim=None):
|
||||||
|
up = up.reshape(up.size(0), -1)
|
||||||
|
down = down.reshape(down.size(0), -1)
|
||||||
|
if dyn_dim is not None:
|
||||||
|
up = up[:, :dyn_dim]
|
||||||
|
down = down[:dyn_dim, :]
|
||||||
|
return (up @ down).reshape(shape)
|
||||||
|
|
||||||
|
|
||||||
|
def rebuild_cp_decomposition(up, down, mid):
|
||||||
|
up = up.reshape(up.size(0), -1)
|
||||||
|
down = down.reshape(down.size(0), -1)
|
||||||
|
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
|
||||||
|
|
||||||
|
|
||||||
|
# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
|
||||||
|
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
||||||
|
'''
|
||||||
|
return a tuple of two value of input dimension decomposed by the number closest to factor
|
||||||
|
second value is higher or equal than first value.
|
||||||
|
|
||||||
|
In LoRA with Kroneckor Product, first value is a value for weight scale.
|
||||||
|
secon value is a value for weight.
|
||||||
|
|
||||||
|
Because of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
|
||||||
|
|
||||||
|
examples)
|
||||||
|
factor
|
||||||
|
-1 2 4 8 16 ...
|
||||||
|
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
|
||||||
|
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
|
||||||
|
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
|
||||||
|
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
|
||||||
|
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
|
||||||
|
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
|
||||||
|
'''
|
||||||
|
|
||||||
|
if factor > 0 and (dimension % factor) == 0:
|
||||||
|
m = factor
|
||||||
|
n = dimension // factor
|
||||||
|
if m > n:
|
||||||
|
n, m = m, n
|
||||||
|
return m, n
|
||||||
|
if factor < 0:
|
||||||
|
factor = dimension
|
||||||
|
m, n = 1, dimension
|
||||||
|
length = m + n
|
||||||
|
while m<n:
|
||||||
|
new_m = m + 1
|
||||||
|
while dimension%new_m != 0:
|
||||||
|
new_m += 1
|
||||||
|
new_n = dimension // new_m
|
||||||
|
if new_m + new_n > length or new_m>factor:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
m, n = new_m, new_n
|
||||||
|
if m > n:
|
||||||
|
n, m = m, n
|
||||||
|
return m, n
|
||||||
|
|
||||||
228
packages_3rdparty/webui_lora_collection/network.py
vendored
Normal file
228
packages_3rdparty/webui_lora_collection/network.py
vendored
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import os
|
||||||
|
from collections import namedtuple
|
||||||
|
import enum
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from modules import sd_models, cache, errors, hashes, shared
|
||||||
|
import modules.models.sd3.mmdit
|
||||||
|
|
||||||
|
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
|
||||||
|
|
||||||
|
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
||||||
|
|
||||||
|
|
||||||
|
class SdVersion(enum.Enum):
|
||||||
|
Unknown = 1
|
||||||
|
SD1 = 2
|
||||||
|
SD2 = 3
|
||||||
|
SDXL = 4
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkOnDisk:
|
||||||
|
def __init__(self, name, filename):
|
||||||
|
self.name = name
|
||||||
|
self.filename = filename
|
||||||
|
self.metadata = {}
|
||||||
|
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
||||||
|
|
||||||
|
def read_metadata():
|
||||||
|
metadata = sd_models.read_metadata_from_safetensors(filename)
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
if self.is_safetensors:
|
||||||
|
try:
|
||||||
|
self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"reading lora {filename}")
|
||||||
|
|
||||||
|
if self.metadata:
|
||||||
|
m = {}
|
||||||
|
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
|
||||||
|
m[k] = v
|
||||||
|
|
||||||
|
self.metadata = m
|
||||||
|
|
||||||
|
self.alias = self.metadata.get('ss_output_name', self.name)
|
||||||
|
|
||||||
|
self.hash = None
|
||||||
|
self.shorthash = None
|
||||||
|
self.set_hash(
|
||||||
|
self.metadata.get('sshs_model_hash') or
|
||||||
|
hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
|
||||||
|
''
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sd_version = self.detect_version()
|
||||||
|
|
||||||
|
def detect_version(self):
|
||||||
|
if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
|
||||||
|
return SdVersion.SDXL
|
||||||
|
elif str(self.metadata.get('ss_v2', "")) == "True":
|
||||||
|
return SdVersion.SD2
|
||||||
|
elif len(self.metadata):
|
||||||
|
return SdVersion.SD1
|
||||||
|
|
||||||
|
return SdVersion.Unknown
|
||||||
|
|
||||||
|
def set_hash(self, v):
|
||||||
|
self.hash = v
|
||||||
|
self.shorthash = self.hash[0:12]
|
||||||
|
|
||||||
|
if self.shorthash:
|
||||||
|
import networks
|
||||||
|
networks.available_network_hash_lookup[self.shorthash] = self
|
||||||
|
|
||||||
|
def read_hash(self):
|
||||||
|
if not self.hash:
|
||||||
|
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
|
||||||
|
|
||||||
|
def get_alias(self):
|
||||||
|
import networks
|
||||||
|
if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:
|
||||||
|
return self.name
|
||||||
|
else:
|
||||||
|
return self.alias
|
||||||
|
|
||||||
|
|
||||||
|
class Network: # LoraModule
|
||||||
|
def __init__(self, name, network_on_disk: NetworkOnDisk):
|
||||||
|
self.name = name
|
||||||
|
self.network_on_disk = network_on_disk
|
||||||
|
self.te_multiplier = 1.0
|
||||||
|
self.unet_multiplier = 1.0
|
||||||
|
self.dyn_dim = None
|
||||||
|
self.modules = {}
|
||||||
|
self.bundle_embeddings = {}
|
||||||
|
self.mtime = None
|
||||||
|
|
||||||
|
self.mentioned_name = None
|
||||||
|
"""the text that was used to add the network to prompt - can be either name or an alias"""
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleType:
|
||||||
|
def create_module(self, net: Network, weights: NetworkWeights) -> Network | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkModule:
|
||||||
|
def __init__(self, net: Network, weights: NetworkWeights):
|
||||||
|
self.network = net
|
||||||
|
self.network_key = weights.network_key
|
||||||
|
self.sd_key = weights.sd_key
|
||||||
|
self.sd_module = weights.sd_module
|
||||||
|
|
||||||
|
if isinstance(self.sd_module, modules.models.sd3.mmdit.QkvLinear):
|
||||||
|
s = self.sd_module.weight.shape
|
||||||
|
self.shape = (s[0] // 3, s[1])
|
||||||
|
elif hasattr(self.sd_module, 'weight'):
|
||||||
|
self.shape = self.sd_module.weight.shape
|
||||||
|
elif isinstance(self.sd_module, nn.MultiheadAttention):
|
||||||
|
# For now, only self-attn use Pytorch's MHA
|
||||||
|
# So assume all qkvo proj have same shape
|
||||||
|
self.shape = self.sd_module.out_proj.weight.shape
|
||||||
|
else:
|
||||||
|
self.shape = None
|
||||||
|
|
||||||
|
self.ops = None
|
||||||
|
self.extra_kwargs = {}
|
||||||
|
if isinstance(self.sd_module, nn.Conv2d):
|
||||||
|
self.ops = F.conv2d
|
||||||
|
self.extra_kwargs = {
|
||||||
|
'stride': self.sd_module.stride,
|
||||||
|
'padding': self.sd_module.padding
|
||||||
|
}
|
||||||
|
elif isinstance(self.sd_module, nn.Linear):
|
||||||
|
self.ops = F.linear
|
||||||
|
elif isinstance(self.sd_module, nn.LayerNorm):
|
||||||
|
self.ops = F.layer_norm
|
||||||
|
self.extra_kwargs = {
|
||||||
|
'normalized_shape': self.sd_module.normalized_shape,
|
||||||
|
'eps': self.sd_module.eps
|
||||||
|
}
|
||||||
|
elif isinstance(self.sd_module, nn.GroupNorm):
|
||||||
|
self.ops = F.group_norm
|
||||||
|
self.extra_kwargs = {
|
||||||
|
'num_groups': self.sd_module.num_groups,
|
||||||
|
'eps': self.sd_module.eps
|
||||||
|
}
|
||||||
|
|
||||||
|
self.dim = None
|
||||||
|
self.bias = weights.w.get("bias")
|
||||||
|
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
|
||||||
|
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
|
||||||
|
|
||||||
|
self.dora_scale = weights.w.get("dora_scale", None)
|
||||||
|
self.dora_norm_dims = len(self.shape) - 1
|
||||||
|
|
||||||
|
def multiplier(self):
|
||||||
|
if 'transformer' in self.sd_key[:20]:
|
||||||
|
return self.network.te_multiplier
|
||||||
|
else:
|
||||||
|
return self.network.unet_multiplier
|
||||||
|
|
||||||
|
def calc_scale(self):
|
||||||
|
if self.scale is not None:
|
||||||
|
return self.scale
|
||||||
|
if self.dim is not None and self.alpha is not None:
|
||||||
|
return self.alpha / self.dim
|
||||||
|
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
def apply_weight_decompose(self, updown, orig_weight):
|
||||||
|
# Match the device/dtype
|
||||||
|
orig_weight = orig_weight.to(updown.dtype)
|
||||||
|
dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype)
|
||||||
|
updown = updown.to(orig_weight.device)
|
||||||
|
|
||||||
|
merged_scale1 = updown + orig_weight
|
||||||
|
merged_scale1_norm = (
|
||||||
|
merged_scale1.transpose(0, 1)
|
||||||
|
.reshape(merged_scale1.shape[1], -1)
|
||||||
|
.norm(dim=1, keepdim=True)
|
||||||
|
.reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
dora_merged = (
|
||||||
|
merged_scale1 * (dora_scale / merged_scale1_norm)
|
||||||
|
)
|
||||||
|
final_updown = dora_merged - orig_weight
|
||||||
|
return final_updown
|
||||||
|
|
||||||
|
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
||||||
|
if self.bias is not None:
|
||||||
|
updown = updown.reshape(self.bias.shape)
|
||||||
|
updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
|
||||||
|
updown = updown.reshape(output_shape)
|
||||||
|
|
||||||
|
if len(output_shape) == 4:
|
||||||
|
updown = updown.reshape(output_shape)
|
||||||
|
|
||||||
|
if orig_weight.size().numel() == updown.size().numel():
|
||||||
|
updown = updown.reshape(orig_weight.shape)
|
||||||
|
|
||||||
|
if ex_bias is not None:
|
||||||
|
ex_bias = ex_bias * self.multiplier()
|
||||||
|
|
||||||
|
updown = updown * self.calc_scale()
|
||||||
|
|
||||||
|
if self.dora_scale is not None:
|
||||||
|
updown = self.apply_weight_decompose(updown, orig_weight)
|
||||||
|
|
||||||
|
return updown * self.multiplier(), ex_bias
|
||||||
|
|
||||||
|
def calc_updown(self, target):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
"""A general forward implementation for all modules"""
|
||||||
|
if self.ops is None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
else:
|
||||||
|
updown, ex_bias = self.calc_updown(self.sd_module.weight)
|
||||||
|
return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)
|
||||||
|
|
||||||
27
packages_3rdparty/webui_lora_collection/network_full.py
vendored
Normal file
27
packages_3rdparty/webui_lora_collection/network_full.py
vendored
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import network
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleTypeFull(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
if all(x in weights.w for x in ["diff"]):
|
||||||
|
return NetworkModuleFull(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkModuleFull(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
super().__init__(net, weights)
|
||||||
|
|
||||||
|
self.weight = weights.w.get("diff")
|
||||||
|
self.ex_bias = weights.w.get("diff_b")
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
output_shape = self.weight.shape
|
||||||
|
updown = self.weight.to(orig_weight.device)
|
||||||
|
if self.ex_bias is not None:
|
||||||
|
ex_bias = self.ex_bias.to(orig_weight.device)
|
||||||
|
else:
|
||||||
|
ex_bias = None
|
||||||
|
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
|
||||||
33
packages_3rdparty/webui_lora_collection/network_glora.py
vendored
Normal file
33
packages_3rdparty/webui_lora_collection/network_glora.py
vendored
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
|
||||||
|
import network
|
||||||
|
|
||||||
|
class ModuleTypeGLora(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]):
|
||||||
|
return NetworkModuleGLora(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
# adapted from https://github.com/KohakuBlueleaf/LyCORIS
|
||||||
|
class NetworkModuleGLora(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
super().__init__(net, weights)
|
||||||
|
|
||||||
|
if hasattr(self.sd_module, 'weight'):
|
||||||
|
self.shape = self.sd_module.weight.shape
|
||||||
|
|
||||||
|
self.w1a = weights.w["a1.weight"]
|
||||||
|
self.w1b = weights.w["b1.weight"]
|
||||||
|
self.w2a = weights.w["a2.weight"]
|
||||||
|
self.w2b = weights.w["b2.weight"]
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
w1a = self.w1a.to(orig_weight.device)
|
||||||
|
w1b = self.w1b.to(orig_weight.device)
|
||||||
|
w2a = self.w2a.to(orig_weight.device)
|
||||||
|
w2b = self.w2b.to(orig_weight.device)
|
||||||
|
|
||||||
|
output_shape = [w1a.size(0), w1b.size(1)]
|
||||||
|
updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a))
|
||||||
|
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||||
55
packages_3rdparty/webui_lora_collection/network_hada.py
vendored
Normal file
55
packages_3rdparty/webui_lora_collection/network_hada.py
vendored
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import lyco_helpers
|
||||||
|
import network
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleTypeHada(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]):
|
||||||
|
return NetworkModuleHada(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkModuleHada(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
super().__init__(net, weights)
|
||||||
|
|
||||||
|
if hasattr(self.sd_module, 'weight'):
|
||||||
|
self.shape = self.sd_module.weight.shape
|
||||||
|
|
||||||
|
self.w1a = weights.w["hada_w1_a"]
|
||||||
|
self.w1b = weights.w["hada_w1_b"]
|
||||||
|
self.dim = self.w1b.shape[0]
|
||||||
|
self.w2a = weights.w["hada_w2_a"]
|
||||||
|
self.w2b = weights.w["hada_w2_b"]
|
||||||
|
|
||||||
|
self.t1 = weights.w.get("hada_t1")
|
||||||
|
self.t2 = weights.w.get("hada_t2")
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
w1a = self.w1a.to(orig_weight.device)
|
||||||
|
w1b = self.w1b.to(orig_weight.device)
|
||||||
|
w2a = self.w2a.to(orig_weight.device)
|
||||||
|
w2b = self.w2b.to(orig_weight.device)
|
||||||
|
|
||||||
|
output_shape = [w1a.size(0), w1b.size(1)]
|
||||||
|
|
||||||
|
if self.t1 is not None:
|
||||||
|
output_shape = [w1a.size(1), w1b.size(1)]
|
||||||
|
t1 = self.t1.to(orig_weight.device)
|
||||||
|
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
|
||||||
|
output_shape += t1.shape[2:]
|
||||||
|
else:
|
||||||
|
if len(w1b.shape) == 4:
|
||||||
|
output_shape += w1b.shape[2:]
|
||||||
|
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
|
||||||
|
|
||||||
|
if self.t2 is not None:
|
||||||
|
t2 = self.t2.to(orig_weight.device)
|
||||||
|
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
||||||
|
else:
|
||||||
|
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
|
||||||
|
|
||||||
|
updown = updown1 * updown2
|
||||||
|
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||||
30
packages_3rdparty/webui_lora_collection/network_ia3.py
vendored
Normal file
30
packages_3rdparty/webui_lora_collection/network_ia3.py
vendored
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import network
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleTypeIa3(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
if all(x in weights.w for x in ["weight"]):
|
||||||
|
return NetworkModuleIa3(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkModuleIa3(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
super().__init__(net, weights)
|
||||||
|
|
||||||
|
self.w = weights.w["weight"]
|
||||||
|
self.on_input = weights.w["on_input"].item()
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
w = self.w.to(orig_weight.device)
|
||||||
|
|
||||||
|
output_shape = [w.size(0), orig_weight.size(1)]
|
||||||
|
if self.on_input:
|
||||||
|
output_shape.reverse()
|
||||||
|
else:
|
||||||
|
w = w.reshape(-1, 1)
|
||||||
|
|
||||||
|
updown = orig_weight * w
|
||||||
|
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||||
64
packages_3rdparty/webui_lora_collection/network_lokr.py
vendored
Normal file
64
packages_3rdparty/webui_lora_collection/network_lokr.py
vendored
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
import lyco_helpers
|
||||||
|
import network
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleTypeLokr(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
has_1 = "lokr_w1" in weights.w or ("lokr_w1_a" in weights.w and "lokr_w1_b" in weights.w)
|
||||||
|
has_2 = "lokr_w2" in weights.w or ("lokr_w2_a" in weights.w and "lokr_w2_b" in weights.w)
|
||||||
|
if has_1 and has_2:
|
||||||
|
return NetworkModuleLokr(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def make_kron(orig_shape, w1, w2):
|
||||||
|
if len(w2.shape) == 4:
|
||||||
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
|
w2 = w2.contiguous()
|
||||||
|
return torch.kron(w1, w2).reshape(orig_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkModuleLokr(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
super().__init__(net, weights)
|
||||||
|
|
||||||
|
self.w1 = weights.w.get("lokr_w1")
|
||||||
|
self.w1a = weights.w.get("lokr_w1_a")
|
||||||
|
self.w1b = weights.w.get("lokr_w1_b")
|
||||||
|
self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim
|
||||||
|
self.w2 = weights.w.get("lokr_w2")
|
||||||
|
self.w2a = weights.w.get("lokr_w2_a")
|
||||||
|
self.w2b = weights.w.get("lokr_w2_b")
|
||||||
|
self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim
|
||||||
|
self.t2 = weights.w.get("lokr_t2")
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
if self.w1 is not None:
|
||||||
|
w1 = self.w1.to(orig_weight.device)
|
||||||
|
else:
|
||||||
|
w1a = self.w1a.to(orig_weight.device)
|
||||||
|
w1b = self.w1b.to(orig_weight.device)
|
||||||
|
w1 = w1a @ w1b
|
||||||
|
|
||||||
|
if self.w2 is not None:
|
||||||
|
w2 = self.w2.to(orig_weight.device)
|
||||||
|
elif self.t2 is None:
|
||||||
|
w2a = self.w2a.to(orig_weight.device)
|
||||||
|
w2b = self.w2b.to(orig_weight.device)
|
||||||
|
w2 = w2a @ w2b
|
||||||
|
else:
|
||||||
|
t2 = self.t2.to(orig_weight.device)
|
||||||
|
w2a = self.w2a.to(orig_weight.device)
|
||||||
|
w2b = self.w2b.to(orig_weight.device)
|
||||||
|
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
||||||
|
|
||||||
|
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
|
||||||
|
if len(orig_weight.shape) == 4:
|
||||||
|
output_shape = orig_weight.shape
|
||||||
|
|
||||||
|
updown = make_kron(output_shape, w1, w2)
|
||||||
|
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||||
94
packages_3rdparty/webui_lora_collection/network_lora.py
vendored
Normal file
94
packages_3rdparty/webui_lora_collection/network_lora.py
vendored
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
import lyco_helpers
|
||||||
|
import modules.models.sd3.mmdit
|
||||||
|
import network
|
||||||
|
from modules import devices
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleTypeLora(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):
|
||||||
|
return NetworkModuleLora(net, weights)
|
||||||
|
|
||||||
|
if all(x in weights.w for x in ["lora_A.weight", "lora_B.weight"]):
|
||||||
|
w = weights.w.copy()
|
||||||
|
weights.w.clear()
|
||||||
|
weights.w.update({"lora_up.weight": w["lora_B.weight"], "lora_down.weight": w["lora_A.weight"]})
|
||||||
|
|
||||||
|
return NetworkModuleLora(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkModuleLora(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
super().__init__(net, weights)
|
||||||
|
|
||||||
|
self.up_model = self.create_module(weights.w, "lora_up.weight")
|
||||||
|
self.down_model = self.create_module(weights.w, "lora_down.weight")
|
||||||
|
self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True)
|
||||||
|
|
||||||
|
self.dim = weights.w["lora_down.weight"].shape[0]
|
||||||
|
|
||||||
|
def create_module(self, weights, key, none_ok=False):
|
||||||
|
weight = weights.get(key)
|
||||||
|
|
||||||
|
if weight is None and none_ok:
|
||||||
|
return None
|
||||||
|
|
||||||
|
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear]
|
||||||
|
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
||||||
|
|
||||||
|
if is_linear:
|
||||||
|
weight = weight.reshape(weight.shape[0], -1)
|
||||||
|
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||||
|
elif is_conv and key == "lora_down.weight" or key == "dyn_up":
|
||||||
|
if len(weight.shape) == 2:
|
||||||
|
weight = weight.reshape(weight.shape[0], -1, 1, 1)
|
||||||
|
|
||||||
|
if weight.shape[2] != 1 or weight.shape[3] != 1:
|
||||||
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
|
||||||
|
else:
|
||||||
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||||
|
elif is_conv and key == "lora_mid.weight":
|
||||||
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
|
||||||
|
elif is_conv and key == "lora_up.weight" or key == "dyn_down":
|
||||||
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||||
|
else:
|
||||||
|
raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if weight.shape != module.weight.shape:
|
||||||
|
weight = weight.reshape(module.weight.shape)
|
||||||
|
module.weight.copy_(weight)
|
||||||
|
|
||||||
|
module.to(device=devices.cpu, dtype=devices.dtype)
|
||||||
|
module.weight.requires_grad_(False)
|
||||||
|
|
||||||
|
return module
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
up = self.up_model.weight.to(orig_weight.device)
|
||||||
|
down = self.down_model.weight.to(orig_weight.device)
|
||||||
|
|
||||||
|
output_shape = [up.size(0), down.size(1)]
|
||||||
|
if self.mid_model is not None:
|
||||||
|
# cp-decomposition
|
||||||
|
mid = self.mid_model.weight.to(orig_weight.device)
|
||||||
|
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
|
||||||
|
output_shape += mid.shape[2:]
|
||||||
|
else:
|
||||||
|
if len(down.shape) == 4:
|
||||||
|
output_shape += down.shape[2:]
|
||||||
|
updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim)
|
||||||
|
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
self.up_model.to(device=devices.device)
|
||||||
|
self.down_model.to(device=devices.device)
|
||||||
|
|
||||||
|
return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()
|
||||||
|
|
||||||
|
|
||||||
28
packages_3rdparty/webui_lora_collection/network_norm.py
vendored
Normal file
28
packages_3rdparty/webui_lora_collection/network_norm.py
vendored
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import network
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleTypeNorm(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
if all(x in weights.w for x in ["w_norm", "b_norm"]):
|
||||||
|
return NetworkModuleNorm(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkModuleNorm(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
super().__init__(net, weights)
|
||||||
|
|
||||||
|
self.w_norm = weights.w.get("w_norm")
|
||||||
|
self.b_norm = weights.w.get("b_norm")
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
output_shape = self.w_norm.shape
|
||||||
|
updown = self.w_norm.to(orig_weight.device)
|
||||||
|
|
||||||
|
if self.b_norm is not None:
|
||||||
|
ex_bias = self.b_norm.to(orig_weight.device)
|
||||||
|
else:
|
||||||
|
ex_bias = None
|
||||||
|
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
|
||||||
118
packages_3rdparty/webui_lora_collection/network_oft.py
vendored
Normal file
118
packages_3rdparty/webui_lora_collection/network_oft.py
vendored
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
import torch
|
||||||
|
import network
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleTypeOFT(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]):
|
||||||
|
return NetworkModuleOFT(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
|
||||||
|
# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
|
||||||
|
class NetworkModuleOFT(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
|
||||||
|
super().__init__(net, weights)
|
||||||
|
|
||||||
|
self.lin_module = None
|
||||||
|
self.org_module: list[torch.Module] = [self.sd_module]
|
||||||
|
|
||||||
|
self.scale = 1.0
|
||||||
|
self.is_R = False
|
||||||
|
self.is_boft = False
|
||||||
|
|
||||||
|
# kohya-ss/New LyCORIS OFT/BOFT
|
||||||
|
if "oft_blocks" in weights.w.keys():
|
||||||
|
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
|
||||||
|
self.alpha = weights.w.get("alpha", None) # alpha is constraint
|
||||||
|
self.dim = self.oft_blocks.shape[0] # lora dim
|
||||||
|
# Old LyCORIS OFT
|
||||||
|
elif "oft_diag" in weights.w.keys():
|
||||||
|
self.is_R = True
|
||||||
|
self.oft_blocks = weights.w["oft_diag"]
|
||||||
|
# self.alpha is unused
|
||||||
|
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
|
||||||
|
|
||||||
|
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
|
||||||
|
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
||||||
|
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
|
||||||
|
|
||||||
|
if is_linear:
|
||||||
|
self.out_dim = self.sd_module.out_features
|
||||||
|
elif is_conv:
|
||||||
|
self.out_dim = self.sd_module.out_channels
|
||||||
|
elif is_other_linear:
|
||||||
|
self.out_dim = self.sd_module.embed_dim
|
||||||
|
|
||||||
|
# LyCORIS BOFT
|
||||||
|
if self.oft_blocks.dim() == 4:
|
||||||
|
self.is_boft = True
|
||||||
|
self.rescale = weights.w.get('rescale', None)
|
||||||
|
if self.rescale is not None and not is_other_linear:
|
||||||
|
self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))
|
||||||
|
|
||||||
|
self.num_blocks = self.dim
|
||||||
|
self.block_size = self.out_dim // self.dim
|
||||||
|
self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim
|
||||||
|
if self.is_R:
|
||||||
|
self.constraint = None
|
||||||
|
self.block_size = self.dim
|
||||||
|
self.num_blocks = self.out_dim // self.dim
|
||||||
|
elif self.is_boft:
|
||||||
|
self.boft_m = self.oft_blocks.shape[0]
|
||||||
|
self.num_blocks = self.oft_blocks.shape[1]
|
||||||
|
self.block_size = self.oft_blocks.shape[2]
|
||||||
|
self.boft_b = self.block_size
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
oft_blocks = self.oft_blocks.to(orig_weight.device)
|
||||||
|
eye = torch.eye(self.block_size, device=oft_blocks.device)
|
||||||
|
|
||||||
|
if not self.is_R:
|
||||||
|
block_Q = oft_blocks - oft_blocks.transpose(-1, -2) # ensure skew-symmetric orthogonal matrix
|
||||||
|
if self.constraint != 0:
|
||||||
|
norm_Q = torch.norm(block_Q.flatten())
|
||||||
|
new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device))
|
||||||
|
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||||
|
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
|
||||||
|
|
||||||
|
R = oft_blocks.to(orig_weight.device)
|
||||||
|
|
||||||
|
if not self.is_boft:
|
||||||
|
# This errors out for MultiheadAttention, might need to be handled up-stream
|
||||||
|
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
||||||
|
merged_weight = torch.einsum(
|
||||||
|
'k n m, k n ... -> k m ...',
|
||||||
|
R,
|
||||||
|
merged_weight
|
||||||
|
)
|
||||||
|
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
|
||||||
|
else:
|
||||||
|
# TODO: determine correct value for scale
|
||||||
|
scale = 1.0
|
||||||
|
m = self.boft_m
|
||||||
|
b = self.boft_b
|
||||||
|
r_b = b // 2
|
||||||
|
inp = orig_weight
|
||||||
|
for i in range(m):
|
||||||
|
bi = R[i] # b_num, b_size, b_size
|
||||||
|
if i == 0:
|
||||||
|
# Apply multiplier/scale and rescale into first weight
|
||||||
|
bi = bi * scale + (1 - scale) * eye
|
||||||
|
inp = rearrange(inp, "(c g k) ... -> (c k g) ...", g=2, k=2**i * r_b)
|
||||||
|
inp = rearrange(inp, "(d b) ... -> d b ...", b=b)
|
||||||
|
inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp)
|
||||||
|
inp = rearrange(inp, "d b ... -> (d b) ...")
|
||||||
|
inp = rearrange(inp, "(c k g) ... -> (c g k) ...", g=2, k=2**i * r_b)
|
||||||
|
merged_weight = inp
|
||||||
|
|
||||||
|
# Rescale mechanism
|
||||||
|
if self.rescale is not None:
|
||||||
|
merged_weight = self.rescale.to(merged_weight) * merged_weight
|
||||||
|
|
||||||
|
updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype)
|
||||||
|
output_shape = orig_weight.shape
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||||
Reference in New Issue
Block a user