diff --git a/backend/loader.py b/backend/loader.py index 54dbb38e..a0460493 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -243,6 +243,7 @@ def split_state_dict(sd, additional_state_dicts: list = None): sd = replace_state_dict(sd, asd, guess) guess.clip_target = guess.clip_target(sd) + guess.model_type = guess.model_type(sd) state_dict = { guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix), @@ -286,6 +287,15 @@ def forge_loader(sd, additional_state_dicts=None): if component is not None: huggingface_components[component_name] = component + # Fix Huggingface prediction type using estimated config detection + prediction_types = { + 'EPS': 'epsilon', + 'V_PREDICTION': 'v_prediction', + 'EDM': 'edm', + } + if 'scheduler' in huggingface_components and hasattr(huggingface_components['scheduler'], 'config') and 'prediction_type' in huggingface_components['scheduler'].config: + huggingface_components['scheduler'].config.prediction_type = prediction_types.get(estimated_config.model_type.name, 'epsilon') + for M in possible_models: if any(isinstance(estimated_config, x) for x in M.matched_guesses): return M(estimated_config=estimated_config, huggingface_components=huggingface_components)