mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-13 00:49:48 +00:00
Update sd_models.py
This commit is contained in:
@@ -599,92 +599,26 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
else:
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
if shared.opts.sd_checkpoint_cache > 0:
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = state_dict.copy()
|
||||
|
||||
sd_model = forge_loader.load_model_for_a1111(timer=timer, checkpoint_info=checkpoint_info, state_dict=state_dict)
|
||||
|
||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
|
||||
|
||||
timer.record("find config")
|
||||
|
||||
sd_config = OmegaConf.load(checkpoint_config)
|
||||
repair_config(sd_config)
|
||||
|
||||
timer.record("load config")
|
||||
|
||||
if hasattr(sd_config.model.params, 'network_config'):
|
||||
sd_config.model.params.network_config.target = 'modules_forge.forge_loader.FakeObject'
|
||||
|
||||
if hasattr(sd_config.model.params, 'unet_config'):
|
||||
sd_config.model.params.unet_config.target = 'modules_forge.forge_loader.FakeObject'
|
||||
|
||||
if hasattr(sd_config.model.params, 'first_stage_config'):
|
||||
sd_config.model.params.first_stage_config.target = 'modules_forge.forge_loader.FakeObject'
|
||||
|
||||
print(f"Creating model from config: {checkpoint_config}")
|
||||
|
||||
sd_model = None
|
||||
try:
|
||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
|
||||
with forge_ops.use_patched_ops(manual_cast):
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
|
||||
except Exception as e:
|
||||
errors.display(e, "creating model quickly", full_traceback=True)
|
||||
|
||||
if sd_model is None:
|
||||
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
||||
|
||||
with forge_ops.use_patched_ops(manual_cast):
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
|
||||
sd_model.used_config = checkpoint_config
|
||||
|
||||
timer.record("create model")
|
||||
|
||||
state_dict_for_a1111 = {k: v for k, v in state_dict.items() if not k.startswith('model.diffusion_model.') and not k.startswith('first_stage_model.')}
|
||||
state_dict_for_forge = {k: v for k, v in state_dict.items()}
|
||||
del state_dict
|
||||
|
||||
unet_patcher, vae_patcher = forge_loader.load_unet_and_vae(state_dict_for_forge)
|
||||
sd_model.first_stage_model = vae_patcher.first_stage_model
|
||||
sd_model.model.diffusion_model = unet_patcher.model.diffusion_model
|
||||
sd_model.unet_patcher = unet_patcher
|
||||
sd_model.model.diffusion_model.patcher = unet_patcher
|
||||
sd_model.vae_patcher = vae_patcher
|
||||
sd_model.first_stage_model.patcher = vae_patcher
|
||||
timer.record("create unet patcher")
|
||||
del state_dict_for_forge
|
||||
# clean up cache if limit is reached
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||
checkpoints_loaded.popitem(last=False)
|
||||
|
||||
def patched_decode_first_stage(sample):
|
||||
sample = unet_patcher.model.model_config.latent_format.process_out(sample)
|
||||
return vae_patcher.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
|
||||
|
||||
sd_model.decode_first_stage = patched_decode_first_stage
|
||||
sd_vae.delete_base_vae()
|
||||
sd_vae.clear_loaded_vae()
|
||||
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
|
||||
sd_vae.load_vae(sd_model, vae_file, vae_source)
|
||||
timer.record("load VAE")
|
||||
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict_for_a1111, timer)
|
||||
del state_dict_for_a1111
|
||||
timer.record("load weights from state dict")
|
||||
|
||||
current_clip = sd_model.conditioner if hasattr(sd_model, 'conditioner') else sd_model.cond_stage_model
|
||||
clip_load_device = model_management.text_encoder_device()
|
||||
clip_offload_device = model_management.text_encoder_offload_device()
|
||||
clip_dtype = model_management.text_encoder_dtype()
|
||||
|
||||
current_clip.to(clip_dtype)
|
||||
clip_patcher = ldm_patched.modules.model_patcher.ModelPatcher(
|
||||
current_clip,
|
||||
load_device=clip_load_device,
|
||||
offload_device=clip_offload_device
|
||||
)
|
||||
sd_model.clip_patcher = clip_patcher
|
||||
current_clip.patcher = clip_patcher
|
||||
timer.record("create clip patcher")
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
|
||||
timer.record("hijack")
|
||||
|
||||
sd_model.eval()
|
||||
model_data.set_sd_model(sd_model)
|
||||
model_data.was_loaded_at_least_once = True
|
||||
|
||||
@@ -696,7 +630,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
|
||||
timer.record("scripts callbacks")
|
||||
|
||||
with devices.autocast(), torch.no_grad():
|
||||
with torch.no_grad():
|
||||
sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
|
||||
|
||||
timer.record("calculate empty prompt")
|
||||
|
||||
Reference in New Issue
Block a user