diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index e8aba268..15c20110 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -12,6 +12,8 @@ import ldm_patched.modules.clip_vision from omegaconf import OmegaConf from modules.sd_models_config import find_checkpoint_config from modules.shared import cmd_opts +import modules.sd_hijack as sd_hijack +from modules.sd_models_xl import extend_sdxl from ldm.util import instantiate_from_config import open_clip @@ -186,6 +188,23 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): timer.record("forge set components") + sd_hijack.model_hijack.hijack(sd_model) + timer.record("forge hijack") + + sd_model_hash = checkpoint_info.calculate_shorthash() + timer.record("calculate hash") + + sd_model.is_sdxl = conditioner is not None + sd_model.is_sd2 = not sd_model.is_sdxl and hasattr(sd_model.cond_stage_model, 'model') + sd_model.is_sd1 = not sd_model.is_sdxl and not sd_model.is_sd2 + sd_model.is_ssd = sd_model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in sd_model.state_dict().keys() + if sd_model.is_sdxl: + extend_sdxl(sd_model) + sd_model.sd_model_hash = sd_model_hash + sd_model.sd_model_checkpoint = checkpoint_info.filename + sd_model.sd_checkpoint_info = checkpoint_info + timer.record("forge finalize") + return