Fix model prediction detection

Closes #1109
This commit is contained in:
catboxanon
2024-10-19 06:18:02 -04:00
parent 1fae20d94f
commit 5ec47a6b93

View File

@@ -243,6 +243,7 @@ def split_state_dict(sd, additional_state_dicts: list = None):
sd = replace_state_dict(sd, asd, guess)
guess.clip_target = guess.clip_target(sd)
guess.model_type = guess.model_type(sd)
state_dict = {
guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix),
@@ -286,6 +287,15 @@ def forge_loader(sd, additional_state_dicts=None):
if component is not None:
huggingface_components[component_name] = component
# Fix Huggingface prediction type using estimated config detection
prediction_types = {
'EPS': 'epsilon',
'V_PREDICTION': 'v_prediction',
'EDM': 'edm',
}
if 'scheduler' in huggingface_components and hasattr(huggingface_components['scheduler'], 'config') and 'prediction_type' in huggingface_components['scheduler'].config:
huggingface_components['scheduler'].config.prediction_type = prediction_types.get(estimated_config.model_type.name, 'epsilon')
for M in possible_models:
if any(isinstance(estimated_config, x) for x in M.matched_guesses):
return M(estimated_config=estimated_config, huggingface_components=huggingface_components)