From d858e16240ce2414246248bff7581d99e5a0d2bd Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 25 Jan 2024 03:58:17 -0800 Subject: [PATCH] Update forge_loader.py --- modules_forge/forge_loader.py | 94 ++++++++++++++++++++++++++++++++++- 1 file changed, 93 insertions(+), 1 deletion(-) diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index d4b612f0..ccedf8c9 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -4,12 +4,14 @@ import contextlib from ldm_patched.modules import model_management from ldm_patched.modules import model_detection -from ldm_patched.modules.sd import VAE +from ldm_patched.modules.sd import VAE, CLIP, load_model_weights import ldm_patched.modules.model_patcher import ldm_patched.modules.utils +import ldm_patched.modules.clip_vision from omegaconf import OmegaConf from modules.sd_models_config import find_checkpoint_config +from modules.shared import cmd_opts from ldm.util import instantiate_from_config import open_clip @@ -23,6 +25,14 @@ class FakeObject(torch.nn.Module): return +class ForgeSD: + def __init__(self, unet, clip, vae, clipvision): + self.unet = unet + self.clip = clip + self.vae = vae + self.clipvision = clipvision + + @contextlib.contextmanager def no_clip(): backup_openclip = open_clip.create_model_and_transforms @@ -42,6 +52,66 @@ def no_clip(): return +def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): + sd_keys = sd.keys() + clip = None + clipvision = None + vae = None + model = None + model_patcher = None + clip_target = None + + parameters = ldm_patched.modules.utils.calculate_parameters(sd, "model.diffusion_model.") + unet_dtype = model_management.unet_dtype(model_params=parameters) + load_device = model_management.get_torch_device() + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) + + class WeightsLoader(torch.nn.Module): + pass + + model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype) + model_config.set_manual_cast(manual_cast_dtype) + + if model_config is None: + raise RuntimeError("ERROR: Could not detect model type") + + if model_config.clip_vision_prefix is not None: + if output_clipvision: + clipvision = ldm_patched.modules.clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) + + if output_model: + inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) + offload_device = model_management.unet_offload_device() + model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device) + model.load_model_weights(sd, "model.diffusion_model.") + + if output_vae: + vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) + vae_sd = model_config.process_vae_state_dict(vae_sd) + vae = VAE(sd=vae_sd) + + if output_clip: + w = WeightsLoader() + clip_target = model_config.clip_target() + if clip_target is not None: + clip = CLIP(clip_target, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model + sd = model_config.process_clip_state_dict(sd) + load_model_weights(w, sd) + + left_over = sd.keys() + if len(left_over) > 0: + print("left over keys:", left_over) + + if output_model: + model_patcher = ldm_patched.modules.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device) + if inital_load_device != torch.device("cpu"): + print("loaded straight to GPU") + model_management.load_model_gpu(model_patcher) + + return ForgeSD(model_patcher, clip, vae, clipvision) + + def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): a1111_config = find_checkpoint_config(state_dict, checkpoint_info) a1111_config = OmegaConf.load(a1111_config) @@ -61,6 +131,28 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): timer.record("forge instantiate config") + forge_objects = load_checkpoint_guess_config( + state_dict, + output_vae=True, + output_clip=True, + output_clipvision=True, + embedding_directory=cmd_opts.embeddings_dir, + output_model=True + ) + + sd_model.forge_objects = forge_objects + + sd_model.first_stage_model = vae_patcher.first_stage_model + sd_model.model.diffusion_model = unet_patcher.model.diffusion_model + sd_model.unet_patcher = unet_patcher + sd_model.model.diffusion_model.patcher = unet_patcher + sd_model.vae_patcher = vae_patcher + sd_model.first_stage_model.patcher = vae_patcher + + timer.record("forge load real models") + + + return