better model load logic

This commit is contained in:
layerdiffusion
2024-08-06 14:34:57 -07:00
parent c1b23bd494
commit 6d789653b9
3 changed files with 4 additions and 20 deletions

View File

@@ -120,10 +120,6 @@ def initialize_rest(*, reload_script_modules=False):
sd_unet.list_unets()
startup_timer.record("scripts list_unets")
from modules_forge import main_thread
import modules.sd_models
main_thread.async_run(modules.sd_models.model_data.get_sd_model)
from modules import shared_items
shared_items.reload_hypernetworks()
startup_timer.record("reload hypernetworks")

View File

@@ -380,25 +380,10 @@ class SdModelData:
self.sd_model = None
self.loaded_sd_models = []
self.was_loaded_at_least_once = False
self.lock = threading.Lock()
def get_sd_model(self):
if self.was_loaded_at_least_once:
return self.sd_model
if self.sd_model is None:
with self.lock:
if self.sd_model is not None or self.was_loaded_at_least_once:
return self.sd_model
try:
load_model()
except Exception as e:
errors.display(e, "loading stable diffusion model", full_traceback=True)
print("", file=sys.stderr)
print("Stable diffusion model failed to load", file=sys.stderr)
self.sd_model = None
raise ValueError('Something went wrong! Model is not loaded yet ...')
return self.sd_model

View File

@@ -66,4 +66,7 @@ def forge_main_entry():
ui_checkpoint.change(lambda x: main_thread.async_run(checkpoint_change, x), inputs=[ui_checkpoint], show_progress=False)
ui_vae.change(lambda x: main_thread.async_run(vae_change, x), inputs=[ui_vae], show_progress=False)
ui_clip_skip.change(lambda x: main_thread.async_run(clip_skip_change, x), inputs=[ui_clip_skip], show_progress=False)
# Load Model
main_thread.async_run(sd_models.load_model)
return