Update sd_models.py

This commit is contained in:
lllyasviel
2024-01-25 04:42:39 -08:00
parent 50a3deb39c
commit 854997c163

View File

@@ -599,92 +599,26 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
else:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model
checkpoints_loaded[checkpoint_info] = state_dict.copy()
sd_model = forge_loader.load_model_for_a1111(timer=timer, checkpoint_info=checkpoint_info, state_dict=state_dict)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
timer.record("find config")
sd_config = OmegaConf.load(checkpoint_config)
repair_config(sd_config)
timer.record("load config")
if hasattr(sd_config.model.params, 'network_config'):
sd_config.model.params.network_config.target = 'modules_forge.forge_loader.FakeObject'
if hasattr(sd_config.model.params, 'unet_config'):
sd_config.model.params.unet_config.target = 'modules_forge.forge_loader.FakeObject'
if hasattr(sd_config.model.params, 'first_stage_config'):
sd_config.model.params.first_stage_config.target = 'modules_forge.forge_loader.FakeObject'
print(f"Creating model from config: {checkpoint_config}")
sd_model = None
try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
with forge_ops.use_patched_ops(manual_cast):
sd_model = instantiate_from_config(sd_config.model)
except Exception as e:
errors.display(e, "creating model quickly", full_traceback=True)
if sd_model is None:
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
with forge_ops.use_patched_ops(manual_cast):
sd_model = instantiate_from_config(sd_config.model)
sd_model.used_config = checkpoint_config
timer.record("create model")
state_dict_for_a1111 = {k: v for k, v in state_dict.items() if not k.startswith('model.diffusion_model.') and not k.startswith('first_stage_model.')}
state_dict_for_forge = {k: v for k, v in state_dict.items()}
del state_dict
unet_patcher, vae_patcher = forge_loader.load_unet_and_vae(state_dict_for_forge)
sd_model.first_stage_model = vae_patcher.first_stage_model
sd_model.model.diffusion_model = unet_patcher.model.diffusion_model
sd_model.unet_patcher = unet_patcher
sd_model.model.diffusion_model.patcher = unet_patcher
sd_model.vae_patcher = vae_patcher
sd_model.first_stage_model.patcher = vae_patcher
timer.record("create unet patcher")
del state_dict_for_forge
# clean up cache if limit is reached
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False)
def patched_decode_first_stage(sample):
sample = unet_patcher.model.model_config.latent_format.process_out(sample)
return vae_patcher.decode(sample).movedim(-1, 1) * 2.0 - 1.0
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
sd_model.decode_first_stage = patched_decode_first_stage
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
sd_vae.load_vae(sd_model, vae_file, vae_source)
timer.record("load VAE")
load_model_weights(sd_model, checkpoint_info, state_dict_for_a1111, timer)
del state_dict_for_a1111
timer.record("load weights from state dict")
current_clip = sd_model.conditioner if hasattr(sd_model, 'conditioner') else sd_model.cond_stage_model
clip_load_device = model_management.text_encoder_device()
clip_offload_device = model_management.text_encoder_offload_device()
clip_dtype = model_management.text_encoder_dtype()
current_clip.to(clip_dtype)
clip_patcher = ldm_patched.modules.model_patcher.ModelPatcher(
current_clip,
load_device=clip_load_device,
offload_device=clip_offload_device
)
sd_model.clip_patcher = clip_patcher
current_clip.patcher = clip_patcher
timer.record("create clip patcher")
sd_hijack.model_hijack.hijack(sd_model)
timer.record("hijack")
sd_model.eval()
model_data.set_sd_model(sd_model)
model_data.was_loaded_at_least_once = True
@@ -696,7 +630,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("scripts callbacks")
with devices.autocast(), torch.no_grad():
with torch.no_grad():
sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
timer.record("calculate empty prompt")