Config: Allow for interpreting config key with incorrect data type as missing key (for weirdly implemented layerwise RoPE settings in some models)

This commit is contained in:
turboderp
2026-03-01 03:16:32 +01:00
parent 489b3aab12
commit b0cfe46702
2 changed files with 20 additions and 3 deletions

View File

@@ -163,10 +163,22 @@ class Config(ABC):
return RopeSettings(
head_dim = self.head_dim,
rope_theta = read_dict(config_dict, float, theta_key, default_rope_theta),
rope_theta = read_dict(
config_dict,
float,
theta_key,
default_rope_theta,
wrong_type_as_missing = True
),
rope_scaling = read_dict(config_dict, dict, ["rope_scaling", "rope_parameters"], None),
rotary_dim = read_dict(config_dict, int, "rotary_dim", None),
partial_rotary_factor = read_dict(config_dict, float, ["partial_rotary_factor", "rope_parameters->partial_rotary_factor"], default_partial_rotary_factor),
partial_rotary_factor = read_dict(
config_dict,
float,
["partial_rotary_factor", "rope_parameters->partial_rotary_factor"],
default_partial_rotary_factor,
wrong_type_as_missing = True
),
max_position_embeddings = read_dict(config_dict, int, "max_position_embeddings", None),
original_max_position_embeddings = read_dict(config_dict, int, "original_max_position_embeddings", None),
rope_style = rope_style,

View File

@@ -54,6 +54,7 @@ def read_dict(
expected_types: type | list[type],
keys: str | list[str],
default = no_default,
wrong_type_as_missing: bool = False,
) -> T:
"""
Utility function to read typed value from (nested) dictionary
@@ -74,6 +75,9 @@ def read_dict(
Default value to return if the key isn't found, e.g. None. If this is the special value no_default
and no keys are matched, raise an exception instead.
:param wrong_type_as_missing:
Treat an existing key of the wrong type as a missing key and apply default value if possible
:return:
Requested value if key found, otherwise default value
"""
@@ -110,7 +114,8 @@ def read_dict(
x = int(x)
if isinstance(x, t):
return cast(T, x)
raise TypeError(f"Value for {key} is not of expected type: {expected_types}")
if not wrong_type_as_missing:
raise TypeError(f"Value for {key} is not of expected type: {expected_types}")
if default != no_default:
return default