only load lora one time

This commit is contained in:
layerdiffusion
2024-08-16 02:02:22 -07:00
parent 243952f364
commit 12369669cf
5 changed files with 162 additions and 140 deletions

View File

@@ -32,28 +32,23 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filen
if len(lora_unmatch) > 0:
print(f'[LORA] Loading {filename} for {model_flag} with unmatched keys {list(lora_unmatch.keys())}')
new_model = model.clone() if model is not None else None
new_clip = clip.clone() if clip is not None else None
if new_model is not None and len(lora_unet) > 0:
loaded_keys = new_model.add_patches(lora_unet, strength_model)
if model is not None and len(lora_unet) > 0:
loaded_keys = model.lora_loader.add_patches(lora_unet, strength_model)
skipped_keys = [item for item in lora_unet if item not in loaded_keys]
if len(skipped_keys) > 12:
print(f'[LORA] Mismatch {filename} for {model_flag}-UNet with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys')
else:
print(f'[LORA] Loaded {filename} for {model_flag}-UNet with {len(loaded_keys)} keys at weight {strength_model} (skipped {len(skipped_keys)} keys)')
model = new_model
if new_clip is not None and len(lora_clip) > 0:
loaded_keys = new_clip.add_patches(lora_clip, strength_clip)
if clip is not None and len(lora_clip) > 0:
loaded_keys = clip.patcher.lora_loader.add_patches(lora_clip, strength_clip)
skipped_keys = [item for item in lora_clip if item not in loaded_keys]
if len(skipped_keys) > 12:
print(f'[LORA] Mismatch {filename} for {model_flag}-CLIP with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys')
else:
print(f'[LORA] Loaded {filename} for {model_flag}-CLIP with {len(loaded_keys)} keys at weight {strength_clip} (skipped {len(skipped_keys)} keys)')
clip = new_clip
return model, clip
return
@functools.lru_cache(maxsize=5)
@@ -112,14 +107,15 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
return
current_sd.current_lora_hash = compiled_lora_targets_hash
current_sd.forge_objects.unet = current_sd.forge_objects_original.unet
current_sd.forge_objects.clip = current_sd.forge_objects_original.clip
current_sd.forge_objects.unet = current_sd.forge_objects_original.unet.clone()
current_sd.forge_objects.clip = current_sd.forge_objects_original.clip.clone()
current_sd.forge_objects.unet.lora_loader.clear_patches()
current_sd.forge_objects.clip.patcher.lora_loader.clear_patches()
for filename, strength_model, strength_clip in compiled_lora_targets:
lora_sd = load_lora_state_dict(filename)
current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models(
current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip,
filename=filename)
load_lora_for_models(current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip, filename=filename)
current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy()
return