Update networks.py

This commit is contained in:
lllyasviel
2024-01-25 12:05:42 -08:00
parent 65142eb8b1
commit e747c8c564

View File

@@ -8,8 +8,13 @@ import torch
from typing import Union
from modules import shared, sd_models, errors, scripts
from ldm_patched.modules.utils import load_torch_file
from ldm_patched.modules.sd import load_lora_for_models
lora_state_dict_cache = {}
lora_state_dict_cache_max_length = 5
module_types = []
@@ -109,6 +114,8 @@ def purge_networks_from_memory():
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
global lora_state_dict_cache
current_sd = sd_models.model_data.get_sd_model()
if current_sd is None:
return
@@ -129,7 +136,17 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
current_sd.current_lora_hash = compiled_lora_targets_hash
for filename, strength_model, strength_clip in compiled_lora_targets:
if filename in lora_state_dict_cache:
lora_sd = lora_state_dict_cache[filename]
else:
if len(lora_state_dict_cache) > lora_state_dict_cache_max_length:
lora_state_dict_cache = {}
lora_sd = load_torch_file(filename, safe_load=True)
lora_state_dict_cache[filename] = lora_sd
a = 0
return