From 90a6970fb76bc62f2acd2e2a98328b79fe0a8f76 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sat, 2 Nov 2024 10:16:51 -0400 Subject: [PATCH] Compatibility for ldm .yaml configs (#2247) --- backend/loader.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/backend/loader.py b/backend/loader.py index 72ecc8b7..fa157ac2 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -313,10 +313,12 @@ def forge_loader(sd, additional_state_dicts=None): has_prediction_type = 'scheduler' in huggingface_components and hasattr(huggingface_components['scheduler'], 'config') and 'prediction_type' in huggingface_components['scheduler'].config if yaml_config is not None: - model_config_params = yaml_config.get('model', {}).get('params', {}) - if "parameterization" in model_config_params: - if model_config_params["parameterization"] == "v": - yaml_config_prediction_type = 'v_prediction' + yaml_config_prediction_type: str = ( + yaml_config.get('model', {}).get('params', {}).get('parameterization', '') + or yaml_config.get('model', {}).get('params', {}).get('denoiser_config', {}).get('params', {}).get('scaling_config').get('target', '') + ) + if yaml_config_prediction_type == 'v' or yaml_config_prediction_type.endswith(".VScaling"): + yaml_config_prediction_type = 'v_prediction' if has_prediction_type: if yaml_config_prediction_type is not None: