From b781e7f80f240e159abb5589a933276d8afa547e Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 25 Jan 2024 03:45:58 -0800 Subject: [PATCH] i --- modules/sd_models.py | 2 ++ modules_forge/forge_loader.py | 62 +++++++++++++++++++++++++++++++---- 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 42729bf5..636bc518 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -599,6 +599,8 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): else: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + sd_model = forge_loader.load_model_for_a1111(timer=timer, checkpoint_info=checkpoint_info, state_dict=state_dict) + checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict) diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index 582d2e27..d4b612f0 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -1,4 +1,5 @@ import torch +import contextlib from ldm_patched.modules import model_management from ldm_patched.modules import model_detection @@ -7,6 +8,61 @@ from ldm_patched.modules.sd import VAE import ldm_patched.modules.model_patcher import ldm_patched.modules.utils +from omegaconf import OmegaConf +from modules.sd_models_config import find_checkpoint_config +from ldm.util import instantiate_from_config + +import open_clip +from transformers import CLIPTextModel, CLIPTokenizer + + +class FakeObject(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.visual = None + return + + +@contextlib.contextmanager +def no_clip(): + backup_openclip = open_clip.create_model_and_transforms + backup_CLIPTextModel = CLIPTextModel.from_pretrained + backup_CLIPTokenizer = CLIPTokenizer.from_pretrained + + try: + open_clip.create_model_and_transforms = lambda *args, **kwargs: (FakeObject(), None, None) + CLIPTextModel.from_pretrained = lambda *args, **kwargs: FakeObject() + CLIPTokenizer.from_pretrained = lambda *args, **kwargs: FakeObject() + yield + + finally: + open_clip.create_model_and_transforms = backup_openclip + CLIPTextModel.from_pretrained = backup_CLIPTextModel + CLIPTokenizer.from_pretrained = backup_CLIPTokenizer + return + + +def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): + a1111_config = find_checkpoint_config(state_dict, checkpoint_info) + a1111_config = OmegaConf.load(a1111_config) + timer.record("forge solving config") + + if hasattr(a1111_config.model.params, 'network_config'): + a1111_config.model.params.network_config.target = 'modules_forge.forge_loader.FakeObject' + + if hasattr(a1111_config.model.params, 'unet_config'): + a1111_config.model.params.unet_config.target = 'modules_forge.forge_loader.FakeObject' + + if hasattr(a1111_config.model.params, 'first_stage_config'): + a1111_config.model.params.first_stage_config.target = 'modules_forge.forge_loader.FakeObject' + + with no_clip(): + sd_model = instantiate_from_config(a1111_config.model) + + timer.record("forge instantiate config") + + return + def load_unet_and_vae(sd): parameters = ldm_patched.modules.utils.calculate_parameters(sd, "model.diffusion_model.") @@ -34,9 +90,3 @@ def load_unet_and_vae(sd): vae_patcher = VAE(sd=vae_sd) return model_patcher, vae_patcher - - -class FakeObject(torch.nn.Module): - def __init__(self, *args, **kwargs): - super().__init__() - return