From feac0d7f2d70dc13318ac5a6a70f5668c5d05573 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 25 Jan 2024 05:23:18 -0800 Subject: [PATCH] Update forge_loader.py --- modules_forge/forge_loader.py | 38 +++++++++-------------------------- 1 file changed, 10 insertions(+), 28 deletions(-) diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index 291a4910..fc124be1 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -205,36 +205,18 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): sd_model.sd_checkpoint_info = checkpoint_info timer.record("forge finalize") + def patched_decode_first_stage(sample): + sample = forge_object.unet.model.model_config.latent_format.process_out(sample) + return forge_object.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0 + + sd_model.decode_first_stage = patched_decode_first_stage + sd_model.unet_patcher = forge_object.unet sd_model.clip_patcher = forge_object.clip.patcher sd_model.vae_patcher = forge_object.vae.patcher + sd_model.unet_patcher_original = forge_object.unet + sd_model.clip_patcher_original = forge_object.clip.patcher + sd_model.vae_patcher_original = forge_object.vae.patcher + timer.record("get patcher") return sd_model - - -def load_unet_and_vae(sd): - parameters = ldm_patched.modules.utils.calculate_parameters(sd, "model.diffusion_model.") - unet_dtype = model_management.unet_dtype(model_params=parameters) - load_device = model_management.get_torch_device() - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) - - model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype) - model_config.set_manual_cast(manual_cast_dtype) - - if model_config is None: - raise RuntimeError("ERROR: Could not detect model type of") - - initial_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) - model = model_config.get_model(sd, "model.diffusion_model.", device=initial_load_device) - model.load_model_weights(sd, "model.diffusion_model.") - - model_patcher = ldm_patched.modules.model_patcher.ModelPatcher(model, - load_device=load_device, - offload_device=model_management.unet_offload_device(), - current_device=initial_load_device) - - vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) - vae_sd = model_config.process_vae_state_dict(vae_sd) - vae_patcher = VAE(sd=vae_sd) - - return model_patcher, vae_patcher