From e9459b6c4af8a3f80f6d0ad379bbc33e94cff284 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Fri, 9 Feb 2024 20:51:18 -0800 Subject: [PATCH] support sd1.5 model with v prediction #123 --- modules_forge/forge_loader.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index 7b283958..eb25120b 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -17,6 +17,7 @@ from modules.sd_models_xl import extend_sdxl from ldm.util import instantiate_from_config from modules_forge import forge_clip from modules_forge.unet_patcher import UnetPatcher +from ldm_patched.modules.model_base import model_sampling, ModelType import open_clip from transformers import CLIPTextModel, CLIPTokenizer @@ -219,6 +220,9 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") + if getattr(sd_model, 'parameterization') == 'v': + sd_model.forge_objects.unet.model.model_sampling = model_sampling(sd_model.forge_objects.unet.model.model_config, ModelType.V_PREDICTION) + sd_model.is_sdxl = conditioner is not None sd_model.is_sd2 = not sd_model.is_sdxl and hasattr(sd_model.cond_stage_model, 'model') sd_model.is_sd1 = not sd_model.is_sdxl and not sd_model.is_sd2