diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index 7b283958..eb25120b 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -17,6 +17,7 @@ from modules.sd_models_xl import extend_sdxl from ldm.util import instantiate_from_config from modules_forge import forge_clip from modules_forge.unet_patcher import UnetPatcher +from ldm_patched.modules.model_base import model_sampling, ModelType import open_clip from transformers import CLIPTextModel, CLIPTokenizer @@ -219,6 +220,9 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") + if getattr(sd_model, 'parameterization') == 'v': + sd_model.forge_objects.unet.model.model_sampling = model_sampling(sd_model.forge_objects.unet.model.model_config, ModelType.V_PREDICTION) + sd_model.is_sdxl = conditioner is not None sd_model.is_sd2 = not sd_model.is_sdxl and hasattr(sd_model.cond_stage_model, 'model') sd_model.is_sd1 = not sd_model.is_sdxl and not sd_model.is_sd2