Add Qwen3_5ForCausalLM and Qwen3_5MoeForCausalLM

This commit is contained in:
turboderp
2026-03-11 21:00:23 +01:00
parent 1b9e58c9b5
commit e05f4636ee
2 changed files with 206 additions and 85 deletions

View File

@@ -29,7 +29,7 @@ from .phi3 import Phi3Model
from .qwen2 import Qwen2Model
from .qwen2_5_vl import Qwen2_5VLModel
from .qwen3 import Qwen3Model
from .qwen3_5 import Qwen3_5Model, Qwen3_5MoeModel
from .qwen3_5 import Qwen3_5Model, Qwen3_5MoeModel, Qwen3_5VLModel, Qwen3_5VLMoeModel
from .qwen3_moe import Qwen3MoeModel
from .qwen3_next import Qwen3NextModel
from .qwen3_vl import Qwen3VLModel
@@ -79,6 +79,8 @@ ARCHITECTURES = {
Qwen3Model,
Qwen3_5Model,
Qwen3_5MoeModel,
Qwen3_5VLModel,
Qwen3_5VLMoeModel,
Qwen3MoeModel,
Qwen3NextModel,
Qwen3VLModel,

View File

@@ -39,7 +39,87 @@ def read_qwen3_5_layer_types(config: Config, text_config_path: str, num_layers:
]
class Qwen3_5Config(Config):
class Qwen3_5VLBaseConfig(Config):
def __init__(
self,
directory: str,
text_cfg: str | None = "text_config",
text_model = None,
vision_model = None,
**kwargs,
):
super().__init__(
directory,
({"text": text_model} if text_model else {}) |
({"vision": vision_model} if vision_model else {}),
**kwargs
)
def pfx(key):
nonlocal text_cfg
return key if not text_cfg else f"{text_cfg}->{key}"
# Attention params
self.head_dim = self.read_cfg(int, pfx("head_dim"), None)
self.hidden_size = self.read_cfg(int, pfx("hidden_size"), no_default)
self.num_q_heads = self.read_cfg(int, pfx("num_attention_heads"), no_default)
self.num_kv_heads = self.read_cfg(int, pfx("num_key_value_heads"), self.num_q_heads)
self.full_attention_interval = self.read_cfg(int, pfx("full_attention_interval"), 4)
if not self.head_dim:
self.head_dim = self.hidden_size // self.num_q_heads
# Linear attn params
self.linear_conv_kernel_dim = self.read_cfg(int, pfx("linear_conv_kernel_dim"), 4)
self.linear_num_key_heads = self.read_cfg(int, pfx("linear_num_key_heads"), 16)
self.linear_num_value_heads = self.read_cfg(int, pfx("linear_num_value_heads"), 32)
self.linear_key_head_dim = self.read_cfg(int, pfx("linear_key_head_dim"), 128)
self.linear_value_head_dim = self.read_cfg(int, pfx("linear_value_head_dim"), 128)
# MLP params
self.assert_cfg(str, pfx("hidden_act"), "silu", True)
self.intermediate_size = self.read_cfg(int, pfx("intermediate_size"), no_default)
# Norms
self.rms_norm_eps = self.read_cfg(float, pfx("rms_norm_eps"), no_default)
# Layers
self.num_hidden_layers = self.read_cfg(int, pfx("num_hidden_layers"), no_default)
self.tie_word_embeddings = self.read_cfg(bool, "tie_word_embeddings", False)
self.layer_types = read_qwen3_5_layer_types(
self,
text_cfg,
self.num_hidden_layers,
self.full_attention_interval,
)
# RoPE
self.rope_settings = self.read_rope_settings_default(
RopeStyle.NEOX,
default_rope_theta = 10000000,
config_dict = self.read_cfg(dict, text_cfg, no_default) if text_cfg else None
)
# Vision model settings
if vision_model:
read_vision_config = self.read_cfg(dict, "vision_config", no_default)
self.vision = read_qwen3_vl_vision_config(read_vision_config)
prep_path = os.path.join(self.directory, "preprocessor_config.json")
with open(prep_path, encoding = "utf8") as f:
read_prep_config = json.load(f)
self.vision_pp = read_qwen3_vl_pp_config(read_prep_config)
self.vision_start_token_id = self.read_cfg(int, "vision_start_token_id", 151652)
self.vision_end_token_id = self.read_cfg(int, "vision_end_token_id", 151653)
else:
self.vision = None
self.vision_pp = None
class Qwen3_5VLConfig(Qwen3_5VLBaseConfig):
arch_string = "Qwen3_5ForConditionalGeneration"
def __init__(
@@ -49,38 +129,81 @@ class Qwen3_5Config(Config):
):
super().__init__(
directory,
{"text": Qwen3_5Model, "vision": Qwen3VLVisionModel},
"text_config",
Qwen3_5VLModel,
Qwen3VLVisionModel,
**kwargs
)
text_cfg = "text_config"
class Qwen3_5Config(Qwen3_5VLBaseConfig):
arch_string = "Qwen3_5ForCausalLM"
def __init__(
self,
directory: str,
**kwargs,
):
super().__init__(
directory,
None,
Qwen3_5Model,
None,
**kwargs
)
class Qwen3_5VLMoeBaseConfig(Config):
def __init__(
self,
directory: str,
text_cfg: str | None = "text_config",
text_model = None,
vision_model = None,
**kwargs,
):
super().__init__(
directory,
({"text": text_model} if text_model else {}) |
({"vision": vision_model} if vision_model else {}),
**kwargs
)
def pfx(key):
nonlocal text_cfg
return key if not text_cfg else f"{text_cfg}->{key}"
# Attention params
self.head_dim = self.read_cfg(int, f"{text_cfg}->head_dim", None)
self.hidden_size = self.read_cfg(int, f"{text_cfg}->hidden_size", no_default)
self.num_q_heads = self.read_cfg(int, f"{text_cfg}->num_attention_heads", no_default)
self.num_kv_heads = self.read_cfg(int, f"{text_cfg}->num_key_value_heads", self.num_q_heads)
self.full_attention_interval = self.read_cfg(int, f"{text_cfg}->full_attention_interval", 4)
self.head_dim = self.read_cfg(int, pfx("head_dim"), None)
self.hidden_size = self.read_cfg(int, pfx("hidden_size"), no_default)
self.num_q_heads = self.read_cfg(int, pfx("num_attention_heads"), no_default)
self.num_kv_heads = self.read_cfg(int, pfx("num_key_value_heads"), self.num_q_heads)
self.full_attention_interval = self.read_cfg(int, pfx("full_attention_interval"), 4)
if not self.head_dim:
self.head_dim = self.hidden_size // self.num_q_heads
# Linear attn params
self.linear_conv_kernel_dim = self.read_cfg(int, f"{text_cfg}->linear_conv_kernel_dim", 4)
self.linear_num_key_heads = self.read_cfg(int, f"{text_cfg}->linear_num_key_heads", 16)
self.linear_num_value_heads = self.read_cfg(int, f"{text_cfg}->linear_num_value_heads", 32)
self.linear_key_head_dim = self.read_cfg(int, f"{text_cfg}->linear_key_head_dim", 128)
self.linear_value_head_dim = self.read_cfg(int, f"{text_cfg}->linear_value_head_dim", 128)
self.linear_conv_kernel_dim = self.read_cfg(int, pfx("linear_conv_kernel_dim"), 4)
self.linear_num_key_heads = self.read_cfg(int, pfx("linear_num_key_heads"), 16)
self.linear_num_value_heads = self.read_cfg(int, pfx("linear_num_value_heads"), 32)
self.linear_key_head_dim = self.read_cfg(int, pfx("linear_key_head_dim"), 128)
self.linear_value_head_dim = self.read_cfg(int, pfx("linear_value_head_dim"), 128)
# MLP params
self.assert_cfg(str, f"{text_cfg}->hidden_act", "silu", True)
self.intermediate_size = self.read_cfg(int, f"{text_cfg}->intermediate_size", no_default)
self.assert_cfg(str, pfx("hidden_act"), "silu", True)
self.assert_cfg(bool, pfx("norm_topk_prob"), True, True)
self.moe_intermediate_size = self.read_cfg(int, pfx("moe_intermediate_size"), no_default)
self.num_experts = self.read_cfg(int, pfx("num_experts"), no_default)
self.num_experts_per_tok = self.read_cfg(int, pfx("num_experts_per_tok"), no_default)
self.shared_expert_intermediate_size = self.read_cfg(int, pfx("shared_expert_intermediate_size"), 512)
# Norms
self.rms_norm_eps = self.read_cfg(float, f"{text_cfg}->rms_norm_eps", no_default)
self.rms_norm_eps = self.read_cfg(float, pfx("rms_norm_eps"), no_default)
# Layers
self.num_hidden_layers = self.read_cfg(int, f"{text_cfg}->num_hidden_layers", no_default)
self.num_hidden_layers = self.read_cfg(int, pfx("num_hidden_layers"), no_default)
self.tie_word_embeddings = self.read_cfg(bool, "tie_word_embeddings", False)
self.layer_types = read_qwen3_5_layer_types(
self,
@@ -93,23 +216,27 @@ class Qwen3_5Config(Config):
self.rope_settings = self.read_rope_settings_default(
RopeStyle.NEOX,
default_rope_theta = 10000000,
config_dict = self.read_cfg(dict, text_cfg, no_default)
config_dict = self.read_cfg(dict, text_cfg, no_default) if text_cfg else None
)
# Vision model settings
read_vision_config = self.read_cfg(dict, "vision_config", no_default)
self.vision = read_qwen3_vl_vision_config(read_vision_config)
if vision_model:
read_vision_config = self.read_cfg(dict, "vision_config", no_default)
self.vision = read_qwen3_vl_vision_config(read_vision_config)
prep_path = os.path.join(self.directory, "preprocessor_config.json")
with open(prep_path, encoding = "utf8") as f:
read_prep_config = json.load(f)
self.vision_pp = read_qwen3_vl_pp_config(read_prep_config)
prep_path = os.path.join(self.directory, "preprocessor_config.json")
with open(prep_path, encoding = "utf8") as f:
read_prep_config = json.load(f)
self.vision_pp = read_qwen3_vl_pp_config(read_prep_config)
self.vision_start_token_id = self.read_cfg(int, "vision_start_token_id", 151652)
self.vision_end_token_id = self.read_cfg(int, "vision_end_token_id", 151653)
self.vision_start_token_id = self.read_cfg(int, "vision_start_token_id", 151652)
self.vision_end_token_id = self.read_cfg(int, "vision_end_token_id", 151653)
else:
self.vision = None
self.vision_pp = None
class Qwen3_5MoeConfig(Config):
class Qwen3_5VLMoeConfig(Qwen3_5VLMoeBaseConfig):
arch_string = "Qwen3_5MoeForConditionalGeneration"
def __init__(
@@ -119,75 +246,35 @@ class Qwen3_5MoeConfig(Config):
):
super().__init__(
directory,
{"text": Qwen3_5MoeModel, "vision": Qwen3VLVisionModel},
"text_config",
Qwen3_5VLMoeModel,
Qwen3VLVisionModel,
**kwargs
)
text_cfg = "text_config"
# Attention params
self.head_dim = self.read_cfg(int, f"{text_cfg}->head_dim", None)
self.hidden_size = self.read_cfg(int, f"{text_cfg}->hidden_size", no_default)
self.num_q_heads = self.read_cfg(int, f"{text_cfg}->num_attention_heads", no_default)
self.num_kv_heads = self.read_cfg(int, f"{text_cfg}->num_key_value_heads", self.num_q_heads)
self.full_attention_interval = self.read_cfg(int, f"{text_cfg}->full_attention_interval", 4)
class Qwen3_5MoeConfig(Qwen3_5VLMoeBaseConfig):
arch_string = "Qwen3_5MoeForCausalLM"
if not self.head_dim:
self.head_dim = self.hidden_size // self.num_q_heads
# Linear attn params
self.linear_conv_kernel_dim = self.read_cfg(int, f"{text_cfg}->linear_conv_kernel_dim", 4)
self.linear_num_key_heads = self.read_cfg(int, f"{text_cfg}->linear_num_key_heads", 16)
self.linear_num_value_heads = self.read_cfg(int, f"{text_cfg}->linear_num_value_heads", 32)
self.linear_key_head_dim = self.read_cfg(int, f"{text_cfg}->linear_key_head_dim", 128)
self.linear_value_head_dim = self.read_cfg(int, f"{text_cfg}->linear_value_head_dim", 128)
# MLP params
self.assert_cfg(str, f"{text_cfg}->hidden_act", "silu", True)
self.assert_cfg(bool, f"{text_cfg}->norm_topk_prob", True, True)
self.moe_intermediate_size = self.read_cfg(int, f"{text_cfg}->moe_intermediate_size", no_default)
self.num_experts = self.read_cfg(int, f"{text_cfg}->num_experts", no_default)
self.num_experts_per_tok = self.read_cfg(int, f"{text_cfg}->num_experts_per_tok", no_default)
self.shared_expert_intermediate_size = self.read_cfg(int, f"{text_cfg}->shared_expert_intermediate_size", 512)
# Norms
self.rms_norm_eps = self.read_cfg(float, f"{text_cfg}->rms_norm_eps", no_default)
# Layers
self.num_hidden_layers = self.read_cfg(int, f"{text_cfg}->num_hidden_layers", no_default)
self.tie_word_embeddings = self.read_cfg(bool, "tie_word_embeddings", False)
self.layer_types = read_qwen3_5_layer_types(
self,
text_cfg,
self.num_hidden_layers,
self.full_attention_interval,
def __init__(
self,
directory: str,
**kwargs,
):
super().__init__(
directory,
None,
Qwen3_5MoeModel,
None,
**kwargs
)
# RoPE
self.rope_settings = self.read_rope_settings_default(
RopeStyle.NEOX,
default_rope_theta = 10000000,
config_dict = self.read_cfg(dict, text_cfg, no_default)
)
# Vision model settings
read_vision_config = self.read_cfg(dict, "vision_config", no_default)
self.vision = read_qwen3_vl_vision_config(read_vision_config)
prep_path = os.path.join(self.directory, "preprocessor_config.json")
with open(prep_path, encoding = "utf8") as f:
read_prep_config = json.load(f)
self.vision_pp = read_qwen3_vl_pp_config(read_prep_config)
self.vision_start_token_id = self.read_cfg(int, "vision_start_token_id", 151652)
self.vision_end_token_id = self.read_cfg(int, "vision_end_token_id", 151653)
class Qwen3_5BaseModel(Model):
def __init__(
self,
config: Qwen3_5Config | Qwen3_5MoeConfig,
config: Qwen3_5VLConfig | Qwen3_5VLMoeConfig | Qwen3_5Config | Qwen3_5MoeConfig,
key_prefix: str,
use_moe: bool,
**kwargs
@@ -396,6 +483,22 @@ class Qwen3_5BaseModel(Model):
return p
class Qwen3_5VLModel(Qwen3_5BaseModel):
config_class = Qwen3_5VLConfig
def __init__(
self,
config: Qwen3_5VLConfig,
**kwargs
):
super().__init__(
config = config,
key_prefix = "model.language_model",
use_moe = False,
**kwargs,
)
class Qwen3_5Model(Qwen3_5BaseModel):
config_class = Qwen3_5Config
@@ -412,6 +515,22 @@ class Qwen3_5Model(Qwen3_5BaseModel):
)
class Qwen3_5VLMoeModel(Qwen3_5BaseModel):
config_class = Qwen3_5VLMoeConfig
def __init__(
self,
config: Qwen3_5VLMoeConfig,
**kwargs
):
super().__init__(
config = config,
key_prefix = "model.language_model",
use_moe = True,
**kwargs,
)
class Qwen3_5MoeModel(Qwen3_5BaseModel):
config_class = Qwen3_5MoeConfig