Update forge_loader.py

This commit is contained in:
lllyasviel
2024-01-25 04:38:19 -08:00
parent d0eaa5c07e
commit f937148628

View File

@@ -12,6 +12,8 @@ import ldm_patched.modules.clip_vision
from omegaconf import OmegaConf
from modules.sd_models_config import find_checkpoint_config
from modules.shared import cmd_opts
import modules.sd_hijack as sd_hijack
from modules.sd_models_xl import extend_sdxl
from ldm.util import instantiate_from_config
import open_clip
@@ -186,6 +188,23 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
timer.record("forge set components")
sd_hijack.model_hijack.hijack(sd_model)
timer.record("forge hijack")
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
sd_model.is_sdxl = conditioner is not None
sd_model.is_sd2 = not sd_model.is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
sd_model.is_sd1 = not sd_model.is_sdxl and not sd_model.is_sd2
sd_model.is_ssd = sd_model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in sd_model.state_dict().keys()
if sd_model.is_sdxl:
extend_sdxl(sd_model)
sd_model.sd_model_hash = sd_model_hash
sd_model.sd_model_checkpoint = checkpoint_info.filename
sd_model.sd_checkpoint_info = checkpoint_info
timer.record("forge finalize")
return