From 5e5b60b5b151991ca9d30e642dedd2da90032d96 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Sun, 25 Feb 2024 11:04:14 -0800 Subject: [PATCH] rework lora loading and add logs --- extensions-builtin/Lora/networks.py | 3 +- ldm_patched/modules/lora.py | 6 +-- ldm_patched/modules/sd.py | 61 +++++++++++++++++------------ 3 files changed, 40 insertions(+), 30 deletions(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 05823b54..83026c80 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -62,7 +62,8 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No for filename, strength_model, strength_clip in compiled_lora_targets: lora_sd = load_lora_state_dict(filename) current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models( - current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip) + current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip, + filename=filename) current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy() return diff --git a/ldm_patched/modules/lora.py b/ldm_patched/modules/lora.py index 7e78e65d..7f74f119 100644 --- a/ldm_patched/modules/lora.py +++ b/ldm_patched/modules/lora.py @@ -158,10 +158,8 @@ def load_lora(lora, to_load): patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,)) loaded_keys.add(diff_bias_name) - for x in lora.keys(): - if x not in loaded_keys: - print("lora key not loaded", x) - return patch_dict + remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys} + return patch_dict, remaining_dict def model_lora_keys_clip(model, key_map={}): sdk = model.state_dict().keys() diff --git a/ldm_patched/modules/sd.py b/ldm_patched/modules/sd.py index ed232b5a..8c018cc5 100644 --- a/ldm_patched/modules/sd.py +++ b/ldm_patched/modules/sd.py @@ -56,34 +56,45 @@ def load_clip_weights(model, sd): return load_model_weights(model, sd) -def load_lora_for_models(model, clip, lora, strength_model, strength_clip): - key_map = {} - if model is not None: - key_map = ldm_patched.modules.lora.model_lora_keys_unet(model.model, key_map) - if clip is not None: - key_map = ldm_patched.modules.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) +def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filename='default'): + model_flag = type(model.model).__name__ if model is not None else 'default' - loaded = ldm_patched.modules.lora.load_lora(lora, key_map) - if model is not None: - new_modelpatcher = model.clone() - k = new_modelpatcher.add_patches(loaded, strength_model) - else: - k = () - new_modelpatcher = None + unet_keys = ldm_patched.modules.lora.model_lora_keys_unet(model.model) if model is not None else {} + clip_keys = ldm_patched.modules.lora.model_lora_keys_clip(clip.cond_stage_model) if clip is not None else {} - if clip is not None: - new_clip = clip.clone() - k1 = new_clip.add_patches(loaded, strength_clip) - else: - k1 = () - new_clip = None - k = set(k) - k1 = set(k1) - for x in loaded: - if (x not in k) and (x not in k1): - print("NOT LOADED", x) + lora_unmatch = lora + lora_unet, lora_unmatch = ldm_patched.modules.lora.load_lora(lora_unmatch, unet_keys) + lora_clip, lora_unmatch = ldm_patched.modules.lora.load_lora(lora_unmatch, clip_keys) - return (new_modelpatcher, new_clip) + if len(lora_unmatch) > 12: + print(f'[LORA] LoRA version mismatch for {model_flag}: {filename}') + return model, clip + + if len(lora_unmatch) > 0: + print(f'[LORA] Loading {filename} for {model_flag} with unmatched keys {list(lora_unmatch.keys())}') + + new_model = model.clone() if model is not None else None + new_clip = clip.clone() if model is not None else None + + if new_model is not None and len(lora_unet) > 0: + loaded_keys = new_model.add_patches(lora_unet, strength_model) + skipped_keys = [item for item in lora_unet if item not in loaded_keys] + if len(skipped_keys) > 12: + print(f'[LORA] Mismatch {filename} for {model_flag}-UNet with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys') + else: + print(f'[LORA] Loaded {filename} for {model_flag}-UNet with {len(loaded_keys)} keys at weight {strength_model} (skipped {len(skipped_keys)} keys)') + model = new_model + + if new_clip is not None and len(lora_clip) > 0: + loaded_keys = new_clip.add_patches(lora_clip, strength_clip) + skipped_keys = [item for item in lora_clip if item not in loaded_keys] + if len(skipped_keys) > 12: + print(f'[LORA] Mismatch {filename} for {model_flag}-CLIP with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys') + else: + print(f'[LORA] Loaded {filename} for {model_flag}-CLIP with {len(loaded_keys)} keys at weight {strength_clip} (skipped {len(skipped_keys)} keys)') + clip = new_clip + + return model, clip class CLIP: