diff --git a/exllamav3/model/config.py b/exllamav3/model/config.py index 2f2330a..1d48ce4 100644 --- a/exllamav3/model/config.py +++ b/exllamav3/model/config.py @@ -163,10 +163,22 @@ class Config(ABC): return RopeSettings( head_dim = self.head_dim, - rope_theta = read_dict(config_dict, float, theta_key, default_rope_theta), + rope_theta = read_dict( + config_dict, + float, + theta_key, + default_rope_theta, + wrong_type_as_missing = True + ), rope_scaling = read_dict(config_dict, dict, ["rope_scaling", "rope_parameters"], None), rotary_dim = read_dict(config_dict, int, "rotary_dim", None), - partial_rotary_factor = read_dict(config_dict, float, ["partial_rotary_factor", "rope_parameters->partial_rotary_factor"], default_partial_rotary_factor), + partial_rotary_factor = read_dict( + config_dict, + float, + ["partial_rotary_factor", "rope_parameters->partial_rotary_factor"], + default_partial_rotary_factor, + wrong_type_as_missing = True + ), max_position_embeddings = read_dict(config_dict, int, "max_position_embeddings", None), original_max_position_embeddings = read_dict(config_dict, int, "original_max_position_embeddings", None), rope_style = rope_style, diff --git a/exllamav3/util/file.py b/exllamav3/util/file.py index e0624e9..7c0b649 100644 --- a/exllamav3/util/file.py +++ b/exllamav3/util/file.py @@ -54,6 +54,7 @@ def read_dict( expected_types: type | list[type], keys: str | list[str], default = no_default, + wrong_type_as_missing: bool = False, ) -> T: """ Utility function to read typed value from (nested) dictionary @@ -74,6 +75,9 @@ def read_dict( Default value to return if the key isn't found, e.g. None. If this is the special value no_default and no keys are matched, raise an exception instead. + :param wrong_type_as_missing: + Treat an existing key of the wrong type as a missing key and apply default value if possible + :return: Requested value if key found, otherwise default value """ @@ -110,7 +114,8 @@ def read_dict( x = int(x) if isinstance(x, t): return cast(T, x) - raise TypeError(f"Value for {key} is not of expected type: {expected_types}") + if not wrong_type_as_missing: + raise TypeError(f"Value for {key} is not of expected type: {expected_types}") if default != no_default: return default