mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-21 06:48:56 +00:00
Fix .yaml config loading (#2224)
This commit is contained in:
@@ -290,14 +290,39 @@ def forge_loader(sd, additional_state_dicts=None):
|
||||
if component is not None:
|
||||
huggingface_components[component_name] = component
|
||||
|
||||
# Fix Huggingface prediction type using estimated config detection
|
||||
yaml_config = None
|
||||
yaml_config_prediction_type = None
|
||||
|
||||
try:
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
config_filename = os.path.splitext(sd)[0] + '.yaml'
|
||||
if Path(config_filename).is_file():
|
||||
with open(config_filename, 'r') as stream:
|
||||
yaml_config = yaml.safe_load(stream)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Fix Huggingface prediction type using .yaml config or estimated config detection
|
||||
prediction_types = {
|
||||
'EPS': 'epsilon',
|
||||
'V_PREDICTION': 'v_prediction',
|
||||
'EDM': 'edm',
|
||||
}
|
||||
if 'scheduler' in huggingface_components and hasattr(huggingface_components['scheduler'], 'config') and 'prediction_type' in huggingface_components['scheduler'].config:
|
||||
huggingface_components['scheduler'].config.prediction_type = prediction_types.get(estimated_config.model_type.name, huggingface_components['scheduler'].config.prediction_type)
|
||||
|
||||
has_prediction_type = 'scheduler' in huggingface_components and hasattr(huggingface_components['scheduler'], 'config') and 'prediction_type' in huggingface_components['scheduler'].config
|
||||
|
||||
if yaml_config is not None:
|
||||
model_config_params = config.get('model', {}).get('params', {})
|
||||
if "parameterization" in model_config_params:
|
||||
if model_config_params["parameterization"] == "v":
|
||||
yaml_config_prediction_type = 'v_prediction'
|
||||
|
||||
if has_prediction_type:
|
||||
if yaml_config_prediction_type is not None:
|
||||
huggingface_components['scheduler'].config.prediction_type = yaml_config_prediction_type
|
||||
else:
|
||||
huggingface_components['scheduler'].config.prediction_type = prediction_types.get(estimated_config.model_type.name, huggingface_components['scheduler'].config.prediction_type)
|
||||
|
||||
for M in possible_models:
|
||||
if any(isinstance(estimated_config, x) for x in M.matched_guesses):
|
||||
|
||||
Reference in New Issue
Block a user