Fix .yaml config loading (#2224)

This commit is contained in:
catboxanon
2024-10-30 16:18:44 -04:00
committed by GitHub
parent ecd4d28e46
commit b691b1e755

View File

@@ -290,14 +290,39 @@ def forge_loader(sd, additional_state_dicts=None):
if component is not None:
huggingface_components[component_name] = component
# Fix Huggingface prediction type using estimated config detection
yaml_config = None
yaml_config_prediction_type = None
try:
import yaml
from pathlib import Path
config_filename = os.path.splitext(sd)[0] + '.yaml'
if Path(config_filename).is_file():
with open(config_filename, 'r') as stream:
yaml_config = yaml.safe_load(stream)
except ImportError:
pass
# Fix Huggingface prediction type using .yaml config or estimated config detection
prediction_types = {
'EPS': 'epsilon',
'V_PREDICTION': 'v_prediction',
'EDM': 'edm',
}
if 'scheduler' in huggingface_components and hasattr(huggingface_components['scheduler'], 'config') and 'prediction_type' in huggingface_components['scheduler'].config:
huggingface_components['scheduler'].config.prediction_type = prediction_types.get(estimated_config.model_type.name, huggingface_components['scheduler'].config.prediction_type)
has_prediction_type = 'scheduler' in huggingface_components and hasattr(huggingface_components['scheduler'], 'config') and 'prediction_type' in huggingface_components['scheduler'].config
if yaml_config is not None:
model_config_params = config.get('model', {}).get('params', {})
if "parameterization" in model_config_params:
if model_config_params["parameterization"] == "v":
yaml_config_prediction_type = 'v_prediction'
if has_prediction_type:
if yaml_config_prediction_type is not None:
huggingface_components['scheduler'].config.prediction_type = yaml_config_prediction_type
else:
huggingface_components['scheduler'].config.prediction_type = prediction_types.get(estimated_config.model_type.name, huggingface_components['scheduler'].config.prediction_type)
for M in possible_models:
if any(isinstance(estimated_config, x) for x in M.matched_guesses):