rework lora loading

and add logs
This commit is contained in:
lllyasviel
2024-02-25 11:04:14 -08:00
parent 437c348926
commit 5e5b60b5b1
3 changed files with 40 additions and 30 deletions

View File

@@ -62,7 +62,8 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
for filename, strength_model, strength_clip in compiled_lora_targets:
lora_sd = load_lora_state_dict(filename)
current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models(
current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip)
current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip,
filename=filename)
current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy()
return

View File

@@ -158,10 +158,8 @@ def load_lora(lora, to_load):
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
loaded_keys.add(diff_bias_name)
for x in lora.keys():
if x not in loaded_keys:
print("lora key not loaded", x)
return patch_dict
remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
return patch_dict, remaining_dict
def model_lora_keys_clip(model, key_map={}):
sdk = model.state_dict().keys()

View File

@@ -56,34 +56,45 @@ def load_clip_weights(model, sd):
return load_model_weights(model, sd)
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
key_map = {}
if model is not None:
key_map = ldm_patched.modules.lora.model_lora_keys_unet(model.model, key_map)
if clip is not None:
key_map = ldm_patched.modules.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filename='default'):
model_flag = type(model.model).__name__ if model is not None else 'default'
loaded = ldm_patched.modules.lora.load_lora(lora, key_map)
if model is not None:
new_modelpatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model)
else:
k = ()
new_modelpatcher = None
unet_keys = ldm_patched.modules.lora.model_lora_keys_unet(model.model) if model is not None else {}
clip_keys = ldm_patched.modules.lora.model_lora_keys_clip(clip.cond_stage_model) if clip is not None else {}
if clip is not None:
new_clip = clip.clone()
k1 = new_clip.add_patches(loaded, strength_clip)
else:
k1 = ()
new_clip = None
k = set(k)
k1 = set(k1)
for x in loaded:
if (x not in k) and (x not in k1):
print("NOT LOADED", x)
lora_unmatch = lora
lora_unet, lora_unmatch = ldm_patched.modules.lora.load_lora(lora_unmatch, unet_keys)
lora_clip, lora_unmatch = ldm_patched.modules.lora.load_lora(lora_unmatch, clip_keys)
return (new_modelpatcher, new_clip)
if len(lora_unmatch) > 12:
print(f'[LORA] LoRA version mismatch for {model_flag}: {filename}')
return model, clip
if len(lora_unmatch) > 0:
print(f'[LORA] Loading {filename} for {model_flag} with unmatched keys {list(lora_unmatch.keys())}')
new_model = model.clone() if model is not None else None
new_clip = clip.clone() if model is not None else None
if new_model is not None and len(lora_unet) > 0:
loaded_keys = new_model.add_patches(lora_unet, strength_model)
skipped_keys = [item for item in lora_unet if item not in loaded_keys]
if len(skipped_keys) > 12:
print(f'[LORA] Mismatch {filename} for {model_flag}-UNet with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys')
else:
print(f'[LORA] Loaded {filename} for {model_flag}-UNet with {len(loaded_keys)} keys at weight {strength_model} (skipped {len(skipped_keys)} keys)')
model = new_model
if new_clip is not None and len(lora_clip) > 0:
loaded_keys = new_clip.add_patches(lora_clip, strength_clip)
skipped_keys = [item for item in lora_clip if item not in loaded_keys]
if len(skipped_keys) > 12:
print(f'[LORA] Mismatch {filename} for {model_flag}-CLIP with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys')
else:
print(f'[LORA] Loaded {filename} for {model_flag}-CLIP with {len(loaded_keys)} keys at weight {strength_clip} (skipped {len(skipped_keys)} keys)')
clip = new_clip
return model, clip
class CLIP: