diff --git a/exllamav3/architecture/qwen3_5.py b/exllamav3/architecture/qwen3_5.py index 6a4e4af..78cf869 100644 --- a/exllamav3/architecture/qwen3_5.py +++ b/exllamav3/architecture/qwen3_5.py @@ -16,6 +16,7 @@ from ..modules import ( Linear, GatedDeltaNet, GatedMLP, + DeepstackEmbed ) from ..modules.attn import prepare_for_attn from ..modules.gated_delta_net import prepare_for_recurrence @@ -205,102 +206,116 @@ class Qwen3_5BaseModel(Model): self.first_block_idx = len(self.modules) - self.modules += [ - TransformerBlock( - config = config, - key = f"{key_prefix}.layers.{idx}", - attn_norm = RMSNorm( + for idx in range(config.num_hidden_layers): + self.modules += [ + TransformerBlock( config = config, - key = f"{key_prefix}.layers.{idx}.input_layernorm", - rms_norm_eps = config.rms_norm_eps, - constant_bias = 1.0, - ), - attn = ( - GatedDeltaNet( + key = f"{key_prefix}.layers.{idx}", + attn_norm = RMSNorm( config = config, - key = f"{key_prefix}.layers.{idx}.linear_attn", - layer_idx = idx, - hidden_size = config.hidden_size, - k_head_dim = config.linear_key_head_dim, - v_head_dim = config.linear_value_head_dim, - num_k_heads = config.linear_num_key_heads, - num_v_heads = config.linear_num_value_heads, + key = f"{key_prefix}.layers.{idx}.input_layernorm", rms_norm_eps = config.rms_norm_eps, - conv_kernel_size = config.linear_conv_kernel_dim, - key_a_log = "A_log", - key_dt_bias = "dt_bias", - key_conv1d = "conv1d", - key_qkv = "in_proj_qkv", - key_z = "in_proj_z", - key_b = "in_proj_b", - key_a = "in_proj_a", - key_norm = "norm", - key_o = "out_proj", - qmap = "block.attn", - out_dtype = torch.float, - ) - if config.layer_types[idx] == "linear_attention" else - Attention( - config = config, - key = f"{key_prefix}.layers.{idx}.self_attn", - layer_idx = idx, - hidden_size = config.hidden_size, - head_dim = config.head_dim, - num_q_heads = config.num_q_heads, - num_kv_heads = config.num_kv_heads, - rope_settings = config.rope_settings, - sm_scale = None, - key_q = "q_proj", - key_k = "k_proj", - key_v = "v_proj", - key_o = "o_proj", - qmap = "block.attn", - q_norm = RMSNorm( + constant_bias = 1.0, + ), + attn = ( + GatedDeltaNet( config = config, - key = f"{key_prefix}.layers.{idx}.self_attn.q_norm", - rms_norm_eps = config.rms_norm_eps, - constant_bias = 1.0, - ), - k_norm = RMSNorm( - config = config, - key = f"{key_prefix}.layers.{idx}.self_attn.k_norm", - rms_norm_eps = config.rms_norm_eps, - constant_bias = 1.0, - ), - out_dtype = torch.float, - interleaved_gate = True, - ) - ), - mlp_norm = RMSNorm( - config = config, - key = f"{key_prefix}.layers.{idx}.post_attention_layernorm", - rms_norm_eps = config.rms_norm_eps, - constant_bias = 1.0, - ), - mlp = ( - BlockSparseMLP( - config = config, - key = f"{key_prefix}.layers.{idx}.mlp", - hidden_size = config.hidden_size, - intermediate_size = config.moe_intermediate_size, - num_experts = config.num_experts, - num_experts_per_tok = config.num_experts_per_tok, - key_up = "experts.{expert_idx}.up_proj", - key_gate = "experts.{expert_idx}.gate_proj", - key_down = "experts.{expert_idx}.down_proj", - key_gate_up_split = "experts.gate_up_proj", - key_down_split = "experts.down_proj", - key_routing_gate = "gate", - key_shared_gate = "shared_expert_gate", - transposed_load = False, - qmap = "block.mlp", - interm_dtype = torch.half, - out_dtype = torch.float, - shared_experts = GatedMLP( - config = config, - key = f"{key_prefix}.layers.{idx}.mlp.shared_expert", + key = f"{key_prefix}.layers.{idx}.linear_attn", + layer_idx = idx, hidden_size = config.hidden_size, - intermediate_size = config.shared_expert_intermediate_size, + k_head_dim = config.linear_key_head_dim, + v_head_dim = config.linear_value_head_dim, + num_k_heads = config.linear_num_key_heads, + num_v_heads = config.linear_num_value_heads, + rms_norm_eps = config.rms_norm_eps, + conv_kernel_size = config.linear_conv_kernel_dim, + key_a_log = "A_log", + key_dt_bias = "dt_bias", + key_conv1d = "conv1d", + key_qkv = "in_proj_qkv", + key_z = "in_proj_z", + key_b = "in_proj_b", + key_a = "in_proj_a", + key_norm = "norm", + key_o = "out_proj", + qmap = "block.attn", + out_dtype = torch.float, + ) + if config.layer_types[idx] == "linear_attention" else + Attention( + config = config, + key = f"{key_prefix}.layers.{idx}.self_attn", + layer_idx = idx, + hidden_size = config.hidden_size, + head_dim = config.head_dim, + num_q_heads = config.num_q_heads, + num_kv_heads = config.num_kv_heads, + rope_settings = config.rope_settings, + sm_scale = None, + key_q = "q_proj", + key_k = "k_proj", + key_v = "v_proj", + key_o = "o_proj", + qmap = "block.attn", + q_norm = RMSNorm( + config = config, + key = f"{key_prefix}.layers.{idx}.self_attn.q_norm", + rms_norm_eps = config.rms_norm_eps, + constant_bias = 1.0, + ), + k_norm = RMSNorm( + config = config, + key = f"{key_prefix}.layers.{idx}.self_attn.k_norm", + rms_norm_eps = config.rms_norm_eps, + constant_bias = 1.0, + ), + out_dtype = torch.float, + interleaved_gate = True, + ) + ), + mlp_norm = RMSNorm( + config = config, + key = f"{key_prefix}.layers.{idx}.post_attention_layernorm", + rms_norm_eps = config.rms_norm_eps, + constant_bias = 1.0, + ), + mlp = ( + BlockSparseMLP( + config = config, + key = f"{key_prefix}.layers.{idx}.mlp", + hidden_size = config.hidden_size, + intermediate_size = config.moe_intermediate_size, + num_experts = config.num_experts, + num_experts_per_tok = config.num_experts_per_tok, + key_up = "experts.{expert_idx}.up_proj", + key_gate = "experts.{expert_idx}.gate_proj", + key_down = "experts.{expert_idx}.down_proj", + key_gate_up_split = "experts.gate_up_proj", + key_down_split = "experts.down_proj", + key_routing_gate = "gate", + key_shared_gate = "shared_expert_gate", + transposed_load = False, + qmap = "block.mlp", + interm_dtype = torch.half, + out_dtype = torch.float, + shared_experts = GatedMLP( + config = config, + key = f"{key_prefix}.layers.{idx}.mlp.shared_expert", + hidden_size = config.hidden_size, + intermediate_size = config.shared_expert_intermediate_size, + key_up = "up_proj", + key_gate = "gate_proj", + key_down = "down_proj", + qmap = "block.mlp", + interm_dtype = torch.half, + out_dtype = torch.float, + ) + ) if use_moe else + GatedMLP( + config = config, + key = f"{key_prefix}.layers.{idx}.mlp", + hidden_size = config.hidden_size, + intermediate_size = config.intermediate_size, key_up = "up_proj", key_gate = "gate_proj", key_down = "down_proj", @@ -308,23 +323,19 @@ class Qwen3_5BaseModel(Model): interm_dtype = torch.half, out_dtype = torch.float, ) - ) if use_moe else - GatedMLP( - config = config, - key = f"{key_prefix}.layers.{idx}.mlp", - hidden_size = config.hidden_size, - intermediate_size = config.intermediate_size, - key_up = "up_proj", - key_gate = "gate_proj", - key_down = "down_proj", - qmap = "block.mlp", - interm_dtype = torch.half, - out_dtype = torch.float, - ) - ), - ) - for idx in range(config.num_hidden_layers) - ] + ), + ) + ] + + if config.vision and config.vision.deepstack_visual_indexes: + if idx < len(config.vision.deepstack_visual_indexes): + self.modules += [ + DeepstackEmbed( + config = config, + key = f"{key_prefix}.layers.{idx}.deepstack_embed", + deepstack_index = idx, + ) + ] self.last_kv_module_idx = len(self.modules) - 1