diff --git a/backend/modules/k_model.py b/backend/modules/k_model.py index 3c7d5ee8..57f2f2fd 100644 --- a/backend/modules/k_model.py +++ b/backend/modules/k_model.py @@ -1,13 +1,15 @@ import torch -from backend import memory_management, attention, operations +from backend import memory_management, attention from backend.modules.k_prediction import k_prediction_from_diffusers_scheduler class KModel(torch.nn.Module): - def __init__(self, model, diffusers_scheduler, k_predictor=None): + def __init__(self, model, diffusers_scheduler, k_predictor=None, config=None): super().__init__() + self.config = config + self.storage_dtype = model.storage_dtype self.computation_dtype = model.computation_dtype diff --git a/backend/patcher/lora.py b/backend/patcher/lora.py index cc75f5cf..ec87b000 100644 --- a/backend/patcher/lora.py +++ b/backend/patcher/lora.py @@ -1,293 +1,31 @@ -# LoRA Implementation Collection form ComfyUI -# Modified by Forge to support greedy loading (load a set or wrong/correct loras to a model and only preserve the correct ones), -# which is important to webui experience - -from backend.misc.diffusers_state_dict import unet_to_diffusers +import packages_3rdparty.webui_lora_collection.lora as lora_utils_webui +import packages_3rdparty.comfyui_lora_collection.lora as lora_utils_comfyui -LORA_CLIP_MAP = { - "mlp.fc1": "mlp_fc1", - "mlp.fc2": "mlp_fc2", - "self_attn.k_proj": "self_attn_k_proj", - "self_attn.q_proj": "self_attn_q_proj", - "self_attn.v_proj": "self_attn_v_proj", - "self_attn.out_proj": "self_attn_out_proj", -} +class ForgeLoraCollection: + # TODO + pass + + +lora_utils_forge = ForgeLoraCollection() + +lora_collection_priority = [lora_utils_forge, lora_utils_webui, lora_utils_comfyui] + + +def get_function(function_name: str): + for lora_collection in lora_collection_priority: + if hasattr(lora_collection, function_name): + return getattr(lora_collection, function_name) def load_lora(lora, to_load): - patch_dict = {} - loaded_keys = set() - for x in to_load: - alpha_name = "{}.alpha".format(x) - alpha = None - if alpha_name in lora.keys(): - alpha = lora[alpha_name].item() - loaded_keys.add(alpha_name) - - dora_scale_name = "{}.dora_scale".format(x) - dora_scale = None - if dora_scale_name in lora.keys(): - dora_scale = lora[dora_scale_name] - loaded_keys.add(dora_scale_name) - - regular_lora = "{}.lora_up.weight".format(x) - diffusers_lora = "{}_lora.up.weight".format(x) - diffusers2_lora = "{}.lora_B.weight".format(x) - diffusers3_lora = "{}.lora.up.weight".format(x) - transformers_lora = "{}.lora_linear_layer.up.weight".format(x) - A_name = None - - if regular_lora in lora.keys(): - A_name = regular_lora - B_name = "{}.lora_down.weight".format(x) - mid_name = "{}.lora_mid.weight".format(x) - elif diffusers_lora in lora.keys(): - A_name = diffusers_lora - B_name = "{}_lora.down.weight".format(x) - mid_name = None - elif diffusers2_lora in lora.keys(): - A_name = diffusers2_lora - B_name = "{}.lora_A.weight".format(x) - mid_name = None - elif diffusers3_lora in lora.keys(): - A_name = diffusers3_lora - B_name = "{}.lora.down.weight".format(x) - mid_name = None - elif transformers_lora in lora.keys(): - A_name = transformers_lora - B_name = "{}.lora_linear_layer.down.weight".format(x) - mid_name = None - - if A_name is not None: - mid = None - if mid_name is not None and mid_name in lora.keys(): - mid = lora[mid_name] - loaded_keys.add(mid_name) - patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale)) - loaded_keys.add(A_name) - loaded_keys.add(B_name) - - ######## loha - hada_w1_a_name = "{}.hada_w1_a".format(x) - hada_w1_b_name = "{}.hada_w1_b".format(x) - hada_w2_a_name = "{}.hada_w2_a".format(x) - hada_w2_b_name = "{}.hada_w2_b".format(x) - hada_t1_name = "{}.hada_t1".format(x) - hada_t2_name = "{}.hada_t2".format(x) - if hada_w1_a_name in lora.keys(): - hada_t1 = None - hada_t2 = None - if hada_t1_name in lora.keys(): - hada_t1 = lora[hada_t1_name] - hada_t2 = lora[hada_t2_name] - loaded_keys.add(hada_t1_name) - loaded_keys.add(hada_t2_name) - - patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)) - loaded_keys.add(hada_w1_a_name) - loaded_keys.add(hada_w1_b_name) - loaded_keys.add(hada_w2_a_name) - loaded_keys.add(hada_w2_b_name) - - ######## lokr - lokr_w1_name = "{}.lokr_w1".format(x) - lokr_w2_name = "{}.lokr_w2".format(x) - lokr_w1_a_name = "{}.lokr_w1_a".format(x) - lokr_w1_b_name = "{}.lokr_w1_b".format(x) - lokr_t2_name = "{}.lokr_t2".format(x) - lokr_w2_a_name = "{}.lokr_w2_a".format(x) - lokr_w2_b_name = "{}.lokr_w2_b".format(x) - - lokr_w1 = None - if lokr_w1_name in lora.keys(): - lokr_w1 = lora[lokr_w1_name] - loaded_keys.add(lokr_w1_name) - - lokr_w2 = None - if lokr_w2_name in lora.keys(): - lokr_w2 = lora[lokr_w2_name] - loaded_keys.add(lokr_w2_name) - - lokr_w1_a = None - if lokr_w1_a_name in lora.keys(): - lokr_w1_a = lora[lokr_w1_a_name] - loaded_keys.add(lokr_w1_a_name) - - lokr_w1_b = None - if lokr_w1_b_name in lora.keys(): - lokr_w1_b = lora[lokr_w1_b_name] - loaded_keys.add(lokr_w1_b_name) - - lokr_w2_a = None - if lokr_w2_a_name in lora.keys(): - lokr_w2_a = lora[lokr_w2_a_name] - loaded_keys.add(lokr_w2_a_name) - - lokr_w2_b = None - if lokr_w2_b_name in lora.keys(): - lokr_w2_b = lora[lokr_w2_b_name] - loaded_keys.add(lokr_w2_b_name) - - lokr_t2 = None - if lokr_t2_name in lora.keys(): - lokr_t2 = lora[lokr_t2_name] - loaded_keys.add(lokr_t2_name) - - if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)) - - # glora - a1_name = "{}.a1.weight".format(x) - a2_name = "{}.a2.weight".format(x) - b1_name = "{}.b1.weight".format(x) - b2_name = "{}.b2.weight".format(x) - if a1_name in lora: - patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)) - loaded_keys.add(a1_name) - loaded_keys.add(a2_name) - loaded_keys.add(b1_name) - loaded_keys.add(b2_name) - - w_norm_name = "{}.w_norm".format(x) - b_norm_name = "{}.b_norm".format(x) - w_norm = lora.get(w_norm_name, None) - b_norm = lora.get(b_norm_name, None) - - if w_norm is not None: - loaded_keys.add(w_norm_name) - patch_dict[to_load[x]] = ("diff", (w_norm,)) - if b_norm is not None: - loaded_keys.add(b_norm_name) - patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,)) - - diff_name = "{}.diff".format(x) - diff_weight = lora.get(diff_name, None) - if diff_weight is not None: - patch_dict[to_load[x]] = ("diff", (diff_weight,)) - loaded_keys.add(diff_name) - - diff_bias_name = "{}.diff_b".format(x) - diff_bias = lora.get(diff_bias_name, None) - if diff_bias is not None: - patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,)) - loaded_keys.add(diff_bias_name) - - remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys} + patch_dict, remaining_dict = get_function('load_lora')(lora, to_load) return patch_dict, remaining_dict def model_lora_keys_clip(model, key_map={}): - sdk = model.state_dict().keys() - - text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" - clip_l_present = False - for b in range(32): # TODO: clean up - for c in LORA_CLIP_MAP: - k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) - if k in sdk: - lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) - key_map[lora_key] = k - lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) - key_map[lora_key] = k - lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) - key_map[lora_key] = k - - k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) - if k in sdk: - lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) - key_map[lora_key] = k - lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) - key_map[lora_key] = k - clip_l_present = True - lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) - key_map[lora_key] = k - - k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) - if k in sdk: - if clip_l_present: - lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) - key_map[lora_key] = k - lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) - key_map[lora_key] = k - else: - lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) - key_map[lora_key] = k - lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) - key_map[lora_key] = k - lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) - key_map[lora_key] = k - - for k in sdk: # OneTrainer SD3 lora - if k.startswith("t5xxl.transformer.") and k.endswith(".weight"): - l_key = k[len("t5xxl.transformer."):-len(".weight")] - lora_key = "lora_te3_{}".format(l_key.replace(".", "_")) - key_map[lora_key] = k - - k = "clip_g.transformer.text_projection.weight" - if k in sdk: - key_map["lora_prior_te_text_projection"] = k - key_map["lora_te2_text_projection"] = k - - k = "clip_l.transformer.text_projection.weight" - if k in sdk: - key_map["lora_te1_text_projection"] = k - - return key_map + return get_function('model_lora_keys_clip')(model, key_map) def model_lora_keys_unet(model, key_map={}): - sd = model.state_dict() - sdk = sd.keys() - - for k in sdk: - if k.startswith("diffusion_model.") and k.endswith(".weight"): - key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") - key_map["lora_unet_{}".format(key_lora)] = k - key_map["lora_prior_unet_{}".format(key_lora)] = k - - diffusers_keys = unet_to_diffusers(model.diffusion_model.config) - for k in diffusers_keys: - if k.endswith(".weight"): - unet_key = "diffusion_model.{}".format(diffusers_keys[k]) - key_lora = k[:-len(".weight")].replace(".", "_") - key_map["lora_unet_{}".format(key_lora)] = unet_key - - diffusers_lora_prefix = ["", "unet."] - for p in diffusers_lora_prefix: - diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_")) - if diffusers_lora_key.endswith(".to_out.0"): - diffusers_lora_key = diffusers_lora_key[:-2] - key_map[diffusers_lora_key] = unet_key - - # TODO: - - # if isinstance(model, xxxx.modules.model_base.SD3): # Diffusers lora SD3 - # diffusers_keys = xxxx.modules.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") - # for k in diffusers_keys: - # if k.endswith(".weight"): - # to = diffusers_keys[k] - # key_lora = "transformer.{}".format(k[:-len(".weight")]) # regular diffusers sd3 lora format - # key_map[key_lora] = to - # - # key_lora = "base_model.model.{}".format(k[:-len(".weight")]) # format for flash-sd3 lora and others? - # key_map[key_lora] = to - # - # key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) # OneTrainer lora - # key_map[key_lora] = to - # - # if isinstance(model, xxxx.modules.model_base.AuraFlow): # Diffusers lora AuraFlow - # diffusers_keys = xxxx.modules.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") - # for k in diffusers_keys: - # if k.endswith(".weight"): - # to = diffusers_keys[k] - # key_lora = "transformer.{}".format(k[:-len(".weight")]) # simpletrainer and probably regular diffusers lora format - # key_map[key_lora] = to - # - # if isinstance(model, xxxx.modules.model_base.HunyuanDiT): - # for k in sdk: - # if k.startswith("diffusion_model.") and k.endswith(".weight"): - # key_lora = k[len("diffusion_model."):-len(".weight")] - # key_map["base_model.model.{}".format(key_lora)] = k # official hunyuan lora format - - return key_map + return get_function('model_lora_keys_unet')(model, key_map) diff --git a/backend/patcher/unet.py b/backend/patcher/unet.py index 8e5155a3..e8416f70 100644 --- a/backend/patcher/unet.py +++ b/backend/patcher/unet.py @@ -8,7 +8,7 @@ from backend.patcher.base import ModelPatcher class UnetPatcher(ModelPatcher): @classmethod def from_model(cls, model, diffusers_scheduler, config, k_predictor=None): - model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor) + model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor, config=config) return UnetPatcher( model, load_device=model.diffusion_model.load_device, diff --git a/packages_3rdparty/README.md b/packages_3rdparty/README.md new file mode 100644 index 00000000..8c3e98af --- /dev/null +++ b/packages_3rdparty/README.md @@ -0,0 +1 @@ +Please follow the standard of https://github.com/opencv/opencv/tree/315f85d4f484c1e2fa043c73ac3fdd9fc5997ee7/3rdparty when PR or modifying files. diff --git a/packages_3rdparty/comfyui_lora_collection/lora.py b/packages_3rdparty/comfyui_lora_collection/lora.py new file mode 100644 index 00000000..48a97b37 --- /dev/null +++ b/packages_3rdparty/comfyui_lora_collection/lora.py @@ -0,0 +1,323 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Comfy + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +from packages_3rdparty.comfyui_lora_collection import utils + + +LORA_CLIP_MAP = { + "mlp.fc1": "mlp_fc1", + "mlp.fc2": "mlp_fc2", + "self_attn.k_proj": "self_attn_k_proj", + "self_attn.q_proj": "self_attn_q_proj", + "self_attn.v_proj": "self_attn_v_proj", + "self_attn.out_proj": "self_attn_out_proj", +} + + +def load_lora(lora, to_load): + patch_dict = {} + loaded_keys = set() + for x in to_load: + alpha_name = "{}.alpha".format(x) + alpha = None + if alpha_name in lora.keys(): + alpha = lora[alpha_name].item() + loaded_keys.add(alpha_name) + + dora_scale_name = "{}.dora_scale".format(x) + dora_scale = None + if dora_scale_name in lora.keys(): + dora_scale = lora[dora_scale_name] + loaded_keys.add(dora_scale_name) + + regular_lora = "{}.lora_up.weight".format(x) + diffusers_lora = "{}_lora.up.weight".format(x) + diffusers2_lora = "{}.lora_B.weight".format(x) + diffusers3_lora = "{}.lora.up.weight".format(x) + transformers_lora = "{}.lora_linear_layer.up.weight".format(x) + A_name = None + + if regular_lora in lora.keys(): + A_name = regular_lora + B_name = "{}.lora_down.weight".format(x) + mid_name = "{}.lora_mid.weight".format(x) + elif diffusers_lora in lora.keys(): + A_name = diffusers_lora + B_name = "{}_lora.down.weight".format(x) + mid_name = None + elif diffusers2_lora in lora.keys(): + A_name = diffusers2_lora + B_name = "{}.lora_A.weight".format(x) + mid_name = None + elif diffusers3_lora in lora.keys(): + A_name = diffusers3_lora + B_name = "{}.lora.down.weight".format(x) + mid_name = None + elif transformers_lora in lora.keys(): + A_name = transformers_lora + B_name ="{}.lora_linear_layer.down.weight".format(x) + mid_name = None + + if A_name is not None: + mid = None + if mid_name is not None and mid_name in lora.keys(): + mid = lora[mid_name] + loaded_keys.add(mid_name) + patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale)) + loaded_keys.add(A_name) + loaded_keys.add(B_name) + + + ######## loha + hada_w1_a_name = "{}.hada_w1_a".format(x) + hada_w1_b_name = "{}.hada_w1_b".format(x) + hada_w2_a_name = "{}.hada_w2_a".format(x) + hada_w2_b_name = "{}.hada_w2_b".format(x) + hada_t1_name = "{}.hada_t1".format(x) + hada_t2_name = "{}.hada_t2".format(x) + if hada_w1_a_name in lora.keys(): + hada_t1 = None + hada_t2 = None + if hada_t1_name in lora.keys(): + hada_t1 = lora[hada_t1_name] + hada_t2 = lora[hada_t2_name] + loaded_keys.add(hada_t1_name) + loaded_keys.add(hada_t2_name) + + patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)) + loaded_keys.add(hada_w1_a_name) + loaded_keys.add(hada_w1_b_name) + loaded_keys.add(hada_w2_a_name) + loaded_keys.add(hada_w2_b_name) + + + ######## lokr + lokr_w1_name = "{}.lokr_w1".format(x) + lokr_w2_name = "{}.lokr_w2".format(x) + lokr_w1_a_name = "{}.lokr_w1_a".format(x) + lokr_w1_b_name = "{}.lokr_w1_b".format(x) + lokr_t2_name = "{}.lokr_t2".format(x) + lokr_w2_a_name = "{}.lokr_w2_a".format(x) + lokr_w2_b_name = "{}.lokr_w2_b".format(x) + + lokr_w1 = None + if lokr_w1_name in lora.keys(): + lokr_w1 = lora[lokr_w1_name] + loaded_keys.add(lokr_w1_name) + + lokr_w2 = None + if lokr_w2_name in lora.keys(): + lokr_w2 = lora[lokr_w2_name] + loaded_keys.add(lokr_w2_name) + + lokr_w1_a = None + if lokr_w1_a_name in lora.keys(): + lokr_w1_a = lora[lokr_w1_a_name] + loaded_keys.add(lokr_w1_a_name) + + lokr_w1_b = None + if lokr_w1_b_name in lora.keys(): + lokr_w1_b = lora[lokr_w1_b_name] + loaded_keys.add(lokr_w1_b_name) + + lokr_w2_a = None + if lokr_w2_a_name in lora.keys(): + lokr_w2_a = lora[lokr_w2_a_name] + loaded_keys.add(lokr_w2_a_name) + + lokr_w2_b = None + if lokr_w2_b_name in lora.keys(): + lokr_w2_b = lora[lokr_w2_b_name] + loaded_keys.add(lokr_w2_b_name) + + lokr_t2 = None + if lokr_t2_name in lora.keys(): + lokr_t2 = lora[lokr_t2_name] + loaded_keys.add(lokr_t2_name) + + if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): + patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)) + + #glora + a1_name = "{}.a1.weight".format(x) + a2_name = "{}.a2.weight".format(x) + b1_name = "{}.b1.weight".format(x) + b2_name = "{}.b2.weight".format(x) + if a1_name in lora: + patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)) + loaded_keys.add(a1_name) + loaded_keys.add(a2_name) + loaded_keys.add(b1_name) + loaded_keys.add(b2_name) + + w_norm_name = "{}.w_norm".format(x) + b_norm_name = "{}.b_norm".format(x) + w_norm = lora.get(w_norm_name, None) + b_norm = lora.get(b_norm_name, None) + + if w_norm is not None: + loaded_keys.add(w_norm_name) + patch_dict[to_load[x]] = ("diff", (w_norm,)) + if b_norm is not None: + loaded_keys.add(b_norm_name) + patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,)) + + diff_name = "{}.diff".format(x) + diff_weight = lora.get(diff_name, None) + if diff_weight is not None: + patch_dict[to_load[x]] = ("diff", (diff_weight,)) + loaded_keys.add(diff_name) + + diff_bias_name = "{}.diff_b".format(x) + diff_bias = lora.get(diff_bias_name, None) + if diff_bias is not None: + patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,)) + loaded_keys.add(diff_bias_name) + + remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys} + return patch_dict, remaining_dict + + +def model_lora_keys_clip(model, key_map={}): + sdk = model.state_dict().keys() + + text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" + clip_l_present = False + for b in range(32): #TODO: clean up + for c in LORA_CLIP_MAP: + k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) + if k in sdk: + lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = k + lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = k + lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora + key_map[lora_key] = k + + k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) + if k in sdk: + lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = k + lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base + key_map[lora_key] = k + clip_l_present = True + lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora + key_map[lora_key] = k + + k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) + if k in sdk: + if clip_l_present: + lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base + key_map[lora_key] = k + lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora + key_map[lora_key] = k + else: + lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner + key_map[lora_key] = k + lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora + key_map[lora_key] = k + lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config + key_map[lora_key] = k + + for k in sdk: + if k.endswith(".weight"): + if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora + l_key = k[len("t5xxl.transformer."):-len(".weight")] + lora_key = "lora_te3_{}".format(l_key.replace(".", "_")) + key_map[lora_key] = k + elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora + l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")] + lora_key = "lora_te1_{}".format(l_key.replace(".", "_")) + key_map[lora_key] = k + + + k = "clip_g.transformer.text_projection.weight" + if k in sdk: + key_map["lora_prior_te_text_projection"] = k #cascade lora? + # key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too + key_map["lora_te2_text_projection"] = k #OneTrainer SD3 lora + + k = "clip_l.transformer.text_projection.weight" + if k in sdk: + key_map["lora_te1_text_projection"] = k #OneTrainer SD3 lora, not necessary but omits warning + + return key_map + + +def model_lora_keys_unet(model, key_map={}): + sd = model.state_dict() + sdk = sd.keys() + + for k in sdk: + if k.startswith("diffusion_model.") and k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") + key_map["lora_unet_{}".format(key_lora)] = k + key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config + key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names + + diffusers_keys = utils.unet_to_diffusers(model.diffusion_model.config) + for k in diffusers_keys: + if k.endswith(".weight"): + unet_key = "diffusion_model.{}".format(diffusers_keys[k]) + key_lora = k[:-len(".weight")].replace(".", "_") + key_map["lora_unet_{}".format(key_lora)] = unet_key + + diffusers_lora_prefix = ["", "unet."] + for p in diffusers_lora_prefix: + diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_")) + if diffusers_lora_key.endswith(".to_out.0"): + diffusers_lora_key = diffusers_lora_key[:-2] + key_map[diffusers_lora_key] = unet_key + + # if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3 + # diffusers_keys = utils.mmdit_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.") + # for k in diffusers_keys: + # if k.endswith(".weight"): + # to = diffusers_keys[k] + # key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format + # key_map[key_lora] = to + # + # key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others? + # key_map[key_lora] = to + # + # key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora + # key_map[key_lora] = to + # + # if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow + # diffusers_keys = utils.auraflow_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.") + # for k in diffusers_keys: + # if k.endswith(".weight"): + # to = diffusers_keys[k] + # key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format + # key_map[key_lora] = to + # + # if isinstance(model, comfy.model_base.HunyuanDiT): + # for k in sdk: + # if k.startswith("diffusion_model.") and k.endswith(".weight"): + # key_lora = k[len("diffusion_model."):-len(".weight")] + # key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format + + if 'flux' in model.config.huggingface_repo.lower(): #Diffusers lora Flux + diffusers_keys = utils.flux_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.") + for k in diffusers_keys: + if k.endswith(".weight"): + to = diffusers_keys[k] + key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format + key_map[key_lora] = to + + return key_map diff --git a/packages_3rdparty/comfyui_lora_collection/utils.py b/packages_3rdparty/comfyui_lora_collection/utils.py new file mode 100644 index 00000000..d60dea55 --- /dev/null +++ b/packages_3rdparty/comfyui_lora_collection/utils.py @@ -0,0 +1,760 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Comfy + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + + +import torch +import math +import struct +import numpy as np +from PIL import Image +import itertools + + +def calculate_parameters(sd, prefix=""): + params = 0 + for k in sd.keys(): + if k.startswith(prefix): + w = sd[k] + params += w.nelement() + return params + +def weight_dtype(sd, prefix=""): + dtypes = {} + for k in sd.keys(): + if k.startswith(prefix): + w = sd[k] + dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1 + + if len(dtypes) == 0: + return None + + return max(dtypes, key=dtypes.get) + +def state_dict_key_replace(state_dict, keys_to_replace): + for x in keys_to_replace: + if x in state_dict: + state_dict[keys_to_replace[x]] = state_dict.pop(x) + return state_dict + +def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False): + if filter_keys: + out = {} + else: + out = state_dict + for rp in replace_prefix: + replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys()))) + for x in replace: + w = state_dict.pop(x[0]) + out[x[1]] = w + return out + + +def transformers_convert(sd, prefix_from, prefix_to, number): + keys_to_replace = { + "{}positional_embedding": "{}embeddings.position_embedding.weight", + "{}token_embedding.weight": "{}embeddings.token_embedding.weight", + "{}ln_final.weight": "{}final_layer_norm.weight", + "{}ln_final.bias": "{}final_layer_norm.bias", + } + + for k in keys_to_replace: + x = k.format(prefix_from) + if x in sd: + sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x) + + resblock_to_replace = { + "ln_1": "layer_norm1", + "ln_2": "layer_norm2", + "mlp.c_fc": "mlp.fc1", + "mlp.c_proj": "mlp.fc2", + "attn.out_proj": "self_attn.out_proj", + } + + for resblock in range(number): + for x in resblock_to_replace: + for y in ["weight", "bias"]: + k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y) + k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y) + if k in sd: + sd[k_to] = sd.pop(k) + + for y in ["weight", "bias"]: + k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y) + if k_from in sd: + weights = sd.pop(k_from) + shape_from = weights.shape[0] // 3 + for x in range(3): + p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] + k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y) + sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] + + return sd + +def clip_text_transformers_convert(sd, prefix_from, prefix_to): + sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32) + + tp = "{}text_projection.weight".format(prefix_from) + if tp in sd: + sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp) + + tp = "{}text_projection".format(prefix_from) + if tp in sd: + sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1).contiguous() + return sd + + +UNET_MAP_ATTENTIONS = { + "proj_in.weight", + "proj_in.bias", + "proj_out.weight", + "proj_out.bias", + "norm.weight", + "norm.bias", +} + +TRANSFORMER_BLOCKS = { + "norm1.weight", + "norm1.bias", + "norm2.weight", + "norm2.bias", + "norm3.weight", + "norm3.bias", + "attn1.to_q.weight", + "attn1.to_k.weight", + "attn1.to_v.weight", + "attn1.to_out.0.weight", + "attn1.to_out.0.bias", + "attn2.to_q.weight", + "attn2.to_k.weight", + "attn2.to_v.weight", + "attn2.to_out.0.weight", + "attn2.to_out.0.bias", + "ff.net.0.proj.weight", + "ff.net.0.proj.bias", + "ff.net.2.weight", + "ff.net.2.bias", +} + +UNET_MAP_RESNET = { + "in_layers.2.weight": "conv1.weight", + "in_layers.2.bias": "conv1.bias", + "emb_layers.1.weight": "time_emb_proj.weight", + "emb_layers.1.bias": "time_emb_proj.bias", + "out_layers.3.weight": "conv2.weight", + "out_layers.3.bias": "conv2.bias", + "skip_connection.weight": "conv_shortcut.weight", + "skip_connection.bias": "conv_shortcut.bias", + "in_layers.0.weight": "norm1.weight", + "in_layers.0.bias": "norm1.bias", + "out_layers.0.weight": "norm2.weight", + "out_layers.0.bias": "norm2.bias", +} + +UNET_MAP_BASIC = { + ("label_emb.0.0.weight", "class_embedding.linear_1.weight"), + ("label_emb.0.0.bias", "class_embedding.linear_1.bias"), + ("label_emb.0.2.weight", "class_embedding.linear_2.weight"), + ("label_emb.0.2.bias", "class_embedding.linear_2.bias"), + ("label_emb.0.0.weight", "add_embedding.linear_1.weight"), + ("label_emb.0.0.bias", "add_embedding.linear_1.bias"), + ("label_emb.0.2.weight", "add_embedding.linear_2.weight"), + ("label_emb.0.2.bias", "add_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias") +} + +def unet_to_diffusers(unet_config): + if "num_res_blocks" not in unet_config: + return {} + num_res_blocks = unet_config["num_res_blocks"] + channel_mult = unet_config["channel_mult"] + transformer_depth = unet_config["transformer_depth"][:] + transformer_depth_output = unet_config["transformer_depth_output"][:] + num_blocks = len(channel_mult) + + transformers_mid = unet_config.get("transformer_depth_middle", None) + + diffusers_unet_map = {} + for x in range(num_blocks): + n = 1 + (num_res_blocks[x] + 1) * x + for i in range(num_res_blocks[x]): + for b in UNET_MAP_RESNET: + diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b) + num_transformers = transformer_depth.pop(0) + if num_transformers > 0: + for b in UNET_MAP_ATTENTIONS: + diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b) + for t in range(num_transformers): + for b in TRANSFORMER_BLOCKS: + diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b) + n += 1 + for k in ["weight", "bias"]: + diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k) + + i = 0 + for b in UNET_MAP_ATTENTIONS: + diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b) + for t in range(transformers_mid): + for b in TRANSFORMER_BLOCKS: + diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b) + + for i, n in enumerate([0, 2]): + for b in UNET_MAP_RESNET: + diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b) + + num_res_blocks = list(reversed(num_res_blocks)) + for x in range(num_blocks): + n = (num_res_blocks[x] + 1) * x + l = num_res_blocks[x] + 1 + for i in range(l): + c = 0 + for b in UNET_MAP_RESNET: + diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b) + c += 1 + num_transformers = transformer_depth_output.pop() + if num_transformers > 0: + c += 1 + for b in UNET_MAP_ATTENTIONS: + diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b) + for t in range(num_transformers): + for b in TRANSFORMER_BLOCKS: + diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b) + if i == l - 1: + for k in ["weight", "bias"]: + diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k) + n += 1 + + for k in UNET_MAP_BASIC: + diffusers_unet_map[k[1]] = k[0] + + return diffusers_unet_map + +def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + +MMDIT_MAP_BASIC = { + ("context_embedder.bias", "context_embedder.bias"), + ("context_embedder.weight", "context_embedder.weight"), + ("t_embedder.mlp.0.bias", "time_text_embed.timestep_embedder.linear_1.bias"), + ("t_embedder.mlp.0.weight", "time_text_embed.timestep_embedder.linear_1.weight"), + ("t_embedder.mlp.2.bias", "time_text_embed.timestep_embedder.linear_2.bias"), + ("t_embedder.mlp.2.weight", "time_text_embed.timestep_embedder.linear_2.weight"), + ("x_embedder.proj.bias", "pos_embed.proj.bias"), + ("x_embedder.proj.weight", "pos_embed.proj.weight"), + ("y_embedder.mlp.0.bias", "time_text_embed.text_embedder.linear_1.bias"), + ("y_embedder.mlp.0.weight", "time_text_embed.text_embedder.linear_1.weight"), + ("y_embedder.mlp.2.bias", "time_text_embed.text_embedder.linear_2.bias"), + ("y_embedder.mlp.2.weight", "time_text_embed.text_embedder.linear_2.weight"), + ("pos_embed", "pos_embed.pos_embed"), + ("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift), + ("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift), + ("final_layer.linear.bias", "proj_out.bias"), + ("final_layer.linear.weight", "proj_out.weight"), +} + +MMDIT_MAP_BLOCK = { + ("context_block.adaLN_modulation.1.bias", "norm1_context.linear.bias"), + ("context_block.adaLN_modulation.1.weight", "norm1_context.linear.weight"), + ("context_block.attn.proj.bias", "attn.to_add_out.bias"), + ("context_block.attn.proj.weight", "attn.to_add_out.weight"), + ("context_block.mlp.fc1.bias", "ff_context.net.0.proj.bias"), + ("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"), + ("context_block.mlp.fc2.bias", "ff_context.net.2.bias"), + ("context_block.mlp.fc2.weight", "ff_context.net.2.weight"), + ("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"), + ("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"), + ("x_block.attn.proj.bias", "attn.to_out.0.bias"), + ("x_block.attn.proj.weight", "attn.to_out.0.weight"), + ("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"), + ("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"), + ("x_block.mlp.fc2.bias", "ff.net.2.bias"), + ("x_block.mlp.fc2.weight", "ff.net.2.weight"), +} + +def mmdit_to_diffusers(mmdit_config, output_prefix=""): + key_map = {} + + depth = mmdit_config.get("depth", 0) + num_blocks = mmdit_config.get("num_blocks", depth) + for i in range(num_blocks): + block_from = "transformer_blocks.{}".format(i) + block_to = "{}joint_blocks.{}".format(output_prefix, i) + + offset = depth * 64 + + for end in ("weight", "bias"): + k = "{}.attn.".format(block_from) + qkv = "{}.x_block.attn.qkv.{}".format(block_to, end) + key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset)) + key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset)) + key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset)) + + qkv = "{}.context_block.attn.qkv.{}".format(block_to, end) + key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, offset)) + key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset)) + key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset)) + + for k in MMDIT_MAP_BLOCK: + key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0]) + + map_basic = MMDIT_MAP_BASIC.copy() + map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.bias".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.bias".format(depth - 1), swap_scale_shift)) + map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.weight".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.weight".format(depth - 1), swap_scale_shift)) + + for k in map_basic: + if len(k) > 2: + key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2]) + else: + key_map[k[1]] = "{}{}".format(output_prefix, k[0]) + + return key_map + + +def auraflow_to_diffusers(mmdit_config, output_prefix=""): + n_double_layers = mmdit_config.get("n_double_layers", 0) + n_layers = mmdit_config.get("n_layers", 0) + + key_map = {} + for i in range(n_layers): + if i < n_double_layers: + index = i + prefix_from = "joint_transformer_blocks" + prefix_to = "{}double_layers".format(output_prefix) + block_map = { + "attn.to_q.weight": "attn.w2q.weight", + "attn.to_k.weight": "attn.w2k.weight", + "attn.to_v.weight": "attn.w2v.weight", + "attn.to_out.0.weight": "attn.w2o.weight", + "attn.add_q_proj.weight": "attn.w1q.weight", + "attn.add_k_proj.weight": "attn.w1k.weight", + "attn.add_v_proj.weight": "attn.w1v.weight", + "attn.to_add_out.weight": "attn.w1o.weight", + "ff.linear_1.weight": "mlpX.c_fc1.weight", + "ff.linear_2.weight": "mlpX.c_fc2.weight", + "ff.out_projection.weight": "mlpX.c_proj.weight", + "ff_context.linear_1.weight": "mlpC.c_fc1.weight", + "ff_context.linear_2.weight": "mlpC.c_fc2.weight", + "ff_context.out_projection.weight": "mlpC.c_proj.weight", + "norm1.linear.weight": "modX.1.weight", + "norm1_context.linear.weight": "modC.1.weight", + } + else: + index = i - n_double_layers + prefix_from = "single_transformer_blocks" + prefix_to = "{}single_layers".format(output_prefix) + + block_map = { + "attn.to_q.weight": "attn.w1q.weight", + "attn.to_k.weight": "attn.w1k.weight", + "attn.to_v.weight": "attn.w1v.weight", + "attn.to_out.0.weight": "attn.w1o.weight", + "norm1.linear.weight": "modCX.1.weight", + "ff.linear_1.weight": "mlp.c_fc1.weight", + "ff.linear_2.weight": "mlp.c_fc2.weight", + "ff.out_projection.weight": "mlp.c_proj.weight" + } + + for k in block_map: + key_map["{}.{}.{}".format(prefix_from, index, k)] = "{}.{}.{}".format(prefix_to, index, block_map[k]) + + MAP_BASIC = { + ("positional_encoding", "pos_embed.pos_embed"), + ("register_tokens", "register_tokens"), + ("t_embedder.mlp.0.weight", "time_step_proj.linear_1.weight"), + ("t_embedder.mlp.0.bias", "time_step_proj.linear_1.bias"), + ("t_embedder.mlp.2.weight", "time_step_proj.linear_2.weight"), + ("t_embedder.mlp.2.bias", "time_step_proj.linear_2.bias"), + ("cond_seq_linear.weight", "context_embedder.weight"), + ("init_x_linear.weight", "pos_embed.proj.weight"), + ("init_x_linear.bias", "pos_embed.proj.bias"), + ("final_linear.weight", "proj_out.weight"), + ("modF.1.weight", "norm_out.linear.weight", swap_scale_shift), + } + + for k in MAP_BASIC: + if len(k) > 2: + key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2]) + else: + key_map[k[1]] = "{}{}".format(output_prefix, k[0]) + + return key_map + +def flux_to_diffusers(mmdit_config, output_prefix=""): + n_double_layers = mmdit_config.get("depth", 0) + n_single_layers = mmdit_config.get("depth_single_blocks", 0) + hidden_size = mmdit_config.get("hidden_size", 0) + + key_map = {} + for index in range(n_double_layers): + prefix_from = "transformer_blocks.{}".format(index) + prefix_to = "{}double_blocks.{}".format(output_prefix, index) + + for end in ("weight", "bias"): + k = "{}.attn.".format(prefix_from) + qkv = "{}.img_attn.qkv.{}".format(prefix_to, end) + key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) + key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) + key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) + + k = "{}.attn.".format(prefix_from) + qkv = "{}.txt_attn.qkv.{}".format(prefix_to, end) + key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) + key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) + key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) + + block_map = { + "attn.to_out.0.weight": "img_attn.proj.weight", + "attn.to_out.0.bias": "img_attn.proj.bias", + "norm1.linear.weight": "img_mod.lin.weight", + "norm1.linear.bias": "img_mod.lin.bias", + "norm1_context.linear.weight": "txt_mod.lin.weight", + "norm1_context.linear.bias": "txt_mod.lin.bias", + "attn.to_add_out.weight": "txt_attn.proj.weight", + "attn.to_add_out.bias": "txt_attn.proj.bias", + "ff.net.0.proj.weight": "img_mlp.0.weight", + "ff.net.0.proj.bias": "img_mlp.0.bias", + "ff.net.2.weight": "img_mlp.2.weight", + "ff.net.2.bias": "img_mlp.2.bias", + "ff_context.net.0.proj.weight": "txt_mlp.0.weight", + "ff_context.net.0.proj.bias": "txt_mlp.0.bias", + "ff_context.net.2.weight": "txt_mlp.2.weight", + "ff_context.net.2.bias": "txt_mlp.2.bias", + "attn.norm_q.weight": "img_attn.norm.query_norm.scale", + "attn.norm_k.weight": "img_attn.norm.key_norm.scale", + "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale", + "attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale", + } + + for k in block_map: + key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k]) + + for index in range(n_single_layers): + prefix_from = "single_transformer_blocks.{}".format(index) + prefix_to = "{}single_blocks.{}".format(output_prefix, index) + + for end in ("weight", "bias"): + k = "{}.attn.".format(prefix_from) + qkv = "{}.linear1.{}".format(prefix_to, end) + key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) + key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) + key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) + key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4)) + + block_map = { + "norm.linear.weight": "modulation.lin.weight", + "norm.linear.bias": "modulation.lin.bias", + "proj_out.weight": "linear2.weight", + "proj_out.bias": "linear2.bias", + "attn.norm_q.weight": "norm.query_norm.scale", + "attn.norm_k.weight": "norm.key_norm.scale", + } + + for k in block_map: + key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k]) + + MAP_BASIC = { + ("final_layer.linear.bias", "proj_out.bias"), + ("final_layer.linear.weight", "proj_out.weight"), + ("img_in.bias", "x_embedder.bias"), + ("img_in.weight", "x_embedder.weight"), + ("time_in.in_layer.bias", "time_text_embed.timestep_embedder.linear_1.bias"), + ("time_in.in_layer.weight", "time_text_embed.timestep_embedder.linear_1.weight"), + ("time_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"), + ("time_in.out_layer.weight", "time_text_embed.timestep_embedder.linear_2.weight"), + ("txt_in.bias", "context_embedder.bias"), + ("txt_in.weight", "context_embedder.weight"), + ("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"), + ("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"), + ("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"), + ("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"), + ("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"), + ("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"), + ("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"), + ("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"), + ("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift), + ("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift), + } + + for k in MAP_BASIC: + if len(k) > 2: + key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2]) + else: + key_map[k[1]] = "{}{}".format(output_prefix, k[0]) + + return key_map + +def repeat_to_batch_size(tensor, batch_size, dim=0): + if tensor.shape[dim] > batch_size: + return tensor.narrow(dim, 0, batch_size) + elif tensor.shape[dim] < batch_size: + return tensor.repeat(dim * [1] + [math.ceil(batch_size / tensor.shape[dim])] + [1] * (len(tensor.shape) - 1 - dim)).narrow(dim, 0, batch_size) + return tensor + +def resize_to_batch_size(tensor, batch_size): + in_batch_size = tensor.shape[0] + if in_batch_size == batch_size: + return tensor + + if batch_size <= 1: + return tensor[:batch_size] + + output = torch.empty([batch_size] + list(tensor.shape)[1:], dtype=tensor.dtype, device=tensor.device) + if batch_size < in_batch_size: + scale = (in_batch_size - 1) / (batch_size - 1) + for i in range(batch_size): + output[i] = tensor[min(round(i * scale), in_batch_size - 1)] + else: + scale = in_batch_size / batch_size + for i in range(batch_size): + output[i] = tensor[min(math.floor((i + 0.5) * scale), in_batch_size - 1)] + + return output + +def convert_sd_to(state_dict, dtype): + keys = list(state_dict.keys()) + for k in keys: + state_dict[k] = state_dict[k].to(dtype) + return state_dict + +def safetensors_header(safetensors_path, max_size=100*1024*1024): + with open(safetensors_path, "rb") as f: + header = f.read(8) + length_of_header = struct.unpack(' max_size: + return None + return f.read(length_of_header) + +def set_attr(obj, attr, value): + attrs = attr.split(".") + for name in attrs[:-1]: + obj = getattr(obj, name) + prev = getattr(obj, attrs[-1]) + setattr(obj, attrs[-1], value) + return prev + +def set_attr_param(obj, attr, value): + return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) + +def copy_to_param(obj, attr, value): + # inplace update tensor instead of replacing it + attrs = attr.split(".") + for name in attrs[:-1]: + obj = getattr(obj, name) + prev = getattr(obj, attrs[-1]) + prev.data.copy_(value) + +def get_attr(obj, attr): + attrs = attr.split(".") + for name in attrs: + obj = getattr(obj, name) + return obj + +def bislerp(samples, width, height): + def slerp(b1, b2, r): + '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' + + c = b1.shape[-1] + + #norms + b1_norms = torch.norm(b1, dim=-1, keepdim=True) + b2_norms = torch.norm(b2, dim=-1, keepdim=True) + + #normalize + b1_normalized = b1 / b1_norms + b2_normalized = b2 / b2_norms + + #zero when norms are zero + b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0 + b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0 + + #slerp + dot = (b1_normalized*b2_normalized).sum(1) + omega = torch.acos(dot) + so = torch.sin(omega) + + #technically not mathematically correct, but more pleasing? + res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized + res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c) + + #edge cases for same or polar opposites + res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5] + res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1] + return res + + def generate_bilinear_data(length_old, length_new, device): + coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear") + ratios = coords_1 - coords_1.floor() + coords_1 = coords_1.to(torch.int64) + + coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1 + coords_2[:,:,:,-1] -= 1 + coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear") + coords_2 = coords_2.to(torch.int64) + return ratios, coords_1, coords_2 + + orig_dtype = samples.dtype + samples = samples.float() + n,c,h,w = samples.shape + h_new, w_new = (height, width) + + #linear w + ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device) + coords_1 = coords_1.expand((n, c, h, -1)) + coords_2 = coords_2.expand((n, c, h, -1)) + ratios = ratios.expand((n, 1, h, -1)) + + pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c)) + pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c)) + ratios = ratios.movedim(1, -1).reshape((-1,1)) + + result = slerp(pass_1, pass_2, ratios) + result = result.reshape(n, h, w_new, c).movedim(-1, 1) + + #linear h + ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device) + coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) + coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) + ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new)) + + pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c)) + pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c)) + ratios = ratios.movedim(1, -1).reshape((-1,1)) + + result = slerp(pass_1, pass_2, ratios) + result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) + return result.to(orig_dtype) + +def lanczos(samples, width, height): + images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] + images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] + images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images] + result = torch.stack(images) + return result.to(samples.device, samples.dtype) + +def common_upscale(samples, width, height, upscale_method, crop): + if crop == "center": + old_width = samples.shape[3] + old_height = samples.shape[2] + old_aspect = old_width / old_height + new_aspect = width / height + x = 0 + y = 0 + if old_aspect > new_aspect: + x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) + elif old_aspect < new_aspect: + y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) + s = samples[:,:,y:old_height-y,x:old_width-x] + else: + s = samples + + if upscale_method == "bislerp": + return bislerp(s, width, height) + elif upscale_method == "lanczos": + return lanczos(s, width, height) + else: + return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) + +def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): + return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) + +@torch.inference_mode() +def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): + dims = len(tile) + output = torch.empty([samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device) + + for b in range(samples.shape[0]): + s = samples[b:b+1] + out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device) + out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device) + + for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))): + s_in = s + upscaled = [] + + for d in range(dims): + pos = max(0, min(s.shape[d + 2] - overlap, it[d])) + l = min(tile[d], s.shape[d + 2] - pos) + s_in = s_in.narrow(d + 2, pos, l) + upscaled.append(round(pos * upscale_amount)) + ps = function(s_in).to(output_device) + mask = torch.ones_like(ps) + feather = round(overlap * upscale_amount) + for t in range(feather): + for d in range(2, dims + 2): + m = mask.narrow(d, t, 1) + m *= ((1.0/feather) * (t + 1)) + m = mask.narrow(d, mask.shape[d] -1 -t, 1) + m *= ((1.0/feather) * (t + 1)) + + o = out + o_d = out_div + for d in range(dims): + o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2]) + o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2]) + + o += ps * mask + o_d += mask + + if pbar is not None: + pbar.update(1) + + output[b:b+1] = out/out_div + return output + +def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): + return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar) + +PROGRESS_BAR_ENABLED = True +def set_progress_bar_enabled(enabled): + global PROGRESS_BAR_ENABLED + PROGRESS_BAR_ENABLED = enabled + +PROGRESS_BAR_HOOK = None +def set_progress_bar_global_hook(function): + global PROGRESS_BAR_HOOK + PROGRESS_BAR_HOOK = function + +class ProgressBar: + def __init__(self, total): + global PROGRESS_BAR_HOOK + self.total = total + self.current = 0 + self.hook = PROGRESS_BAR_HOOK + + def update_absolute(self, value, total=None, preview=None): + if total is not None: + self.total = total + if value > self.total: + value = self.total + self.current = value + if self.hook is not None: + self.hook(self.current, self.total, preview) + + def update(self, value): + self.update_absolute(self.current + value) diff --git a/packages_3rdparty/webui_lora_collection/lora.py b/packages_3rdparty/webui_lora_collection/lora.py new file mode 100644 index 00000000..e95db050 --- /dev/null +++ b/packages_3rdparty/webui_lora_collection/lora.py @@ -0,0 +1,2 @@ +# TODO: Implement API + diff --git a/packages_3rdparty/webui_lora_collection/lyco_helpers.py b/packages_3rdparty/webui_lora_collection/lyco_helpers.py new file mode 100644 index 00000000..3d4efd7e --- /dev/null +++ b/packages_3rdparty/webui_lora_collection/lyco_helpers.py @@ -0,0 +1,68 @@ +import torch + + +def make_weight_cp(t, wa, wb): + temp = torch.einsum('i j k l, j r -> i r k l', t, wb) + return torch.einsum('i j k l, i r -> r j k l', temp, wa) + + +def rebuild_conventional(up, down, shape, dyn_dim=None): + up = up.reshape(up.size(0), -1) + down = down.reshape(down.size(0), -1) + if dyn_dim is not None: + up = up[:, :dyn_dim] + down = down[:dyn_dim, :] + return (up @ down).reshape(shape) + + +def rebuild_cp_decomposition(up, down, mid): + up = up.reshape(up.size(0), -1) + down = down.reshape(down.size(0), -1) + return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down) + + +# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py +def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: + ''' + return a tuple of two value of input dimension decomposed by the number closest to factor + second value is higher or equal than first value. + + In LoRA with Kroneckor Product, first value is a value for weight scale. + secon value is a value for weight. + + Because of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. + + examples) + factor + -1 2 4 8 16 ... + 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 + 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 + 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 + 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 + 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 + 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 + ''' + + if factor > 0 and (dimension % factor) == 0: + m = factor + n = dimension // factor + if m > n: + n, m = m, n + return m, n + if factor < 0: + factor = dimension + m, n = 1, dimension + length = m + n + while m length or new_m>factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n + diff --git a/packages_3rdparty/webui_lora_collection/network.py b/packages_3rdparty/webui_lora_collection/network.py new file mode 100644 index 00000000..89987438 --- /dev/null +++ b/packages_3rdparty/webui_lora_collection/network.py @@ -0,0 +1,228 @@ +from __future__ import annotations +import os +from collections import namedtuple +import enum + +import torch.nn as nn +import torch.nn.functional as F + +from modules import sd_models, cache, errors, hashes, shared +import modules.models.sd3.mmdit + +NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module']) + +metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} + + +class SdVersion(enum.Enum): + Unknown = 1 + SD1 = 2 + SD2 = 3 + SDXL = 4 + + +class NetworkOnDisk: + def __init__(self, name, filename): + self.name = name + self.filename = filename + self.metadata = {} + self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" + + def read_metadata(): + metadata = sd_models.read_metadata_from_safetensors(filename) + + return metadata + + if self.is_safetensors: + try: + self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata) + except Exception as e: + errors.display(e, f"reading lora {filename}") + + if self.metadata: + m = {} + for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)): + m[k] = v + + self.metadata = m + + self.alias = self.metadata.get('ss_output_name', self.name) + + self.hash = None + self.shorthash = None + self.set_hash( + self.metadata.get('sshs_model_hash') or + hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or + '' + ) + + self.sd_version = self.detect_version() + + def detect_version(self): + if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"): + return SdVersion.SDXL + elif str(self.metadata.get('ss_v2', "")) == "True": + return SdVersion.SD2 + elif len(self.metadata): + return SdVersion.SD1 + + return SdVersion.Unknown + + def set_hash(self, v): + self.hash = v + self.shorthash = self.hash[0:12] + + if self.shorthash: + import networks + networks.available_network_hash_lookup[self.shorthash] = self + + def read_hash(self): + if not self.hash: + self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '') + + def get_alias(self): + import networks + if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases: + return self.name + else: + return self.alias + + +class Network: # LoraModule + def __init__(self, name, network_on_disk: NetworkOnDisk): + self.name = name + self.network_on_disk = network_on_disk + self.te_multiplier = 1.0 + self.unet_multiplier = 1.0 + self.dyn_dim = None + self.modules = {} + self.bundle_embeddings = {} + self.mtime = None + + self.mentioned_name = None + """the text that was used to add the network to prompt - can be either name or an alias""" + + +class ModuleType: + def create_module(self, net: Network, weights: NetworkWeights) -> Network | None: + return None + + +class NetworkModule: + def __init__(self, net: Network, weights: NetworkWeights): + self.network = net + self.network_key = weights.network_key + self.sd_key = weights.sd_key + self.sd_module = weights.sd_module + + if isinstance(self.sd_module, modules.models.sd3.mmdit.QkvLinear): + s = self.sd_module.weight.shape + self.shape = (s[0] // 3, s[1]) + elif hasattr(self.sd_module, 'weight'): + self.shape = self.sd_module.weight.shape + elif isinstance(self.sd_module, nn.MultiheadAttention): + # For now, only self-attn use Pytorch's MHA + # So assume all qkvo proj have same shape + self.shape = self.sd_module.out_proj.weight.shape + else: + self.shape = None + + self.ops = None + self.extra_kwargs = {} + if isinstance(self.sd_module, nn.Conv2d): + self.ops = F.conv2d + self.extra_kwargs = { + 'stride': self.sd_module.stride, + 'padding': self.sd_module.padding + } + elif isinstance(self.sd_module, nn.Linear): + self.ops = F.linear + elif isinstance(self.sd_module, nn.LayerNorm): + self.ops = F.layer_norm + self.extra_kwargs = { + 'normalized_shape': self.sd_module.normalized_shape, + 'eps': self.sd_module.eps + } + elif isinstance(self.sd_module, nn.GroupNorm): + self.ops = F.group_norm + self.extra_kwargs = { + 'num_groups': self.sd_module.num_groups, + 'eps': self.sd_module.eps + } + + self.dim = None + self.bias = weights.w.get("bias") + self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None + self.scale = weights.w["scale"].item() if "scale" in weights.w else None + + self.dora_scale = weights.w.get("dora_scale", None) + self.dora_norm_dims = len(self.shape) - 1 + + def multiplier(self): + if 'transformer' in self.sd_key[:20]: + return self.network.te_multiplier + else: + return self.network.unet_multiplier + + def calc_scale(self): + if self.scale is not None: + return self.scale + if self.dim is not None and self.alpha is not None: + return self.alpha / self.dim + + return 1.0 + + def apply_weight_decompose(self, updown, orig_weight): + # Match the device/dtype + orig_weight = orig_weight.to(updown.dtype) + dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype) + updown = updown.to(orig_weight.device) + + merged_scale1 = updown + orig_weight + merged_scale1_norm = ( + merged_scale1.transpose(0, 1) + .reshape(merged_scale1.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims) + .transpose(0, 1) + ) + + dora_merged = ( + merged_scale1 * (dora_scale / merged_scale1_norm) + ) + final_updown = dora_merged - orig_weight + return final_updown + + def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): + if self.bias is not None: + updown = updown.reshape(self.bias.shape) + updown += self.bias.to(orig_weight.device, dtype=updown.dtype) + updown = updown.reshape(output_shape) + + if len(output_shape) == 4: + updown = updown.reshape(output_shape) + + if orig_weight.size().numel() == updown.size().numel(): + updown = updown.reshape(orig_weight.shape) + + if ex_bias is not None: + ex_bias = ex_bias * self.multiplier() + + updown = updown * self.calc_scale() + + if self.dora_scale is not None: + updown = self.apply_weight_decompose(updown, orig_weight) + + return updown * self.multiplier(), ex_bias + + def calc_updown(self, target): + raise NotImplementedError() + + def forward(self, x, y): + """A general forward implementation for all modules""" + if self.ops is None: + raise NotImplementedError() + else: + updown, ex_bias = self.calc_updown(self.sd_module.weight) + return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs) + diff --git a/packages_3rdparty/webui_lora_collection/network_full.py b/packages_3rdparty/webui_lora_collection/network_full.py new file mode 100644 index 00000000..cf5fbbb2 --- /dev/null +++ b/packages_3rdparty/webui_lora_collection/network_full.py @@ -0,0 +1,27 @@ +import network + + +class ModuleTypeFull(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["diff"]): + return NetworkModuleFull(net, weights) + + return None + + +class NetworkModuleFull(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.weight = weights.w.get("diff") + self.ex_bias = weights.w.get("diff_b") + + def calc_updown(self, orig_weight): + output_shape = self.weight.shape + updown = self.weight.to(orig_weight.device) + if self.ex_bias is not None: + ex_bias = self.ex_bias.to(orig_weight.device) + else: + ex_bias = None + + return self.finalize_updown(updown, orig_weight, output_shape, ex_bias) diff --git a/packages_3rdparty/webui_lora_collection/network_glora.py b/packages_3rdparty/webui_lora_collection/network_glora.py new file mode 100644 index 00000000..efe5c681 --- /dev/null +++ b/packages_3rdparty/webui_lora_collection/network_glora.py @@ -0,0 +1,33 @@ + +import network + +class ModuleTypeGLora(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]): + return NetworkModuleGLora(net, weights) + + return None + +# adapted from https://github.com/KohakuBlueleaf/LyCORIS +class NetworkModuleGLora(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + if hasattr(self.sd_module, 'weight'): + self.shape = self.sd_module.weight.shape + + self.w1a = weights.w["a1.weight"] + self.w1b = weights.w["b1.weight"] + self.w2a = weights.w["a2.weight"] + self.w2b = weights.w["b2.weight"] + + def calc_updown(self, orig_weight): + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) + + output_shape = [w1a.size(0), w1b.size(1)] + updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a)) + + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/packages_3rdparty/webui_lora_collection/network_hada.py b/packages_3rdparty/webui_lora_collection/network_hada.py new file mode 100644 index 00000000..d179b29e --- /dev/null +++ b/packages_3rdparty/webui_lora_collection/network_hada.py @@ -0,0 +1,55 @@ +import lyco_helpers +import network + + +class ModuleTypeHada(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]): + return NetworkModuleHada(net, weights) + + return None + + +class NetworkModuleHada(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + if hasattr(self.sd_module, 'weight'): + self.shape = self.sd_module.weight.shape + + self.w1a = weights.w["hada_w1_a"] + self.w1b = weights.w["hada_w1_b"] + self.dim = self.w1b.shape[0] + self.w2a = weights.w["hada_w2_a"] + self.w2b = weights.w["hada_w2_b"] + + self.t1 = weights.w.get("hada_t1") + self.t2 = weights.w.get("hada_t2") + + def calc_updown(self, orig_weight): + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) + + output_shape = [w1a.size(0), w1b.size(1)] + + if self.t1 is not None: + output_shape = [w1a.size(1), w1b.size(1)] + t1 = self.t1.to(orig_weight.device) + updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b) + output_shape += t1.shape[2:] + else: + if len(w1b.shape) == 4: + output_shape += w1b.shape[2:] + updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape) + + if self.t2 is not None: + t2 = self.t2.to(orig_weight.device) + updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) + else: + updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape) + + updown = updown1 * updown2 + + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/packages_3rdparty/webui_lora_collection/network_ia3.py b/packages_3rdparty/webui_lora_collection/network_ia3.py new file mode 100644 index 00000000..549b9a75 --- /dev/null +++ b/packages_3rdparty/webui_lora_collection/network_ia3.py @@ -0,0 +1,30 @@ +import network + + +class ModuleTypeIa3(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["weight"]): + return NetworkModuleIa3(net, weights) + + return None + + +class NetworkModuleIa3(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.w = weights.w["weight"] + self.on_input = weights.w["on_input"].item() + + def calc_updown(self, orig_weight): + w = self.w.to(orig_weight.device) + + output_shape = [w.size(0), orig_weight.size(1)] + if self.on_input: + output_shape.reverse() + else: + w = w.reshape(-1, 1) + + updown = orig_weight * w + + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/packages_3rdparty/webui_lora_collection/network_lokr.py b/packages_3rdparty/webui_lora_collection/network_lokr.py new file mode 100644 index 00000000..4f3128e5 --- /dev/null +++ b/packages_3rdparty/webui_lora_collection/network_lokr.py @@ -0,0 +1,64 @@ +import torch + +import lyco_helpers +import network + + +class ModuleTypeLokr(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + has_1 = "lokr_w1" in weights.w or ("lokr_w1_a" in weights.w and "lokr_w1_b" in weights.w) + has_2 = "lokr_w2" in weights.w or ("lokr_w2_a" in weights.w and "lokr_w2_b" in weights.w) + if has_1 and has_2: + return NetworkModuleLokr(net, weights) + + return None + + +def make_kron(orig_shape, w1, w2): + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + w2 = w2.contiguous() + return torch.kron(w1, w2).reshape(orig_shape) + + +class NetworkModuleLokr(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.w1 = weights.w.get("lokr_w1") + self.w1a = weights.w.get("lokr_w1_a") + self.w1b = weights.w.get("lokr_w1_b") + self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim + self.w2 = weights.w.get("lokr_w2") + self.w2a = weights.w.get("lokr_w2_a") + self.w2b = weights.w.get("lokr_w2_b") + self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim + self.t2 = weights.w.get("lokr_t2") + + def calc_updown(self, orig_weight): + if self.w1 is not None: + w1 = self.w1.to(orig_weight.device) + else: + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) + w1 = w1a @ w1b + + if self.w2 is not None: + w2 = self.w2.to(orig_weight.device) + elif self.t2 is None: + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) + w2 = w2a @ w2b + else: + t2 = self.t2.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) + w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) + + output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)] + if len(orig_weight.shape) == 4: + output_shape = orig_weight.shape + + updown = make_kron(output_shape, w1, w2) + + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/packages_3rdparty/webui_lora_collection/network_lora.py b/packages_3rdparty/webui_lora_collection/network_lora.py new file mode 100644 index 00000000..8ee26c31 --- /dev/null +++ b/packages_3rdparty/webui_lora_collection/network_lora.py @@ -0,0 +1,94 @@ +import torch + +import lyco_helpers +import modules.models.sd3.mmdit +import network +from modules import devices + + +class ModuleTypeLora(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]): + return NetworkModuleLora(net, weights) + + if all(x in weights.w for x in ["lora_A.weight", "lora_B.weight"]): + w = weights.w.copy() + weights.w.clear() + weights.w.update({"lora_up.weight": w["lora_B.weight"], "lora_down.weight": w["lora_A.weight"]}) + + return NetworkModuleLora(net, weights) + + return None + + +class NetworkModuleLora(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.up_model = self.create_module(weights.w, "lora_up.weight") + self.down_model = self.create_module(weights.w, "lora_down.weight") + self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True) + + self.dim = weights.w["lora_down.weight"].shape[0] + + def create_module(self, weights, key, none_ok=False): + weight = weights.get(key) + + if weight is None and none_ok: + return None + + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear] + is_conv = type(self.sd_module) in [torch.nn.Conv2d] + + if is_linear: + weight = weight.reshape(weight.shape[0], -1) + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif is_conv and key == "lora_down.weight" or key == "dyn_up": + if len(weight.shape) == 2: + weight = weight.reshape(weight.shape[0], -1, 1, 1) + + if weight.shape[2] != 1 or weight.shape[3] != 1: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) + else: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + elif is_conv and key == "lora_mid.weight": + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) + elif is_conv and key == "lora_up.weight" or key == "dyn_down": + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + else: + raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}') + + with torch.no_grad(): + if weight.shape != module.weight.shape: + weight = weight.reshape(module.weight.shape) + module.weight.copy_(weight) + + module.to(device=devices.cpu, dtype=devices.dtype) + module.weight.requires_grad_(False) + + return module + + def calc_updown(self, orig_weight): + up = self.up_model.weight.to(orig_weight.device) + down = self.down_model.weight.to(orig_weight.device) + + output_shape = [up.size(0), down.size(1)] + if self.mid_model is not None: + # cp-decomposition + mid = self.mid_model.weight.to(orig_weight.device) + updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid) + output_shape += mid.shape[2:] + else: + if len(down.shape) == 4: + output_shape += down.shape[2:] + updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim) + + return self.finalize_updown(updown, orig_weight, output_shape) + + def forward(self, x, y): + self.up_model.to(device=devices.device) + self.down_model.to(device=devices.device) + + return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale() + + diff --git a/packages_3rdparty/webui_lora_collection/network_norm.py b/packages_3rdparty/webui_lora_collection/network_norm.py new file mode 100644 index 00000000..d25afcbb --- /dev/null +++ b/packages_3rdparty/webui_lora_collection/network_norm.py @@ -0,0 +1,28 @@ +import network + + +class ModuleTypeNorm(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["w_norm", "b_norm"]): + return NetworkModuleNorm(net, weights) + + return None + + +class NetworkModuleNorm(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.w_norm = weights.w.get("w_norm") + self.b_norm = weights.w.get("b_norm") + + def calc_updown(self, orig_weight): + output_shape = self.w_norm.shape + updown = self.w_norm.to(orig_weight.device) + + if self.b_norm is not None: + ex_bias = self.b_norm.to(orig_weight.device) + else: + ex_bias = None + + return self.finalize_updown(updown, orig_weight, output_shape, ex_bias) diff --git a/packages_3rdparty/webui_lora_collection/network_oft.py b/packages_3rdparty/webui_lora_collection/network_oft.py new file mode 100644 index 00000000..1c515ebb --- /dev/null +++ b/packages_3rdparty/webui_lora_collection/network_oft.py @@ -0,0 +1,118 @@ +import torch +import network +from einops import rearrange + + +class ModuleTypeOFT(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]): + return NetworkModuleOFT(net, weights) + + return None + +# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py +class NetworkModuleOFT(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + + super().__init__(net, weights) + + self.lin_module = None + self.org_module: list[torch.Module] = [self.sd_module] + + self.scale = 1.0 + self.is_R = False + self.is_boft = False + + # kohya-ss/New LyCORIS OFT/BOFT + if "oft_blocks" in weights.w.keys(): + self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size) + self.alpha = weights.w.get("alpha", None) # alpha is constraint + self.dim = self.oft_blocks.shape[0] # lora dim + # Old LyCORIS OFT + elif "oft_diag" in weights.w.keys(): + self.is_R = True + self.oft_blocks = weights.w["oft_diag"] + # self.alpha is unused + self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) + + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] + is_conv = type(self.sd_module) in [torch.nn.Conv2d] + is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported + + if is_linear: + self.out_dim = self.sd_module.out_features + elif is_conv: + self.out_dim = self.sd_module.out_channels + elif is_other_linear: + self.out_dim = self.sd_module.embed_dim + + # LyCORIS BOFT + if self.oft_blocks.dim() == 4: + self.is_boft = True + self.rescale = weights.w.get('rescale', None) + if self.rescale is not None and not is_other_linear: + self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1)) + + self.num_blocks = self.dim + self.block_size = self.out_dim // self.dim + self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim + if self.is_R: + self.constraint = None + self.block_size = self.dim + self.num_blocks = self.out_dim // self.dim + elif self.is_boft: + self.boft_m = self.oft_blocks.shape[0] + self.num_blocks = self.oft_blocks.shape[1] + self.block_size = self.oft_blocks.shape[2] + self.boft_b = self.block_size + + def calc_updown(self, orig_weight): + oft_blocks = self.oft_blocks.to(orig_weight.device) + eye = torch.eye(self.block_size, device=oft_blocks.device) + + if not self.is_R: + block_Q = oft_blocks - oft_blocks.transpose(-1, -2) # ensure skew-symmetric orthogonal matrix + if self.constraint != 0: + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device)) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) + + R = oft_blocks.to(orig_weight.device) + + if not self.is_boft: + # This errors out for MultiheadAttention, might need to be handled up-stream + merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + merged_weight = torch.einsum( + 'k n m, k n ... -> k m ...', + R, + merged_weight + ) + merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + else: + # TODO: determine correct value for scale + scale = 1.0 + m = self.boft_m + b = self.boft_b + r_b = b // 2 + inp = orig_weight + for i in range(m): + bi = R[i] # b_num, b_size, b_size + if i == 0: + # Apply multiplier/scale and rescale into first weight + bi = bi * scale + (1 - scale) * eye + inp = rearrange(inp, "(c g k) ... -> (c k g) ...", g=2, k=2**i * r_b) + inp = rearrange(inp, "(d b) ... -> d b ...", b=b) + inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp) + inp = rearrange(inp, "d b ... -> (d b) ...") + inp = rearrange(inp, "(c k g) ... -> (c g k) ...", g=2, k=2**i * r_b) + merged_weight = inp + + # Rescale mechanism + if self.rescale is not None: + merged_weight = self.rescale.to(merged_weight) * merged_weight + + updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype) + output_shape = orig_weight.shape + return self.finalize_updown(updown, orig_weight, output_shape)