mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-26 01:09:10 +00:00
@@ -132,6 +132,12 @@ class CheckpointInfo:
|
||||
|
||||
return self.shorthash
|
||||
|
||||
def __str__(self):
|
||||
return str(dict(filename=self.filename, hash=self.hash))
|
||||
|
||||
def __repr__(self):
|
||||
return str(dict(filename=self.filename, hash=self.hash))
|
||||
|
||||
|
||||
# try:
|
||||
# # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
@@ -379,8 +385,8 @@ def apply_alpha_schedule_override(sd_model, p=None):
|
||||
class SdModelData:
|
||||
def __init__(self):
|
||||
self.sd_model = None
|
||||
self.loaded_sd_models = []
|
||||
self.was_loaded_at_least_once = False
|
||||
self.forge_loading_parameters = {}
|
||||
self.forge_hash = ''
|
||||
|
||||
def get_sd_model(self):
|
||||
if self.sd_model is None:
|
||||
@@ -388,12 +394,8 @@ class SdModelData:
|
||||
|
||||
return self.sd_model
|
||||
|
||||
def set_sd_model(self, v, already_loaded=False):
|
||||
def set_sd_model(self, v):
|
||||
self.sd_model = v
|
||||
if already_loaded:
|
||||
sd_vae.base_vae = getattr(v, "base_vae", None)
|
||||
sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
|
||||
sd_vae.checkpoint_info = v.sd_checkpoint_info
|
||||
|
||||
|
||||
model_data = SdModelData()
|
||||
@@ -461,28 +463,45 @@ def apply_token_merging(sd_model, token_merging_ratio):
|
||||
|
||||
@torch.no_grad()
|
||||
def forge_model_reload():
|
||||
checkpoint_info = select_checkpoint()
|
||||
current_hash = str(model_data.forge_loading_parameters)
|
||||
|
||||
if model_data.forge_hash == current_hash:
|
||||
return model_data.sd_model
|
||||
|
||||
print('Loading Model: ' + str(model_data.forge_loading_parameters))
|
||||
|
||||
timer = Timer()
|
||||
|
||||
if model_data.sd_model:
|
||||
model_data.sd_model = None
|
||||
model_data.loaded_sd_models = []
|
||||
memory_management.unload_all_models()
|
||||
memory_management.soft_empty_cache()
|
||||
gc.collect()
|
||||
|
||||
timer.record("unload existing model")
|
||||
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
checkpoint_info = model_data.forge_loading_parameters['checkpoint_info']
|
||||
state_dict = load_torch_file(checkpoint_info.filename)
|
||||
timer.record("load state dict")
|
||||
|
||||
state_dict_vae = model_data.forge_loading_parameters.get('vae_filename', None)
|
||||
|
||||
if state_dict_vae is not None:
|
||||
state_dict_vae = load_torch_file(state_dict_vae)
|
||||
|
||||
timer.record("load vae state dict")
|
||||
|
||||
if shared.opts.sd_checkpoint_cache > 0:
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = state_dict.copy()
|
||||
|
||||
timer.record("cache state dict")
|
||||
|
||||
dynamic_args['forge_unet_storage_dtype'] = model_data.forge_loading_parameters.get('unet_storage_dtype', None)
|
||||
dynamic_args['embedding_dir'] = cmd_opts.embeddings_dir
|
||||
dynamic_args['emphasis_name'] = opts.emphasis
|
||||
sd_model = forge_loader(state_dict)
|
||||
sd_model = forge_loader(state_dict, sd_vae=state_dict_vae)
|
||||
del state_dict
|
||||
timer.record("forge model load")
|
||||
|
||||
sd_model.extra_generation_params = {}
|
||||
@@ -492,22 +511,13 @@ def forge_model_reload():
|
||||
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
del state_dict
|
||||
|
||||
# clean up cache if limit is reached
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||
checkpoints_loaded.popitem(last=False)
|
||||
|
||||
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
|
||||
|
||||
sd_vae.delete_base_vae()
|
||||
sd_vae.clear_loaded_vae()
|
||||
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
|
||||
sd_vae.load_vae(sd_model, vae_file, vae_source)
|
||||
timer.record("load VAE")
|
||||
|
||||
model_data.set_sd_model(sd_model)
|
||||
model_data.was_loaded_at_least_once = True
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
@@ -515,4 +525,6 @@ def forge_model_reload():
|
||||
|
||||
print(f"Model loaded in {timer.summary()}.")
|
||||
|
||||
model_data.forge_hash = current_hash
|
||||
|
||||
return sd_model
|
||||
|
||||
Reference in New Issue
Block a user