From e747c8c564900049bf593ce1e7a030fbd1d75736 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 25 Jan 2024 12:05:42 -0800 Subject: [PATCH] Update networks.py --- extensions-builtin/Lora/networks.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index da4d0269..7eb97980 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -8,8 +8,13 @@ import torch from typing import Union from modules import shared, sd_models, errors, scripts +from ldm_patched.modules.utils import load_torch_file +from ldm_patched.modules.sd import load_lora_for_models +lora_state_dict_cache = {} +lora_state_dict_cache_max_length = 5 + module_types = [] @@ -109,6 +114,8 @@ def purge_networks_from_memory(): def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): + global lora_state_dict_cache + current_sd = sd_models.model_data.get_sd_model() if current_sd is None: return @@ -129,7 +136,17 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No current_sd.current_lora_hash = compiled_lora_targets_hash + for filename, strength_model, strength_clip in compiled_lora_targets: + if filename in lora_state_dict_cache: + lora_sd = lora_state_dict_cache[filename] + else: + if len(lora_state_dict_cache) > lora_state_dict_cache_max_length: + lora_state_dict_cache = {} + + lora_sd = load_torch_file(filename, safe_load=True) + lora_state_dict_cache[filename] = lora_sd + a = 0 return