Support T5&Clip Text Encoder LoRA from OneTrainer

requested by #1727
and some cleanups/licenses
PS: LoRA request must give download URL to at least one LoRA
This commit is contained in:
layerdiffusion
2024-09-08 01:39:29 -07:00
parent c3366a7689
commit 44eb4ea837
7 changed files with 1438 additions and 39 deletions

View File

@@ -1,24 +1,13 @@
import torch
import time
import packages_3rdparty.webui_lora_collection.lora as lora_utils_webui
import packages_3rdparty.comfyui_lora_collection.lora as lora_utils_comfyui
from tqdm import tqdm
from backend import memory_management, utils
from backend.args import dynamic_args
class ForgeLoraCollection:
# TODO
pass
extra_weight_calculators = {}
lora_utils_forge = ForgeLoraCollection()
lora_collection_priority = [lora_utils_forge, lora_utils_webui, lora_utils_comfyui]
lora_collection_priority = [lora_utils_webui, lora_utils_comfyui]
def get_function(function_name: str):
@@ -32,12 +21,31 @@ def load_lora(lora, to_load):
return patch_dict, remaining_dict
def inner_str(k, prefix="", suffix=""):
return k[len(prefix):-len(suffix)]
def model_lora_keys_clip(model, key_map={}):
return get_function('model_lora_keys_clip')(model, key_map)
model_keys, key_maps = get_function('model_lora_keys_clip')(model, key_map)
for model_key in model_keys:
if model_key.endswith(".weight"):
if model_key.startswith("t5xxl.transformer."):
# Flux OneTrainer T5
formatted = inner_str(model_key, "t5xxl.transformer.", ".weight")
formatted = formatted.replace(".", "_")
formatted = f"lora_te2_{formatted}"
key_map[formatted] = model_key
return key_maps
def model_lora_keys_unet(model, key_map={}):
return get_function('model_lora_keys_unet')(model, key_map)
model_keys, key_maps = get_function('model_lora_keys_unet')(model, key_map)
# TODO: OFT
return key_maps
@torch.inference_mode()