Refactor architecture logic for code reuse between LLM/VLM

This commit is contained in:
turboderp
2024-11-03 22:34:25 +01:00
parent d92ff8d9e4
commit 6cdfa5e52f
22 changed files with 652 additions and 573 deletions

View File

@@ -1,3 +1,4 @@
from dataclasses import dataclass, field
from enum import IntEnum
# Common keys
@@ -105,7 +106,7 @@ class ExLlamaV2ArchParams:
"""
Get architecture definition from model config. If the architecture isn't recognized, defaults to Llama
architecture.
W
:param arch_string:
Architecture string from config.json
@@ -116,115 +117,135 @@ W
self.arch_string = arch_string
arch_recognized = False
# Keys to expect in model dict
self.expect_keys = []
# Keys to expect in model dict, per layer
self.layer_keys = []
# Map tensors in HF model to standard keys
self.keymap = None
# Fused tensors
self.fused_qkv_key = None
self.fused_mlp_key_12 = None
self.fused_mlp_key_3 = None
@dataclass
class Params:
keys: dict = field(default_factory = lambda: {
"norm_eps": "rms_norm_eps",
"norm_1": ".input_layernorm",
"norm_1_post": None,
"fused_qkv": None,
"mlp_gate": ".mlp.gate_proj",
"mlp_up": ".mlp.up_proj",
"mlp_down": ".mlp.down_proj",
"lm_head": "lm_head",
"norm_2": ".post_attention_layernorm",
"norm_2_post": None,
"fused_mlp_12": None,
"fused_mlp_3": None,
"learned_pos_emb": None
})
# Alternate packing scheme for fused QKV tensor (InternLM2 quirk)
self.fused_qkv_altpack = False
# Compute logit scale from `dim_model_base` key in config.json (MiniCPM quirk)
logit_scale_basedim = False
# Learned position embeddings
self.learned_pos_emb_key = None
# Clamp hidden states to FP16 range
clamp_hidden_states = False
# Default multiplier for MLP inner dim (GPT2 quirk)
self.default_inner_dim_mult = None
# Upcast hidden state to FP32 before adding to residual stream
residual_stream_fp32 = False
# Compute logit scale from `dim_model_base` key in config.json (MiniCPM quirk)
self.logit_scale_basedim = False
# Normalize embeddings (Gemma quirk)
normalize_embeddings = False
# Constant bias for layernorm (Gemma quirk)
norm_constant_bias = 0
# Alternate packing scheme for fused QKV tensor (InternLM2 quirk)
fused_qkv_altpack = False
# SWA required by architecture
swa = False
alternating_swa = False
# Model only works with eager attention
eager_attn_only = False
# Expect bias for linear layers
attention_bias_qkv = False
attention_bias_o = False
mlp_bias = False
# Default multiplier for MLP inner dim (GPT2 quirk)
default_inner_dim_mult = None
# Use gated MLP
mlp_gate = True
# Use block-sparse MLP
is_moe = False
# Use parallel decoder blocks (Cohere quirk)
parallel_decoder_blocks = False
# Use MQA, effectively num_key_value_heads = 1 (GPTBigCode quirk)
mqa = False
# Model is incoherent without BOS at the start of the context
requires_bos = False
# Scale attn weights (GPT2 quirk, not important for inference)
scale_attn_weights = False
# Model implementation works in tensor-parallel mode
supports_tp = False
# Activation function
mlp_act_func = "silu"
# Layer norm type
norm = "rmsnorm"
# RoPE style
rope_style = RopeStyle.NEOX
# Expected keys
expect_keys: list[str] = field(default_factory = lambda: [])
layer_keys: list[str] = field(default_factory = lambda: [])
# Component models
self.lm_prefix = ""
self.vt_prefix = ""
self.mmp_prefix = ""
self.lm = Params()
self.mmp = Params()
self.vt = Params()
self.mmp.keys.update({
"norm_1": None,
"norm_1_post": None,
"norm_2": None,
"norm_2_post": None,
"fused_mlp_12": None,
"fused_mlp_3": None,
})
self.mmp.rope_style = RopeStyle.NONE
# Tensors are transposed in original model weights
self.orig_weights_transposed = False
# Post norm keys
self.norm_key_1_post = None
self.norm_key_2_post = None
# SWA required by architecture
self.swa = False
self.alternating_swa = False
# Model only works with eager attention
self.eager_attn_only = False
# Clamp hidden states to FP16 range
self.clamp_hidden_states = False
# Upcast hidden state to FP32 before adding to residual stream
self.residual_stream_fp32 = False
# Expect bias for linear layers
self.attention_bias_qkv = False
self.attention_bias_o = False
self.mlp_bias = False
# Use gated MLP
self.mlp_gate = True
# Use block-sparse MLP
self.is_moe = False
# Normalize embeddings (Gemma quirk)
self.normalize_embeddings = False
# Constant bias for layernorm (Gemma quirk)
self.norm_constant_bias = 0
# Use parallel decoder blocks (Cohere quirk)
self.parallel_decoder_blocks = False
# Model is incoherent without BOS at the start of the context
self.requires_bos = False
# Use MQA, effectively num_key_valu_heads = 1 (GPTBigCode quirk)
self.mqa = False
# Scale attn weights (GPT2 quirk, not important for inference)
self.scale_attn_weights = False
# Model implementation works in tensor-parallel mode
self.supports_tp = False
# Mistral
if arch_string == "MistralForCausalLM":
arch_recognized = True
self.layer_keys += \
self.lm.layer_keys += \
layer_keys_llama_norms + \
layer_keys_llama_attn + \
layer_keys_llama_mlp
self.expect_keys += \
self.lm.expect_keys += \
expect_keys_llama
self.norm_eps_key = "rms_norm_eps"
self.mlp_key_gate = ".mlp.gate_proj"
self.mlp_key_up = ".mlp.up_proj"
self.mlp_key_down = ".mlp.down_proj"
self.lm_head_key = "lm_head"
self.norm_key_1 = ".input_layernorm"
self.norm_key_2 = ".post_attention_layernorm"
self.mlp_act_func = "silu"
self.norm = "rmsnorm"
self.rope_style = RopeStyle.NEOX
self.supports_tp = True
self.lm.supports_tp = True
# Mixtral
if arch_string == "MixtralForCausalLM":
arch_recognized = True
self.layer_keys += \
self.lm.layer_keys += \
layer_keys_llama_norms + \
layer_keys_llama_attn + \
layer_keys_mixtral_mlp
self.expect_keys += \
self.lm.expect_keys += \
expect_keys_llama
self.norm_eps_key = "rms_norm_eps"
self.mlp_key_gate = ".block_sparse_moe.experts.*.w1"
@@ -238,372 +259,307 @@ W
self.norm = "rmsnorm"
self.rope_style = RopeStyle.NEOX
self.is_moe = True
self.lm.keys.update({
"mlp_gate": ".block_sparse_moe.experts.*.w1",
"mlp_up": ".block_sparse_moe.experts.*.w3",
"mlp_down": ".block_sparse_moe.experts.*.w2",
"mlp_expert_gate": ".block_sparse_moe.gate"
})
self.lm.is_moe = True
# Yi
if arch_string == "YiForCausalLM":
arch_recognized = True
self.layer_keys += \
self.lm.layer_keys += \
layer_keys_yi_norms + \
layer_keys_llama_attn + \
layer_keys_llama_mlp
self.expect_keys += \
self.lm.expect_keys += \
expect_keys_llama
self.norm_eps_key = "rms_norm_eps"
self.mlp_key_gate = ".mlp.gate_proj"
self.mlp_key_up = ".mlp.up_proj"
self.mlp_key_down = ".mlp.down_proj"
self.mlp_act_func = "silu"
self.norm_key_1 = ".ln1"
self.norm_key_2 = ".ln2"
self.norm = "rmsnorm"
self.lm_head_key = "lm_head"
self.rope_style = RopeStyle.NEOX
self.lm.keys.update({
"norm_1": ".ln1",
"norm_2": ".ln2",
})
# Orion
if arch_string == "OrionForCausalLM":
arch_recognized = True
self.layer_keys += \
self.lm.layer_keys += \
layer_keys_llama_norms + \
layer_keys_llama_attn + \
layer_keys_llama_mlp
self.expect_keys += \
self.lm.expect_keys += \
expect_keys_llama
self.norm_eps_key = "rms_norm_eps"
self.mlp_key_gate = ".mlp.gate_proj"
self.mlp_key_up = ".mlp.up_proj"
self.mlp_key_down = ".mlp.down_proj"
self.lm_head_key = "lm_head"
self.norm_key_1 = ".input_layernorm"
self.norm_key_2 = ".post_attention_layernorm"
self.mlp_act_func = "silu"
self.norm = "layernorm"
self.rope_style = RopeStyle.NEOX
self.lm.norm = "layernorm"
# Qwen2 (1.5)
if arch_string == "Qwen2ForCausalLM":
arch_recognized = True
self.layer_keys += \
self.lm.layer_keys += \
layer_keys_llama_norms + \
layer_keys_llama_attn + \
layer_keys_llama_mlp
self.expect_keys += \
self.lm.expect_keys += \
expect_keys_llama
self.norm_eps_key = "rms_norm_eps"
self.mlp_key_gate = ".mlp.gate_proj"
self.mlp_key_up = ".mlp.up_proj"
self.mlp_key_down = ".mlp.down_proj"
self.lm_head_key = "lm_head"
self.norm_key_1 = ".input_layernorm"
self.norm_key_2 = ".post_attention_layernorm"
self.mlp_act_func = "silu"
self.norm = "rmsnorm"
self.rope_style = RopeStyle.NEOX
self.attention_bias_qkv = True
self.supports_tp = True
self.lm.attention_bias_qkv = True
self.lm.supports_tp = True
# Gemma
if arch_string == "GemmaForCausalLM":
arch_recognized = True
self.layer_keys += \
self.lm.layer_keys += \
layer_keys_llama_norms + \
layer_keys_llama_attn + \
layer_keys_llama_mlp
self.expect_keys += \
self.lm.expect_keys += \
expect_keys_gemma
self.norm_eps_key = "rms_norm_eps"
self.mlp_key_gate = ".mlp.gate_proj"
self.mlp_key_up = ".mlp.up_proj"
self.mlp_key_down = ".mlp.down_proj"
self.lm_head_key = "model.embed_tokens"
self.norm_key_1 = ".input_layernorm"
self.norm_key_2 = ".post_attention_layernorm"
self.mlp_act_func = "gelu"
self.norm = "rmsnorm"
self.rope_style = RopeStyle.NEOX
self.normalize_embeddings = True
self.norm_constant_bias = 1
self.requires_bos = True
self.lm.keys.update({
"lm_head": "model.embed_tokens",
})
self.lm.mlp_act_func = "gelu"
self.lm.normalize_embeddings = True
self.lm.norm_constant_bias = 1
self.lm.requires_bos = True
# Gemma2
if arch_string == "Gemma2ForCausalLM":
arch_recognized = True
self.layer_keys += \
self.lm.layer_keys += \
layer_keys_gemma2_norms + \
layer_keys_llama_attn + \
layer_keys_llama_mlp
self.expect_keys += \
self.lm.expect_keys += \
expect_keys_gemma
self.norm_eps_key = "rms_norm_eps"
self.mlp_key_gate = ".mlp.gate_proj"
self.mlp_key_up = ".mlp.up_proj"
self.mlp_key_down = ".mlp.down_proj"
self.lm_head_key = "model.embed_tokens"
self.norm_key_1 = ".input_layernorm"
self.norm_key_1_post = ".post_attention_layernorm"
self.norm_key_2 = ".pre_feedforward_layernorm"
self.norm_key_2_post = ".post_feedforward_layernorm"
self.mlp_act_func = "gelu"
self.norm = "rmsnorm"
self.rope_style = RopeStyle.NEOX
self.normalize_embeddings = True
self.norm_constant_bias = 1
self.requires_bos = True
self.pre_post_layernorm = True
self.alternating_swa = True
self.residual_stream_fp32 = True
self.lm.keys.update({
"lm_head": "model.embed_tokens",
"norm_1": ".input_layernorm",
"norm_1_post": ".post_attention_layernorm",
"norm_2": ".pre_feedforward_layernorm",
"norm_2_post": ".post_feedforward_layernorm",
})
self.lm.mlp_act_func = "gelu"
self.lm.normalize_embeddings = True
self.lm.norm_constant_bias = 1
self.lm.requires_bos = True
self.lm.alternating_swa = True
self.lm.residual_stream_fp32 = True
# StarCoder2
if arch_string == "Starcoder2ForCausalLM":
arch_recognized = True
self.layer_keys += \
self.lm.layer_keys += \
layer_keys_llama_norms + \
layer_keys_llama_attn + \
layer_keys_starcoder2_mlp
self.expect_keys += \
self.lm.expect_keys += \
expect_keys_starcoder2
self.norm_eps_key = "norm_epsilon"
self.mlp_key_up = ".mlp.c_fc"
self.mlp_key_down = ".mlp.c_proj"
self.lm_head_key = "model.embed_tokens"
self.norm_key_1 = ".input_layernorm"
self.norm_key_2 = ".post_attention_layernorm"
self.mlp_act_func = "gelu"
self.norm = "layernorm"
self.rope_style = RopeStyle.NEOX
self.attention_bias_qkv = True
self.attention_bias_o = True
self.mlp_bias = True
self.mlp_gate = False
self.lm.keys.update({
"mlp_gate": None,
"mlp_up": ".mlp.c_fc",
"mlp_down": ".mlp.c_proj",
"lm_head": "model.embed_tokens",
})
self.lm.mlp_act_func = "gelu"
self.lm.norm = "layernorm"
self.lm.attention_bias_qkv = True
self.lm.attention_bias_o = True
self.lm.mlp_bias = True
self.lm.mlp_gate = False
# GemMoE
if arch_string == "GemmoeForCausalLM":
arch_recognized = True
print(f" !! Warning, Gemmoe support is experimental and has not been fully tested")
self.layer_keys += \
self.lm.layer_keys += \
layer_keys_llama_norms + \
layer_keys_llama_attn + \
layer_keys_mixtral_mlp
self.expect_keys += \
self.lm.expect_keys += \
expect_keys_gemma
self.norm_eps_key = "rms_norm_eps"
self.mlp_key_gate = ".block_sparse_moe.experts.*.w1"
self.mlp_key_up = ".block_sparse_moe.experts.*.w3"
self.mlp_key_down = ".block_sparse_moe.experts.*.w2"
self.mlp_key_expert_gate = ".block_sparse_moe.gate"
self.lm_head_key = "model.embed_tokens"
self.norm_key_1 = ".input_layernorm"
self.norm_key_2 = ".post_attention_layernorm"
self.mlp_act_func = "gelu"
self.norm = "rmsnorm"
self.rope_style = RopeStyle.NEOX
self.normalize_embeddings = True
self.norm_constant_bias = 1
self.is_moe = True
self.requires_bos = True
self.lm.keys.update({
"mlp_gate": ".block_sparse_moe.experts.*.w1",
"mlp_up": ".block_sparse_moe.experts.*.w3",
"mlp_down": ".block_sparse_moe.experts.*.w2",
"mlp_expert_gate": ".block_sparse_moe.gate",
"lm_head": "model.embed_tokens",
})
self.lm.mlp_act_func = "gelu"
self.lm.normalize_embeddings = True
self.lm.norm_constant_bias = 1
self.lm.is_moe = True
self.lm.requires_bos = True
# Cohere
if arch_string == "CohereForCausalLM":
arch_recognized = True
self.layer_keys += \
self.lm.layer_keys += \
layer_keys_cohere_norms + \
layer_keys_llama_attn + \
layer_keys_llama_mlp
self.expect_keys += \
self.lm.expect_keys += \
expect_keys_gemma
self.norm_eps_key = "layer_norm_eps"
self.mlp_key_gate = ".mlp.gate_proj"
self.mlp_key_up = ".mlp.up_proj"
self.mlp_key_down = ".mlp.down_proj"
self.lm_head_key = "model.embed_tokens"
self.norm_key_1 = ".input_layernorm"
self.norm_key_2 = None
self.mlp_act_func = "silu"
self.norm = "layernorm"
self.rope_style = RopeStyle.GPTJ
self.parallel_decoder_blocks = True
self.requires_bos = True
self.lm.keys.update({
"norm_eps": "layer_norm_eps",
"lm_head": "model.embed_tokens",
"norm_1": ".input_layernorm",
"norm_2": None,
})
self.lm.norm = "layernorm"
self.lm.rope_style = RopeStyle.GPTJ
self.lm.parallel_decoder_blocks = True
self.lm.requires_bos = True
# DBRX
if arch_string == "DbrxForCausalLM":
arch_recognized = True
self.keymap = dbrx_keymap
self.layer_keys += \
self.lm.layer_keys += \
layer_keys_llama_norms + \
layer_keys_dbrx_attn + \
layer_keys_dbrx_mlp
self.expect_keys += \
self.lm.expect_keys += \
expect_keys_llama
self.norm_eps_key = None
self.mlp_key_gate = ".block_sparse_moe.experts.*.w1"
self.mlp_key_up = ".block_sparse_moe.experts.*.v1"
self.mlp_key_down = ".block_sparse_moe.experts.*.w2"
self.mlp_key_expert_gate = ".block_sparse_moe.gate"
self.lm_head_key = "lm_head"
self.norm_key_1 = ".input_layernorm"
self.norm_key_2 = ".post_attention_layernorm"
self.mlp_act_func = "silu"
self.norm = "layernorm"
self.rope_style = RopeStyle.NEOX
self.fused_qkv_key = "Wqkv"
self.is_moe = True
self.lm.keys.update({
"norm_eps": None,
"mlp_gate": ".block_sparse_moe.experts.*.w1",
"mlp_up": ".block_sparse_moe.experts.*.v1",
"mlp_down": ".block_sparse_moe.experts.*.w2",
"mlp_expert_gate": ".block_sparse_moe.gate",
"lm_head": "model.embed_tokens",
"fused_qkv": "Wqkv",
})
self.lm.is_moe = True
# Phi3
if arch_string == "Phi3ForCausalLM":
arch_recognized = True
self.layer_keys += \
self.lm.layer_keys += \
layer_keys_llama_norms + \
layer_keys_phi3_attn + \
layer_keys_phi3_mlp
self.expect_keys += \
self.lm.expect_keys += \
expect_keys_llama
self.norm_eps_key = "rms_norm_eps"
self.mlp_key_gate = ".mlp.gate_proj"
self.mlp_key_up = ".mlp.up_proj"
self.mlp_key_down = ".mlp.down_proj"
self.lm_head_key = "lm_head"
self.norm_key_1 = ".input_layernorm"
self.norm_key_2 = ".post_attention_layernorm"
self.fused_qkv_key = "qkv_proj"
self.fused_mlp_key_12 = "gate_up_proj"
self.mlp_act_func = "silu"
self.norm = "rmsnorm"
self.rope_style = RopeStyle.NEOX
self.lm.keys.update({
"fused_qkv": "qkv_proj",
"fused_mlp_12": "gate_up_proj",
})
# GPTBigCode
if arch_string == "GPTBigCodeForCausalLM":
arch_recognized = True
self.keymap = bigcode_keymap
self.layer_keys += \
self.lm.layer_keys += \
layer_keys_gpt2_norms + \
layer_keys_gpt2_attn + \
layer_keys_gpt2_mlp
self.expect_keys += \
self.lm.expect_keys += \
expect_keys_gpt2
self.norm_eps_key = "layer_norm_epsilon"
self.mlp_key_gate = None
self.mlp_key_up = ".mlp.c_fc"
self.mlp_key_down = ".mlp.c_proj"
self.lm_head_key = "model.embed_tokens"
self.norm_key_1 = ".ln_1"
self.norm_key_2 = ".ln_2"
self.fused_qkv_key = "c_attn"
self.learned_pos_emb_key = "model.wpe"
self.mlp_act_func = "gelu"
self.norm = "layernorm"
self.rope_style = RopeStyle.NONE
self.mqa = True
self.attention_bias_qkv = True
self.attention_bias_o = True
self.mlp_bias = True
self.mlp_gate = False
self.lm.keys.update({
"norm_eps": "layer_norm_epsilon",
"mlp_gate": None,
"mlp_up": ".mlp.c_fc",
"mlp_down": ".mlp.c_proj",
"lm_head": "model.embed_tokens",
"norm_1": ".ln_1",
"norm_2": ".ln_2",
"fused_qkv": "c_attn",
"learned_pos_emb": "model.wpe",
})
self.lm.mlp_act_func = "gelu"
self.lm.norm = "layernorm"
self.lm.rope_style = RopeStyle.NONE
self.lm.mqa = True
self.lm.attention_bias_qkv = True
self.lm.attention_bias_o = True
self.lm.mlp_bias = True
self.lm.mlp_gate = False
# GPT2
if arch_string == "GPT2LMHeadModel":
arch_recognized = True
self.keymap = gpt2_keymap
self.layer_keys += \
self.lm.layer_keys += \
layer_keys_gpt2_norms + \
layer_keys_gpt2_attn + \
layer_keys_gpt2_mlp
self.expect_keys += \
self.lm.expect_keys += \
expect_keys_gpt2
self.norm_eps_key = "layer_norm_epsilon"
self.mlp_key_gate = None
self.mlp_key_up = ".mlp.c_fc"
self.mlp_key_down = ".mlp.c_proj"
self.lm_head_key = "model.embed_tokens"
self.norm_key_1 = ".ln_1"
self.norm_key_2 = ".ln_2"
self.fused_qkv_key = "c_attn"
self.learned_pos_emb_key = "model.wpe"
self.mlp_act_func = "gelu"
self.norm = "layernorm"
self.rope_style = RopeStyle.NONE
self.default_inner_dim_mult = 4
self.lm.keys.update({
"norm_eps": "layer_norm_epsilon",
"mlp_gate": None,
"mlp_up": ".mlp.c_fc",
"mlp_down": ".mlp.c_proj",
"lm_head": "model.embed_tokens",
"norm_1": ".ln_1",
"norm_2": ".ln_2",
"fused_qkv": "c_attn",
"learned_pos_emb": "model.wpe",
})
self.lm.mlp_act_func = "gelu"
self.lm.norm = "layernorm"
self.lm.rope_style = RopeStyle.NONE
self.lm.default_inner_dim_mult = 4
self.lm.attention_bias_qkv = True
self.lm.attention_bias_o = True
self.lm.mlp_bias = True
self.lm.mlp_gate = False
self.orig_weights_transposed = True
self.attention_bias_qkv = True
self.attention_bias_o = True
self.mlp_bias = True
self.mlp_gate = False
# MiniCPM
if arch_string == "MiniCPMForCausalLM":
arch_recognized = True
self.layer_keys += \
self.lm.layer_keys += \
layer_keys_llama_norms + \
layer_keys_llama_attn + \
layer_keys_llama_mlp
self.expect_keys += \
self.lm.expect_keys += \
expect_keys_llama
self.norm_eps_key = "rms_norm_eps"
self.mlp_key_gate = ".mlp.gate_proj"
self.mlp_key_up = ".mlp.up_proj"
self.mlp_key_down = ".mlp.down_proj"
self.lm_head_key = "lm_head"
self.norm_key_1 = ".input_layernorm"
self.norm_key_2 = ".post_attention_layernorm"
self.mlp_act_func = "silu"
self.norm = "rmsnorm"
self.rope_style = RopeStyle.NEOX
self.logit_scale_basedim = True
self.lm.logit_scale_basedim = True
# InternLM2
if arch_string == "InternLM2ForCausalLM":
arch_recognized = True
self.keymap = internlm2_keymap
self.layer_keys += \
self.lm.layer_keys += \
layer_keys_internlm2_norms + \
layer_keys_internlm2_attn + \
layer_keys_internlm2_mlp
self.expect_keys += \
self.lm.expect_keys += \
expect_keys_llama
self.norm_eps_key = "rms_norm_eps"
self.mlp_key_gate = ".feed_forward.w1"
self.mlp_key_up = ".feed_forward.w3"
self.mlp_key_down = ".feed_forward.w2"
self.lm_head_key = "lm_head"
self.norm_key_1 = ".attention_norm"
self.norm_key_2 = ".ffn_norm"
self.fused_qkv_key = "wqkv"
self.mlp_act_func = "silu"
self.norm = "rmsnorm"
self.rope_style = RopeStyle.NEOX
self.fused_qkv_altpack = True
self.lm.keys.update({
"mlp_gate": ".feed_forward.w1",
"mlp_up": ".feed_forward.w3",
"mlp_down": ".feed_forward.w2",
"norm_1": ".attention_norm",
"norm_2": ".ffn_norm",
"fused_qkv": "wqkv",
})
self.lm.fused_qkv_altpack = True
# Index
if arch_string == "IndexForCausalLM":
arch_recognized = True
self.layer_keys += \
self.lm.layer_keys += \
layer_keys_llama_norms + \
layer_keys_llama_attn + \
layer_keys_llama_mlp
self.expect_keys += \
self.lm.expect_keys += \
expect_keys_llama
self.norm_eps_key = "rms_norm_eps"
self.mlp_key_gate = ".mlp.gate_proj"
self.mlp_key_up = ".mlp.up_proj"
self.mlp_key_down = ".mlp.down_proj"
self.lm_head_key = "lm_head"
self.norm_key_1 = ".input_layernorm"
self.norm_key_2 = ".post_attention_layernorm"
self.mlp_act_func = "silu"
self.norm = "rmsnorm"
self.rope_style = RopeStyle.NEOX
# Llama (default + fallback)
@@ -612,49 +568,41 @@ W
print(f" !! Loading as LlamaForCausalLM")
self.arch_string = "LlamaForCausalLM"
if not arch_recognized:
self.layer_keys += \
self.lm.layer_keys += \
layer_keys_llama_norms + \
layer_keys_llama_attn + \
layer_keys_llama_mlp
self.expect_keys += \
self.lm.expect_keys += \
expect_keys_llama
self.norm_eps_key = "rms_norm_eps"
self.mlp_key_gate = ".mlp.gate_proj"
self.mlp_key_up = ".mlp.up_proj"
self.mlp_key_down = ".mlp.down_proj"
self.lm_head_key = "lm_head"
self.norm_key_1 = ".input_layernorm"
self.norm_key_2 = ".post_attention_layernorm"
self.mlp_act_func = "silu"
self.norm = "rmsnorm"
self.rope_style = RopeStyle.NEOX
self.supports_tp = True
self.lm.supports_tp = True
# Arch overrides
if read_config.get("attention_bias", False):
self.attention_bias_qkv = True
self.attention_bias_o = True
self.lm.attention_bias_qkv = True
self.lm.attention_bias_o = True
if read_config.get("mlp_bias", False):
self.mlp_bias = True
self.lm.mlp_bias = True
if read_config.get("tie_word_embeddings", False):
if ["lm_head"] in self.expect_keys:
self.expect_keys.remove(["lm_head"])
self.lm_head_key = "model.embed_tokens"
if ["lm_head"] in self.lm.expect_keys:
self.lm.expect_keys.remove(["lm_head"])
self.lm.keys.update({
"lm_head": "model.embed_tokens",
})
# Sanity checks
if self.residual_stream_fp32:
assert self.norm_key_1_post and self.norm_key_2_post, \
if self.lm.residual_stream_fp32:
assert self.lm.keys["norm_1_post"] and self.lm.keys["norm_2_post"], \
"FP32 residual stream only implemented for arch with post layernorms"
def make_fused_mlp(self):
for x in layer_keys_llama_mlp: self.layer_keys.remove(x)
self.layer_keys += layer_keys_llama_mlp_swiglu
self.fused_mlp_key_12 = layer_keys_llama_mlp_swiglu[0][0]
self.fused_mlp_key_3 = layer_keys_llama_mlp_swiglu[1][0]
for x in layer_keys_llama_mlp: self.lm.layer_keys.remove(x)
self.lm.layer_keys += layer_keys_llama_mlp_swiglu
self.lm.keys.update({
"fused_mlp_12": layer_keys_llama_mlp_swiglu[0][0],
"fused_mlp_3": layer_keys_llama_mlp_swiglu[1][0],
})

View File

@@ -134,11 +134,15 @@ class ExLlamaV2Attention(ExLlamaV2Module):
layer_idx: int,
has_norm: bool = True,
has_residual: bool = True,
sliding_window: int = 0
sliding_window: int = 0,
archparams = None
):
super().__init__(model, key)
super().__init__(model, key, archparams)
cfg = self.model.config
ap = self.archparams
km = self.archparams.keys
self.is_tp = False
self.tp_dq_size = None
@@ -151,27 +155,28 @@ class ExLlamaV2Attention(ExLlamaV2Module):
hidden_size = cfg.hidden_size
if self.has_norm:
if cfg.arch.norm == "layernorm":
self.pre_layernorm = ExLlamaV2LayerNorm(model, key + cfg.arch.norm_key_1)
self.post_layernorm = ExLlamaV2LayerNorm(model, key + cfg.arch.norm_key_1_post) if cfg.arch.norm_key_1_post else None
elif cfg.arch.norm == "rmsnorm":
self.pre_layernorm = ExLlamaV2RMSNorm(model, key + cfg.arch.norm_key_1)
self.post_layernorm = ExLlamaV2RMSNorm(model, key + cfg.arch.norm_key_1_post) if cfg.arch.norm_key_1_post else None
if self.has_norm and (km["norm_1"] or km["norm_1_post"]):
if ap.norm == "layernorm":
self.pre_layernorm = ExLlamaV2LayerNorm(model, key + km["norm_1"], archparams)
self.post_layernorm = ExLlamaV2LayerNorm(model, key + km["norm_1_post"], archparams) if km["norm_1_post"] else None
elif ap.norm == "rmsnorm":
self.pre_layernorm = ExLlamaV2RMSNorm(model, key + km["norm_1"], archparams)
self.post_layernorm = ExLlamaV2RMSNorm(model, key + km["norm_1_post"], archparams) if km["norm_1_post"] else None
else:
self.pre_layernorm = None
self.post_layernorm = None
self.has_norm = False
f_a = 0
f_b = cfg.num_attention_heads * cfg.head_dim
f_c = f_b + cfg.num_key_value_heads * cfg.head_dim
f_d = f_c + cfg.num_key_value_heads * cfg.head_dim
f_key = (key + ".self_attn." + cfg.arch.fused_qkv_key) if cfg.arch.fused_qkv_key else None
f_key = (key + ".self_attn." + km["fused_qkv"]) if km["fused_qkv"] else None
self.q_proj = ExLlamaV2Linear(model, key + ".self_attn.q_proj", hidden_size, cfg.num_attention_heads * cfg.head_dim, cfg.arch.attention_bias_qkv, f_key = f_key, f_beg = f_a, f_end = f_b, altpack_qkv = cfg.arch.fused_qkv_altpack)
self.k_proj = ExLlamaV2Linear(model, key + ".self_attn.k_proj", hidden_size, cfg.num_key_value_heads * cfg.head_dim, cfg.arch.attention_bias_qkv, f_key = f_key, f_beg = f_b, f_end = f_c, altpack_qkv = cfg.arch.fused_qkv_altpack)
self.v_proj = ExLlamaV2Linear(model, key + ".self_attn.v_proj", hidden_size, cfg.num_key_value_heads * cfg.head_dim, cfg.arch.attention_bias_qkv, f_key = f_key, f_beg = f_c, f_end = f_d, altpack_qkv = cfg.arch.fused_qkv_altpack)
self.o_proj = ExLlamaV2Linear(model, key + ".self_attn.o_proj", cfg.num_attention_heads * cfg.head_dim, hidden_size, cfg.arch.attention_bias_o, prescale = cfg.scale_depth)
self.q_proj = ExLlamaV2Linear(model, key + ".self_attn.q_proj", hidden_size, cfg.num_attention_heads * cfg.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_a, f_end = f_b, altpack_qkv = ap.fused_qkv_altpack)
self.k_proj = ExLlamaV2Linear(model, key + ".self_attn.k_proj", hidden_size, cfg.num_key_value_heads * cfg.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_b, f_end = f_c, altpack_qkv = ap.fused_qkv_altpack)
self.v_proj = ExLlamaV2Linear(model, key + ".self_attn.v_proj", hidden_size, cfg.num_key_value_heads * cfg.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_c, f_end = f_d, altpack_qkv = ap.fused_qkv_altpack)
self.o_proj = ExLlamaV2Linear(model, key + ".self_attn.o_proj", cfg.num_attention_heads * cfg.head_dim, hidden_size, ap.attention_bias_o, prescale = cfg.scale_depth)
if cfg.use_qk_norm:
self.q_norm = ExLlamaV2HeadNorm(model, key + ".self_attn.q_norm", cfg.num_attention_heads, cfg.head_dim)
@@ -180,10 +185,12 @@ class ExLlamaV2Attention(ExLlamaV2Module):
self.q_norm = None
self.k_norm = None
self.submodules = [self.q_proj,
self.k_proj,
self.v_proj,
self.o_proj]
self.submodules = [
self.q_proj,
self.k_proj,
self.v_proj,
self.o_proj
]
if self.pre_layernorm:
self.submodules += [self.pre_layernorm]
if self.post_layernorm:
@@ -294,12 +301,12 @@ class ExLlamaV2Attention(ExLlamaV2Module):
cfg.head_dim,
cfg.max_seq_len,
self.has_residual,
cfg.arch.rope_style.value,
self.archparams.rope_style.value,
q_norm,
k_norm,
post_norm_weight,
post_norm_bias,
cfg.arch.residual_stream_fp32,
self.archparams.residual_stream_fp32,
not cfg.no_graphs
)
@@ -522,7 +529,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
if cfg.use_qk_norm:
q = self.q_norm.forward(q)
k = self.k_norm.forward(k)
if cfg.arch.rope_style != RopeStyle.NONE:
if self.archparams.rope_style != RopeStyle.NONE:
for t, heads in [(q, cfg.num_attention_heads), (k, cfg.num_key_value_heads)]:
ext_c.rope_(
t,
@@ -532,7 +539,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
heads,
cfg.head_dim,
cache_seqlens,
cfg.arch.rope_style == RopeStyle.NEOX
self.archparams.rope_style == RopeStyle.NEOX
)
if attn_params.is_sequential:
k_ = k_cache_f[:, attn_params.first_index : attn_params.first_index + q_len, :, :]
@@ -659,7 +666,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
self.v_proj.q_handle,
self.o_proj.q_handle,
cfg.head_dim,
int(cfg.arch.rope_style),
int(self.archparams.rope_style),
batch_size,
q_len,
sin,
@@ -728,7 +735,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
assert False, "TP not implemented for QK norm" # TODO: ...
# q = self.q_norm.forward(q)
# k = self.k_norm.forward(k)
if cfg.arch.rope_style != RopeStyle.NONE:
if self.archparams.rope_style != RopeStyle.NONE:
for idx, (dev, a, b) in enumerate(split):
context = self.model.get_device_context(dev)
torch.cuda.set_stream(context.stream)
@@ -741,7 +748,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
(b - a) * heads,
cfg.head_dim,
attn_params.cache_seqlens_tp[idx],
cfg.arch.rope_style == RopeStyle.NEOX
self.archparams.rope_style == RopeStyle.NEOX
)
if attn_params.is_sequential:
k_ = [x[:, attn_params.first_index: attn_params.first_index + q_len, :, :] for x in k_cache_f]
@@ -1113,7 +1120,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
pass_lora_temp
)
if cfg.arch.clamp_hidden_states:
if self.archparams.clamp_hidden_states:
hidden_states.clamp_(-65504, 65504)
return hidden_states
@@ -1166,7 +1173,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
self.v_proj.q_handle,
self.o_proj.q_handle,
cfg.head_dim,
int(cfg.arch.rope_style),
int(self.archparams.rope_style),
batch_size,
q_len,
sin,
@@ -1223,7 +1230,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
else:
k_cache, v_cache = None, None
if cfg.arch.rope_style != RopeStyle.NONE:
if self.archparams.rope_style != RopeStyle.NONE:
for idx, (dev, a, b) in enumerate(split):
context = self.model.get_device_context(dev)
torch.cuda.set_stream(context.stream)
@@ -1236,7 +1243,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
(b - a) * heads,
cfg.head_dim,
attn_params.position_offsets_tp[idx] if attn_params.position_offsets is not None else none_tensor,
cfg.arch.rope_style == RopeStyle.NEOX
self.archparams.rope_style == RopeStyle.NEOX
)
attn_outputs = []
@@ -1370,9 +1377,9 @@ class ExLlamaV2Attention(ExLlamaV2Module):
else:
position_offsets = none_tensor
if cfg.arch.rope_style != RopeStyle.NONE:
ext_c.rope_(query_states, constants.sin, constants.cos, past_len, num_attention_heads, head_dim, position_offsets, cfg.arch.rope_style == RopeStyle.NEOX)
ext_c.rope_(key_states, constants.sin, constants.cos, past_len, num_key_value_heads, head_dim, position_offsets, cfg.arch.rope_style == RopeStyle.NEOX)
if self.archparams.rope_style != RopeStyle.NONE:
ext_c.rope_(query_states, constants.sin, constants.cos, past_len, num_attention_heads, head_dim, position_offsets, self.archparams.rope_style == RopeStyle.NEOX)
ext_c.rope_(key_states, constants.sin, constants.cos, past_len, num_key_value_heads, head_dim, position_offsets, self.archparams.rope_style == RopeStyle.NEOX)
# Add keys and values to cache
@@ -1417,15 +1424,15 @@ class ExLlamaV2Attention(ExLlamaV2Module):
# Post layernorm
if self.post_layernorm:
attn_proj = self.post_layernorm.forward(attn_proj, output_fp32 = cfg.arch.residual_stream_fp32)
attn_proj = self.post_layernorm.forward(attn_proj, output_fp32 = self.archparams.residual_stream_fp32)
# Add residual connection
hidden_states = (attn_proj + residual) if self.has_residual else attn_proj
if cfg.arch.residual_stream_fp32:
if self.archparams.residual_stream_fp32:
hidden_states = hidden_states.float()
elif cfg.arch.clamp_hidden_states:
elif self.archparams.clamp_hidden_states:
hidden_states.clamp_(-65504, 65504)
if intermediates:

View File

@@ -10,12 +10,21 @@ from typing import Any, Dict, List, TypeVar, Union, cast
T = TypeVar('T')
no_default = object()
def read(input_dict: dict[str, Any], expected_type: type | list[type], keys: str | list[str], default = no_default) -> T:
def read(
input_dict: dict[str, Any],
expected_type: type | list[type],
keys: str | list[str],
default = no_default,
opt_subkey: str | None = None
) -> T:
expected_types = expected_type if isinstance(expected_type, list) else [expected_type]
if isinstance(keys, str): keys = [keys]
if opt_subkey is not None:
keys = keys + [opt_subkey + "->" + k for k in keys]
for key in keys:
input_dict_s = input_dict
@@ -115,6 +124,13 @@ class ExLlamaV2Config:
yarn_rope_original_max_position_embeddings: int | None
checkpoint_fused_mlp: bool
checkpoint_offset_qzeros: bool
vision_model_type: str | None
vision_head_dim: int | None
vision_hidden_size: int | None
vision_hidden_act: int | None
vision_patch_size: int | None
vision_rope_theta: float | None
vision_feature_layer: int | None
# Deprecated fields, kept for compatibiltiy
@@ -204,36 +220,51 @@ class ExLlamaV2Config:
self.bos_token_id = read(read_config, int, "bos_token_id", None) # 1
self.eos_token_id = read(read_config, [int, list], "eos_token_id", None) # 2
self.pad_token_id = read(read_config, int, "pad_token_id", None) # 0
self.vocab_size = read(read_config, int, "vocab_size")
self.vocab_size = read(read_config, int, "vocab_size", opt_subkey = "text_config")
if isinstance(self.eos_token_id, list):
self.eos_token_id = self.eos_token_id[0] # TODO: Figure out a way to maybe use all the EOS tokens somehow
# Standard params
self.initializer_range = read(read_config, float, ["initializer_range"])
self.num_hidden_layers = read(read_config, int, ["num_hidden_layers", "n_layers", "n_layer"])
self.initializer_range = read(read_config, float, ["initializer_range"], 0.02)
self.num_hidden_layers = read(read_config, int, ["num_hidden_layers", "n_layers", "n_layer"], opt_subkey = "text_config")
# Norm params
if self.arch.norm_eps_key:
self.norm_eps = read(read_config, float, self.arch.norm_eps_key)
if self.arch.lm.keys["norm_eps"]:
self.norm_eps = read(read_config, float, self.arch.lm.keys["norm_eps"], opt_subkey = "text_config")
else:
self.norm_eps = 1e-5 # Torch default
# Model dimensions
self.hidden_size = read(read_config, int, ["hidden_size", "d_model", "n_embd"])
self.hidden_size = read(read_config, int, ["hidden_size", "d_model", "n_embd"], opt_subkey = "text_config")
# Attn params
self.num_attention_heads = read(read_config, int, ["num_attention_heads", "n_heads", "n_head"])
self.head_dim = read(read_config, int, "head_dim", self.hidden_size // self.num_attention_heads)
self.num_attention_heads = read(read_config, int, ["num_attention_heads", "n_heads", "n_head"], 0, opt_subkey = "text_config")
self.head_dim = read(
read_config,
int,
"head_dim",
(self.hidden_size // self.num_attention_heads) if self.num_attention_heads else no_default,
opt_subkey = "text_config"
)
if self.arch.mqa:
if not self.num_attention_heads:
self.num_attention_heads = self.hidden_size // self.head_dim
if self.arch.lm.mqa:
self.num_key_value_heads = 1
else:
self.num_key_value_heads = read(read_config, int, ["num_key_value_heads", "attn_config->kv_n_heads"], self.num_attention_heads)
self.num_key_value_heads = read(
read_config,
int,
["num_key_value_heads", "attn_config->kv_n_heads"],
self.num_attention_heads,
opt_subkey = "text_config",
)
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
self.use_qk_norm = read(read_config, bool, ["use_qk_norm"], False)
@@ -241,19 +272,25 @@ class ExLlamaV2Config:
# MLP params
if self.arch.default_inner_dim_mult is not None:
default_intermediate_size = self.arch.default_inner_dim_mult * self.hidden_size
if self.arch.lm.default_inner_dim_mult is not None:
default_intermediate_size = self.arch.lm.default_inner_dim_mult * self.hidden_size
else:
default_intermediate_size = no_default
self.intermediate_size = read(read_config, int, ["intermediate_size", "ffn_config->ffn_hidden_size", "n_inner"], default_intermediate_size)
self.intermediate_size = read(
read_config,
int,
["intermediate_size", "ffn_config->ffn_hidden_size", "n_inner"],
default_intermediate_size,
opt_subkey = "text_config",
)
self.num_experts = read(read_config, int, ["num_local_experts", "ffn_config->moe_num_experts"], None)
self.num_experts_per_token = read(read_config, int,["num_experts_per_tok", "ffn_config->moe_top_k"], None)
# Logit/embedding/residual scale
self.logit_scale = read(read_config, float, "logit_scale", 1)
if self.arch.logit_scale_basedim:
if self.arch.lm.logit_scale_basedim:
dim_model_base = read(read_config, int, "dim_model_base", self.hidden_size)
self.logit_scale /= (self.hidden_size / dim_model_base)
@@ -273,16 +310,24 @@ class ExLlamaV2Config:
# Positional embeddings
self.rotary_embedding_base = read(read_config, float, ["rope_theta", "attn_config->rope_theta"], 10000.0)
self.rotary_embedding_base = read(
read_config,
float,
["rope_theta", "attn_config->rope_theta"],
10000.0,
opt_subkey = "text_config",
)
self.max_seq_len = read(read_config, int,["max_sequence_length",
"model_max_length",
"max_position_embeddings",
"max_seq_len",
"n_positions"], 2048)
self.max_seq_len = read(
read_config,
int,
["max_sequence_length", "model_max_length", "max_position_embeddings", "max_seq_len", "n_positions"],
2048,
opt_subkey = "text_config"
)
self.original_max_seq_len = self.max_seq_len
self.sliding_window = read(read_config, int, ["sliding_window", "sliding_window_size"], 0)
self.sliding_window = read(read_config, int, ["sliding_window", "sliding_window_size"], 0, opt_subkey = "text_config")
rs = read(read_config, dict, "rope_scaling", None)
if rs:
@@ -345,37 +390,51 @@ class ExLlamaV2Config:
# Make sure we found all the layers we need
expect_keys = self.arch.expect_keys.copy()
def check_keys(archparams, prefix):
if not self.num_experts or self.num_experts == 1:
per_layer_keys = self.arch.layer_keys
else:
per_layer_keys = set()
for expert_idx in range(self.num_experts):
for k in self.arch.layer_keys:
skt = [sk.replace(".*.", f".{expert_idx}.") for sk in k]
per_layer_keys.add(tuple(skt))
per_layer_keys = list(per_layer_keys)
expect_keys = archparams.expect_keys.copy()
for layer_idx in range(self.num_hidden_layers):
for ks in per_layer_keys:
prefixes = [f"model.layers.{layer_idx}.{k}" for k in ks]
expect_keys.append(prefixes)
if not self.num_experts or self.num_experts == 1:
per_layer_keys = archparams.layer_keys
else:
per_layer_keys = set()
for expert_idx in range(self.num_experts):
for k in archparams.layer_keys:
skt = [sk.replace(".*.", f".{expert_idx}.") for sk in k]
per_layer_keys.add(tuple(skt))
per_layer_keys = list(per_layer_keys)
all_keys = set(self.tensor_file_map.keys())
suffixes = [".q_weight", ".qweight", ".weight", ""]
for layer_idx in range(self.num_hidden_layers):
for ks in per_layer_keys:
prefixes = [f"model.layers.{layer_idx}.{k}" for k in ks]
expect_keys.append(prefixes)
for prefixes in expect_keys:
match = False
for prefix in prefixes:
for suffix in suffixes:
if (prefix + suffix) in all_keys:
match = True
break
if self.arch.lm_prefix:
expect_keys = [
[prefix + k for k in k2]
for k2 in expect_keys
]
all_keys = set(self.tensor_file_map.keys())
suffixes = [".q_weight", ".qweight", ".weight", ""]
for prefixes in expect_keys:
match = False
for prefix in prefixes:
for suffix in suffixes:
if (prefix + suffix) in all_keys:
match = True
break
if match: break
if match: break
if match: break
if not match:
raise ValueError(f" ## Could not find {prefix}.* in model")
if not match:
raise ValueError(f" ## Could not find {prefix}.* in model")
check_keys(self.arch.lm, self.arch.lm_prefix)
check_keys(self.arch.mmp, self.arch.mmp_prefix)
check_keys(self.arch.vt, self.arch.vt_prefix)
# Cleanup
cleanup_stfiles()
@@ -391,7 +450,7 @@ class ExLlamaV2Config:
warnings = []
if self.arch.eager_attn_only:
if self.arch.lm.eager_attn_only:
warnings.append(" !! Warning: Architecture currently supports only eager attention")
if not warn_only:
warnings.append(" !! Warning: flash-attn, xformers and SDPA are disabled")
@@ -406,7 +465,7 @@ class ExLlamaV2Config:
if self.attn_logit_softcapping and not has_flash_attn_with_softcap:
warnings.append(" !! Warning: model requires softcap, not supported in installed version of flash-attn")
disable = True
if (self.arch.swa or self.arch.alternating_swa) and not has_flash_attn_with_window:
if (self.arch.lm.swa or self.arch.lm.alternating_swa) and not has_flash_attn_with_window:
warnings.append(" !! Warning: model requires SWA, not supported in installed version of flash-attn")
disable = True
if disable and not warn_only:
@@ -418,7 +477,7 @@ class ExLlamaV2Config:
if self.attn_logit_softcapping:
warnings.append(" !! Warning: model requires softcap, not supported in xformers")
disable = True
if self.arch.swa or self.arch.alternating_swa:
if self.arch.lm.swa or self.arch.lm.alternating_swa:
warnings.append(" !! Warning: model requires SWA, not supported in xformers")
disable = True
if disable and not warn_only:

View File

@@ -90,7 +90,7 @@ def compile_model(job, save_fn, model):
if isinstance(module, ExLlamaV2MLP):
has_gate = model.config.arch.mlp_gate
has_gate = model.config.arch.lm.mlp_gate
d = get_f_module(job, module.pre_layernorm)
if d: out_dict.update(d); current_size += _dsize(d)
d = get_f_module(job, module.post_layernorm)
@@ -110,7 +110,7 @@ def compile_model(job, save_fn, model):
if isinstance(module, ExLlamaV2ParallelDecoder):
has_gate = model.config.arch.mlp_gate
has_gate = model.config.arch.lm.mlp_gate
has_qk_norm = model.config.use_qk_norm
d = get_f_module(job, module.input_layernorm); out_dict.update(d); current_size += _dsize(d)
d = get_q_module(job, module.attn.q_proj); out_dict.update(d); current_size += _dsize(d)

View File

@@ -208,7 +208,7 @@ model.load(lazy = True)
# Limit context length if necessary
if model.config.arch.rope_style == RopeStyle.NONE:
if model.config.arch.lm.rope_style == RopeStyle.NONE:
max_ctx = model.config.max_seq_len
if job["length"] > max_ctx:
print (f" !! Warning: Reducing calibration length to model max context: {max_ctx}")

View File

@@ -204,7 +204,7 @@ def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_p
def measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_params, reuse_h_up_proj = None):
has_gate = module.model.config.arch.mlp_gate
has_gate = module.model.config.arch.lm.mlp_gate
qjobs, qmaps = get_qparams_reduced(qparams_mlp, not has_gate)
results = []
@@ -490,7 +490,7 @@ def measure_quant(job, save_fn, model, hidden_state_offload_layers):
elif isinstance(module, ExLlamaV2MLP):
mode = "mlp"
has_gate = module.model.config.arch.mlp_gate
has_gate = module.model.config.arch.lm.mlp_gate
if has_gate: quantizers["gate_proj"] = AdaptiveGPTQ(module.gate_proj.linear)
quantizers["up_proj"] = AdaptiveGPTQ(module.up_proj.linear)
quantizers["down_proj"] = AdaptiveGPTQ(module.down_proj.linear)
@@ -508,7 +508,7 @@ def measure_quant(job, save_fn, model, hidden_state_offload_layers):
quantizers["k_proj"] = AdaptiveGPTQ(module.attn.k_proj.linear)
quantizers["v_proj"] = AdaptiveGPTQ(module.attn.v_proj.linear)
quantizers["o_proj"] = AdaptiveGPTQ(module.attn.o_proj.linear)
has_gate = module.model.config.arch.mlp_gate
has_gate = module.model.config.arch.lm.mlp_gate
if has_gate: quantizers["gate_proj"] = AdaptiveGPTQ(module.mlp.gate_proj.linear)
quantizers["up_proj"] = AdaptiveGPTQ(module.mlp.up_proj.linear)
quantizers["down_proj"] = AdaptiveGPTQ(module.mlp.down_proj.linear)

View File

@@ -9,7 +9,7 @@ def optimize(job, save_fn, model):
cfg = model.config
has_gate = cfg.arch.mlp_gate
has_gate = cfg.arch.lm.mlp_gate
if has_gate: mlp_key_gate = cfg.arch.mlp_key_gate
mlp_key_up = cfg.arch.mlp_key_up
mlp_key_down = cfg.arch.mlp_key_down
@@ -38,7 +38,7 @@ def optimize(job, save_fn, model):
key_v = key + ".self_attn.v_proj"
key_o = key + ".self_attn.o_proj"
if not cfg.arch.is_moe:
if not cfg.arch.lm.is_moe:
if has_gate: key_g = key + mlp_key_gate
key_u = key + mlp_key_up
key_d = key + mlp_key_down
@@ -83,7 +83,7 @@ def optimize(job, save_fn, model):
params = []
for i in range(num_layers):
if cfg.arch.parallel_decoder_blocks:
if cfg.arch.lm.parallel_decoder_blocks:
m1 = measurement["model.layers." + str(i) + ".parallel_decoder"]["attn"]
m2 = measurement["model.layers." + str(i) + ".parallel_decoder"]["mlp"]
else:

View File

@@ -153,7 +153,7 @@ def quant_attn(job, module, hidden_states, target_states, quantizers, attn_param
def quant_mlp(job, module, hidden_states, target_states, quantizers, attn_params, strat, reuse_h_up_proj = None):
has_mlp = module.model.config.arch.mlp_gate
has_mlp = module.model.config.arch.lm.mlp_gate
if reuse_h_up_proj is not None:
quantizers["up_proj"].reuse_h(quantizers[reuse_h_up_proj])
@@ -311,7 +311,7 @@ def quant(job, save_fn, model):
elif isinstance(module, ExLlamaV2MLP):
mode = "mlp"
has_mlp = model.config.arch.mlp_gate
has_mlp = model.config.arch.lm.mlp_gate
# testc(module, hidden_states, hidden_i_states, module.post_attention_layernorm, [module.gate_proj, module.up_proj])
if has_mlp: quantizers["gate_proj"] = AdaptiveGPTQ(module.gate_proj.linear)
quantizers["up_proj"] = AdaptiveGPTQ(module.up_proj.linear)
@@ -341,7 +341,7 @@ def quant(job, save_fn, model):
quantizers["k_proj"] = AdaptiveGPTQ(module.attn.k_proj.linear)
quantizers["v_proj"] = AdaptiveGPTQ(module.attn.v_proj.linear)
quantizers["o_proj"] = AdaptiveGPTQ(module.attn.o_proj.linear)
has_gate = module.model.config.arch.mlp_gate
has_gate = module.model.config.arch.lm.mlp_gate
if has_gate: quantizers["gate_proj"] = AdaptiveGPTQ(module.mlp.gate_proj.linear)
quantizers["up_proj"] = AdaptiveGPTQ(module.mlp.up_proj.linear)
quantizers["down_proj"] = AdaptiveGPTQ(module.mlp.down_proj.linear)

View File

@@ -50,7 +50,8 @@ class ExLlamaV2DeviceContext:
self,
model: ExLlamaV2,
device_idx: int,
scratch_bytes: int
scratch_bytes: int,
archparams = None
):
self.model = model
self.device_idx = device_idx
@@ -58,6 +59,7 @@ class ExLlamaV2DeviceContext:
self.scratch = None
self.scratch_bytes = scratch_bytes
self.scratch_idx = 0
self.archparams = archparams or model.config.arch.lm
# Create streams (only one per device)
@@ -140,7 +142,7 @@ class ExLlamaV2DeviceContext:
device = _torch_device(self.device_idx)
cfg = self.model.config
if cfg.arch.rope_style == RopeStyle.NONE:
if self.archparams.rope_style == RopeStyle.NONE:
self.sin = torch.zeros((1,), device = device, dtype = torch.half)
self.cos = self.sin
return
@@ -254,9 +256,9 @@ class ExLlamaV2DeviceContext:
if scale != 1.0: t /= scale
freqs = torch.einsum("i,j->ij", t, inv_freq)
if cfg.arch.rope_style == RopeStyle.NEOX:
if self.archparams.rope_style == RopeStyle.NEOX:
emb = torch.cat((freqs, freqs), dim=-1)
elif cfg.arch.rope_style == RopeStyle.GPTJ:
elif self.archparams.rope_style == RopeStyle.GPTJ:
emb = torch.repeat_interleave(freqs, 2, dim=-1)
else:
raise ValueError()

View File

@@ -23,9 +23,10 @@ class ExLlamaV2Embedding(ExLlamaV2Module):
def __init__(
self,
model: ExLlamaV2,
key: str
key: str,
archparams = None
):
super().__init__(model, key)
super().__init__(model, key, archparams)
self.is_tp = False
@@ -156,9 +157,9 @@ class ExLlamaV2Embedding(ExLlamaV2Module):
# Normalization
if cfg.arch.residual_stream_fp32:
if self.archparams.residual_stream_fp32:
combined_embeddings = combined_embeddings.float()
if cfg.arch.normalize_embeddings:
if self.archparams.normalize_embeddings:
combined_embeddings *= cfg.hidden_size ** 0.5
# Extract indexed embeddings and insert in-place
@@ -179,9 +180,9 @@ class ExLlamaV2Embedding(ExLlamaV2Module):
else:
hidden_states = self.embedding(hidden_states)
if cfg.arch.residual_stream_fp32:
if self.archparams.residual_stream_fp32:
hidden_states = hidden_states.float()
if cfg.arch.normalize_embeddings:
if self.archparams.normalize_embeddings:
hidden_states *= cfg.hidden_size ** 0.5
# Move to pinned temp buffer for TP

View File

@@ -26,9 +26,10 @@ class ExLlamaV2HeadNorm(ExLlamaV2Module):
model: ExLlamaV2,
key: str,
num_heads: int,
head_dim: int
head_dim: int,
archparams = None
):
super().__init__(model, key)
super().__init__(model, key, archparams)
self.layernorm = None
self.weight = None

View File

@@ -21,9 +21,10 @@ class ExLlamaV2LayerNorm(ExLlamaV2Module):
def __init__(
self,
model: ExLlamaV2,
key: str
key: str,
archparams = None
):
super().__init__(model, key)
super().__init__(model, key, archparams)
self.layernorm = None
self.weight = None

View File

@@ -67,9 +67,10 @@ class ExLlamaV2Linear(ExLlamaV2Module):
f_end: int = None,
is_sub_module: bool = True,
altpack_qkv: bool = False,
normalize_unq: bool = False
normalize_unq: bool = False,
archparams = None
):
super().__init__(model, key)
super().__init__(model, key, archparams)
self.is_sub_module = is_sub_module

View File

@@ -5,6 +5,7 @@ import os, json
from safetensors.torch import load_file as safe_load_file
from torch import load as load_file
import torch
import math
from exllamav2.compat import safe_move_tensor
from typing import TYPE_CHECKING
@@ -58,7 +59,7 @@ class ExLlamaV2Lora:
# Compatibility check
assert not self.model.config.arch.residual_stream_fp32, \
assert not self.model.config.arch.lm.residual_stream_fp32, \
"LoRAs not (yet) supported for models with FP32 residual stream"
# Grab relevant items from LoRA config
@@ -88,7 +89,7 @@ class ExLlamaV2Lora:
tensor = f[key]
# Find target
if key.endswith(f'{self.config.arch.lm_head_key}.weight'):
if key.endswith(f'{self.config.arch.lm.keys["lm_head"]}.weight'):
if tensor.dtype == torch.bfloat16:
tensor = tensor.to(torch.float16)
elif tensor.dtype == torch.float32:

View File

@@ -38,15 +38,29 @@ class ExLlamaV2MLP(ExLlamaV2Module):
tp_dq_size: list[int] | None
def __init__(self,
model: ExLlamaV2,
key: str,
layer_idx: int,
has_norm: bool = True,
has_residual: bool = True):
super().__init__(model, key)
def __init__(
self,
model: ExLlamaV2,
key: str,
layer_idx: int,
has_norm: bool = True,
has_residual: bool = True,
archparams = None,
in_features: int | None = None,
out_features: int | None = None,
interm_features: int | None = None,
):
super().__init__(model, key, archparams)
cfg = self.model.config
ap = self.archparams
km = self.archparams.keys
if in_features is None: in_features = cfg.hidden_size
if out_features is None: out_features = cfg.hidden_size
if interm_features is None: interm_features = cfg.intermediate_size
self.in_features = in_features
self.out_features = out_features
self.interm_features = interm_features
self.is_tp = False
self.tp_dq_size = None
@@ -59,23 +73,23 @@ class ExLlamaV2MLP(ExLlamaV2Module):
self.temp_lora_size = 0
f_a = 0
f_b = cfg.intermediate_size
f_c = f_b + cfg.intermediate_size
f_key = (key + ".mlp." + cfg.arch.fused_mlp_key_12) if cfg.arch.fused_mlp_key_12 else None
f_b = interm_features
f_c = f_b + interm_features
f_key = (key + ".mlp." + km["fused_mlp_12"]) if km["fused_mlp_12"] else None
if self.has_norm:
if cfg.arch.norm == "layernorm":
self.pre_layernorm = ExLlamaV2LayerNorm(model, key + cfg.arch.norm_key_2)
self.post_layernorm = ExLlamaV2LayerNorm(model, key + cfg.arch.norm_key_2_post) if cfg.arch.norm_key_2_post else None
elif cfg.arch.norm == "rmsnorm":
self.pre_layernorm = ExLlamaV2RMSNorm(model, key + cfg.arch.norm_key_2)
self.post_layernorm = ExLlamaV2RMSNorm(model, key + cfg.arch.norm_key_2_post) if cfg.arch.norm_key_2_post else None
if self.has_norm and (km["norm_2"] or km["norm_2_post"]):
if ap.norm == "layernorm":
self.pre_layernorm = ExLlamaV2LayerNorm(model, key + km["norm_2"])
self.post_layernorm = ExLlamaV2LayerNorm(model, key + km["norm_2_post"]) if km["norm_2_post"] else None
elif ap.norm == "rmsnorm":
self.pre_layernorm = ExLlamaV2RMSNorm(model, key + km["norm_2"])
self.post_layernorm = ExLlamaV2RMSNorm(model, key + km["norm_2_post"]) if km["norm_2_post"] else None
else:
self.pre_layernorm = None
self.post_layernorm = None
self.up_proj = ExLlamaV2Linear(model, key + cfg.arch.mlp_key_up, cfg.hidden_size, cfg.intermediate_size, self.model.config.arch.mlp_bias, f_key = f_key, f_beg = f_b, f_end = f_c)
self.down_proj = ExLlamaV2Linear(model, key + cfg.arch.mlp_key_down, cfg.intermediate_size, cfg.hidden_size, self.model.config.arch.mlp_bias, prescale = cfg.scale_depth)
self.up_proj = ExLlamaV2Linear(model, key + km["mlp_up"], in_features, interm_features, ap.mlp_bias, f_key = f_key, f_beg = f_b, f_end = f_c)
self.down_proj = ExLlamaV2Linear(model, key + km["mlp_down"], interm_features, out_features, ap.mlp_bias, prescale = cfg.scale_depth)
self.submodules = [self.up_proj,
self.down_proj]
@@ -84,8 +98,8 @@ class ExLlamaV2MLP(ExLlamaV2Module):
if self.post_layernorm:
self.submodules += [self.post_layernorm]
if cfg.arch.mlp_gate:
self.gate_proj = ExLlamaV2Linear(model, key + cfg.arch.mlp_key_gate, cfg.hidden_size, cfg.intermediate_size, self.model.config.arch.mlp_bias, f_key = f_key, f_beg = f_a, f_end = f_b)
if ap.mlp_gate:
self.gate_proj = ExLlamaV2Linear(model, key + km["mlp_gate"], in_features, interm_features, ap.mlp_bias, f_key = f_key, f_beg = f_a, f_end = f_b)
self.submodules += [self.gate_proj]
else:
self.gate_proj = None
@@ -96,7 +110,7 @@ class ExLlamaV2MLP(ExLlamaV2Module):
numel = self.up_proj.numel() + \
self.down_proj.numel()
if self.model.config.arch.mlp_gate:
if self.archparams.arch.mlp_gate:
numel += self.gate_proj.numel()
if self.pre_layernorm is not None:
@@ -113,6 +127,7 @@ class ExLlamaV2MLP(ExLlamaV2Module):
device_context: bool = True
):
cfg = self.model.config
km = self.archparams.keys
if self.pre_layernorm is not None:
self.pre_layernorm.load()
@@ -120,10 +135,10 @@ class ExLlamaV2MLP(ExLlamaV2Module):
self.post_layernorm.load()
if cfg.checkpoint_fused_mlp:
w12 = self.load_weight(self.key + cfg.arch.fused_mlp_key_12)
w1 = nn.Parameter(w12[:cfg.intermediate_size, :].contiguous())
w2 = nn.Parameter(w12[cfg.intermediate_size:, :].contiguous())
w3 = self.load_weight(self.key + cfg.arch.fused_mlp_key_3)
w12 = self.load_weight(self.key + km["fused_mlp_12"])
w1 = nn.Parameter(w12[:self.interm_features, :].contiguous())
w2 = nn.Parameter(w12[self.interm_features:, :].contiguous())
w3 = self.load_weight(self.key + km["fused_mlp_3"])
self.down_proj.load(w3, device_context = device_context)
self.gate_proj.load(w1, device_context = device_context)
self.up_proj.load(w2, device_context = device_context)
@@ -180,11 +195,11 @@ class ExLlamaV2MLP(ExLlamaV2Module):
temp_b,
temp_dq,
cfg.max_input_len * cfg.max_batch_size,
cfg.arch.mlp_act_func == "gelu",
self.archparams.mlp_act_func == "gelu",
self.has_residual,
post_norm_weight,
post_norm_bias,
cfg.arch.residual_stream_fp32,
self.archparams.residual_stream_fp32,
not cfg.no_graphs
)
@@ -205,7 +220,8 @@ class ExLlamaV2MLP(ExLlamaV2Module):
def weight_footprint(self) -> int:
if self.model.config.checkpoint_fused_mlp:
fp = 3 * self.model.config.intermediate_size * self.model.config.hidden_size * 2
fp = 2 * self.in_features * self.interm_features * 2 + \
self.interm_features * self.out_features * 2
else:
fp = self.up_proj.weight_footprint() + \
self.down_proj.weight_footprint()
@@ -230,30 +246,31 @@ class ExLlamaV2MLP(ExLlamaV2Module):
def scratch_space(self) -> int:
cfg = self.model.config
assert cfg.intermediate_size >= cfg.hidden_size
return self.temp_state_size() + \
self.temp_a_size() + \
self.temp_b_size() + \
self.temp_dq_size()
assert self.interm_features >= self.in_features and self.interm_features >= self.out_features
return (
self.temp_state_size() +
self.temp_a_size() +
self.temp_b_size() +
self.temp_dq_size()
)
def temp_state_size(self) -> int:
cfg = self.model.config
return cfg.max_input_len * cfg.max_batch_size * cfg.hidden_size * 2 + 128
return cfg.max_input_len * cfg.max_batch_size * self.out_features * 2 + 128
def temp_a_size(self) -> int:
cfg = self.model.config
return cfg.max_input_len * cfg.max_batch_size * cfg.intermediate_size * 2 + 128
return cfg.max_input_len * cfg.max_batch_size * self.interm_features * 2 + 128
def temp_b_size(self) -> int:
cfg = self.model.config
return cfg.max_input_len * cfg.max_batch_size * cfg.intermediate_size * 2 + 128
return cfg.max_input_len * cfg.max_batch_size * self.interm_features * 2 + 128
def temp_dq_size(self) -> int:
@@ -316,7 +333,7 @@ class ExLlamaV2MLP(ExLlamaV2Module):
pass_loras,
pass_lora_temp)
if cfg.arch.clamp_hidden_states:
if self.archparams.clamp_hidden_states:
hidden_states.clamp_(-65504, 65504)
return hidden_states
@@ -354,7 +371,7 @@ class ExLlamaV2MLP(ExLlamaV2Module):
self.gate_proj.q_handle if self.gate_proj is not None else [],
self.up_proj.q_handle,
self.down_proj.q_handle,
cfg.arch.mlp_act_func == "gelu"
self.archparams.mlp_act_func == "gelu"
)
return ctx.get_pinned(0, batch_size, q_len, cfg.hidden_size)
@@ -391,9 +408,9 @@ class ExLlamaV2MLP(ExLlamaV2Module):
context = self.model.get_device_context(dev)
torch.cuda.set_stream(context.stream)
if cfg.arch.mlp_act_func == "silu":
if self.archparams.mlp_act_func == "silu":
output = F.silu(gate[idx])
elif cfg.arch.mlp_act_func == "gelu":
elif self.archparams.mlp_act_func == "gelu":
output = F.gelu(gate[idx], approximate = "tanh")
output *= up[idx]
# output.clamp_(min = -65504.0, max = 65504.0)
@@ -430,28 +447,28 @@ class ExLlamaV2MLP(ExLlamaV2Module):
if self.gate_proj is not None:
gate = self.gate_proj.forward(post_norm, loras = loras)
if cfg.arch.mlp_act_func == "silu":
if self.archparams.mlp_act_func == "silu":
y = F.silu(gate)
elif cfg.arch.mlp_act_func == "gelu":
elif self.archparams.mlp_act_func == "gelu":
y = F.gelu(gate, approximate = "tanh")
up = self.up_proj.forward(post_norm, loras = loras)
y *= up
y.clamp_(min = -65504.0, max = 65504.0)
else:
up = self.up_proj.forward(post_norm, loras = loras)
if cfg.arch.mlp_act_func == "silu":
if self.archparams.mlp_act_func == "silu":
y = F.silu(up)
elif cfg.arch.mlp_act_func == "gelu":
elif self.archparams.mlp_act_func == "gelu":
y = F.gelu(up, approximate = "tanh")
down = self.down_proj.forward(y, loras = loras)
if self.post_layernorm:
down = self.post_layernorm.forward(down, output_fp32 = cfg.arch.residual_stream_fp32)
down = self.post_layernorm.forward(down, output_fp32 = self.archparams.residual_stream_fp32)
hidden_states = down + residual if self.has_residual else down
if cfg.arch.residual_stream_fp32:
if self.archparams.residual_stream_fp32:
hidden_states = hidden_states.float()
elif cfg.arch.clamp_hidden_states:
elif self.archparams.clamp_hidden_states:
hidden_states = hidden_states.clamp(-65504, 65504)
if intermediates:

View File

@@ -1,8 +1,6 @@
from __future__ import annotations
import os, sys
from exllamav2.architecture import RopeStyle
min_version = (3, 8)
if sys.version_info < min_version:
print("")
@@ -49,6 +47,7 @@ from exllamav2.compat import safe_move_tensor
from exllamav2.stloader import cleanup_stfiles
from exllamav2.device import ExLlamaV2DeviceContext, set_device_streams
from exllamav2.tensor_p import TPContext, BROADCAST_VC
from exllamav2.architecture import RopeStyle
import gc
import threading
from typing import Callable
@@ -76,7 +75,12 @@ class ExLlamaV2:
tp_context: TPContext | None
def __init__(self, config: ExLlamaV2Config, lazy_load = False):
def __init__(
self,
config: ExLlamaV2Config,
lazy_load = False,
archparams = None
):
self.config = config
self.modules = []
@@ -88,47 +92,57 @@ class ExLlamaV2:
# Build model
emb = ExLlamaV2Embedding(self, "model.embed_tokens")
cfg = self.config
if archparams is None: archparams = cfg.arch.lm
self.archparams = archparams
emb = ExLlamaV2Embedding(self, cfg.arch.lm_prefix + "model.embed_tokens")
self.modules += [emb]
if self.config.arch.learned_pos_emb_key:
pos_emb = ExLlamaV2PosEmbedding(self, self.config.arch.learned_pos_emb_key)
if archparams.keys["learned_pos_emb"]:
pos_emb = ExLlamaV2PosEmbedding(self, archparams.keys["learned_pos_emb"])
self.modules += [pos_emb]
for layer_idx in range(self.config.num_hidden_layers):
for layer_idx in range(cfg.num_hidden_layers):
layer_key = f"model.layers.{layer_idx}"
if self.config.arch.parallel_decoder_blocks:
layer_key = cfg.arch.lm_prefix + f"model.layers.{layer_idx}"
if cfg.arch.lm.parallel_decoder_blocks:
pd = ExLlamaV2ParallelDecoder(self, layer_key, layer_idx)
self.modules += [pd]
else:
if self.config.arch.alternating_swa:
swa = self.config.sliding_window if not bool(layer_idx % 2) else 0
elif self.config.arch.swa:
swa = self.config.sliding_window
if cfg.arch.lm.alternating_swa:
swa = cfg.sliding_window if not bool(layer_idx % 2) else 0
elif cfg.arch.lm.swa:
swa = cfg.sliding_window
else:
swa = 0
attn = ExLlamaV2Attention(self, layer_key, layer_idx, sliding_window = swa)
if self.config.arch.is_moe: mlp = ExLlamaV2MoEMLP(self, layer_key, layer_idx)
if cfg.arch.lm.is_moe: mlp = ExLlamaV2MoEMLP(self, layer_key, layer_idx)
else: mlp = ExLlamaV2MLP(self, layer_key, layer_idx)
self.modules += [attn, mlp]
if self.config.arch.norm == "layernorm": norm = ExLlamaV2LayerNorm(self, "model.norm")
elif self.config.arch.norm == "rmsnorm": norm = ExLlamaV2RMSNorm(self, "model.norm")
else: raise ValueError("unknown norm type")
if cfg.arch.lm.norm == "layernorm":
norm = ExLlamaV2LayerNorm(self, cfg.arch.lm_prefix + "model.norm")
elif cfg.arch.lm.norm == "rmsnorm":
norm = ExLlamaV2RMSNorm(self, cfg.arch.lm_prefix + "model.norm")
else:
raise ValueError("unknown norm type")
self.modules += [norm]
self.head_layer_idx = len(self.modules)
head = ExLlamaV2Linear(self, "lm_head",
self.config.hidden_size,
self.config.vocab_size,
False,
max_out_len = self.config.max_output_len,
prescale = self.config.logit_scale,
is_sub_module = False,
normalize_unq = bool(self.config.norm_head))
if self.config.arch.lm_head_key != "lm_head":
head.alt_key = self.config.arch.lm_head_key
head = ExLlamaV2Linear(
self,
cfg.arch.lm_prefix + "lm_head",
cfg.hidden_size,
cfg.vocab_size,
False,
max_out_len = cfg.max_output_len,
prescale = cfg.logit_scale,
is_sub_module = False,
normalize_unq = bool(cfg.norm_head)
)
if archparams.keys["lm_head"] != "lm_head":
head.alt_key = archparams.keys["lm_head"]
self.modules += [head]
# Compile dictionary of modules
@@ -157,17 +171,21 @@ class ExLlamaV2:
embed_cpu: bool = True
) -> list[float]:
cfg = self.config
self.cache_map = {}
# Constant shared between layers
sincos_size = self.config.head_dim * self.config.max_seq_len * 2
constant_size = sincos_size * 2
if self.archparams.rope_style != RopeStyle.NONE:
sincos_size = cfg.head_dim * cfg.max_seq_len * 2
constant_size = sincos_size * 2
else:
constant_size = 0
# Max size of hidden state
state_size = self.config.hidden_size * self.config.max_input_len * self.config.max_batch_size * 2
mask_size = self.config.max_input_len ** 2 * self.config.max_batch_size * 2
state_size = cfg.hidden_size * cfg.max_input_len * cfg.max_batch_size * 2
mask_size = cfg.max_input_len ** 2 * cfg.max_batch_size * 2
# Bytes remaining per device
@@ -184,7 +202,7 @@ class ExLlamaV2:
# Special case for token embeddings on CPU
if idx == 0 and embed_cpu:
if isinstance(module, ExLlamaV2Embedding) and embed_cpu:
module.set_device_idx(-1)
continue
@@ -245,6 +263,7 @@ class ExLlamaV2:
callback: Callable[[int, int], None] | None = None,
callback_gen: Callable[[int, int], None] | None = None,
progress: bool = False
progress: bool = False,
):
"""
Load model, regular manual split mode.

View File

@@ -25,9 +25,16 @@ class ExLlamaV2Module:
submodules: list[ExLlamaV2Module]
assumed_footprint: int
def __init__(self,
model: ExLlamaV2,
key: str):
def __init__(
self,
model: ExLlamaV2,
key: str,
archparams = None,
):
if archparams is None:
archparams = model.config.arch.lm
self.archparams = archparams
self.model = model
self.key = key

View File

@@ -29,13 +29,18 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module):
temp_lora_size: int
def __init__(self,
model: ExLlamaV2,
key: str,
layer_idx: int):
super().__init__(model, key)
def __init__(
self,
model: ExLlamaV2,
key: str,
layer_idx: int,
archparams = None
):
super().__init__(model, key, archparams)
cfg = self.model.config
ap = self.archparams
km = self.archparams.keys
self.layer_idx = layer_idx
@@ -47,19 +52,19 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module):
self.num_experts = cfg.num_experts
self.num_experts_per_token = cfg.num_experts_per_token
if cfg.arch.norm == "layernorm":
self.post_attention_layernorm = ExLlamaV2LayerNorm(model, key + self.model.config.arch.norm_key_2)
elif cfg.arch.norm == "rmsnorm":
self.post_attention_layernorm = ExLlamaV2RMSNorm(model, key + self.model.config.arch.norm_key_2)
if ap.norm == "layernorm":
self.post_attention_layernorm = ExLlamaV2LayerNorm(model, key + km["norm_2"])
elif ap.norm == "rmsnorm":
self.post_attention_layernorm = ExLlamaV2RMSNorm(model, key + km["norm_2"])
w1_key = key + cfg.arch.mlp_key_gate
w2_key = key + cfg.arch.mlp_key_down
w3_key = key + cfg.arch.mlp_key_up
w1_key = key + km["mlp_gate"]
w2_key = key + km["mlp_down"]
w3_key = key + km["mlp_up"]
w1_f_key = w1_key.replace(".*.", ".")
w2_f_key = w2_key.replace(".*.", ".")
w3_f_key = w3_key.replace(".*.", ".")
gate_key = cfg.arch.mlp_key_expert_gate
gate_key = km["mlp_expert_gate"]
self.w1 = []
self.w2 = []
@@ -72,9 +77,9 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module):
# ad = bd
bu += intermediate_size
# bd += hidden_size
w1 = ExLlamaV2Linear(model, w1_key.replace("*", str(e)), hidden_size, intermediate_size, cfg.arch.mlp_bias, f_key = w1_f_key, f_beg = au, f_end = bu)
w2 = ExLlamaV2Linear(model, w2_key.replace("*", str(e)), intermediate_size, hidden_size, cfg.arch.mlp_bias, f_key = w2_f_key, f_beg = au, f_end = bu)
w3 = ExLlamaV2Linear(model, w3_key.replace("*", str(e)), hidden_size, intermediate_size, cfg.arch.mlp_bias, f_key = w3_f_key, f_beg = au, f_end = bu)
w1 = ExLlamaV2Linear(model, w1_key.replace("*", str(e)), hidden_size, intermediate_size, ap.mlp_bias, f_key = w1_f_key, f_beg = au, f_end = bu)
w2 = ExLlamaV2Linear(model, w2_key.replace("*", str(e)), intermediate_size, hidden_size, ap.mlp_bias, f_key = w2_f_key, f_beg = au, f_end = bu)
w3 = ExLlamaV2Linear(model, w3_key.replace("*", str(e)), hidden_size, intermediate_size, ap.mlp_bias, f_key = w3_f_key, f_beg = au, f_end = bu)
self.w1.append(w1)
self.w2.append(w2)
self.w3.append(w3)
@@ -124,7 +129,7 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module):
device_context.get_scratch_slice(self.temp_logit_size()),
device_context.get_scratch_slice(self.temp_dq_size()),
self.model.config.max_input_len * self.model.config.max_batch_size,
self.model.config.arch.mlp_act_func == "gelu"
self.archparams.mlp_act_func == "gelu"
)
@@ -296,9 +301,9 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module):
gate = self.w1[expert_idx].forward(current_state, loras = loras)
up = self.w3[expert_idx].forward(current_state, loras = loras)
if self.model.config.arch.mlp_act_func == "silu":
if self.archparams.mlp_act_func == "silu":
current_hidden_states = F.silu(gate)
elif self.model.config.arch.mlp_act_func == "gelu":
elif self.archparams.mlp_act_func == "gelu":
current_hidden_states = F.gelu(gate)
current_hidden_states *= up
if intermediates: result[f"pre_down.{expert_idx}"] = current_hidden_states

View File

@@ -24,18 +24,23 @@ class ExLlamaV2ParallelDecoder(ExLlamaV2Module):
attn: ExLlamaV2Attention
mlp: ExLlamaV2MLP
def __init__(self,
model: ExLlamaV2,
key: str,
layer_idx: int):
super().__init__(model, key)
def __init__(
self,
model: ExLlamaV2,
key: str,
layer_idx: int,
archparams = None
):
super().__init__(model, key, archparams)
cfg = self.model.config
self.layer_idx = layer_idx
if self.model.config.arch.norm == "layernorm":
self.input_layernorm = ExLlamaV2LayerNorm(model, key + self.model.config.arch.norm_key_1)
elif self.model.config.arch.norm == "rmsnorm":
self.input_layernorm = ExLlamaV2RMSNorm(model, key + self.model.config.arch.norm_key_1)
if self.archparams.norm == "layernorm":
self.input_layernorm = ExLlamaV2LayerNorm(model, key + self.archparams.keys["norm_1"])
elif self.archparams.norm == "rmsnorm":
self.input_layernorm = ExLlamaV2RMSNorm(model, key + self.archparams.keys["norm_1"])
self.attn = ExLlamaV2Attention(model, key, layer_idx, has_norm = False, has_residual = False)
self.mlp = ExLlamaV2MLP(model, key, layer_idx, has_norm = False, has_residual = False)

View File

@@ -20,8 +20,13 @@ class ExLlamaV2RMSNorm(ExLlamaV2Module):
is_tp: bool
broadcast_type: int | None
def __init__(self, model, key):
super().__init__(model, key)
def __init__(
self,
model,
key,
archparams = None
):
super().__init__(model, key, archparams)
self.is_tp = False
self.broadcast_type = None
@@ -50,8 +55,8 @@ class ExLlamaV2RMSNorm(ExLlamaV2Module):
self.variance_epsilon = self.model.config.norm_eps
# Gemma adds 1 to the norm tensor for some reason
if self.model.config.arch.norm_constant_bias != 0:
self.weight += self.model.config.arch.norm_constant_bias
if self.archparams.norm_constant_bias != 0:
self.weight += self.archparams.norm_constant_bias
def unload(self):
@@ -74,8 +79,8 @@ class ExLlamaV2RMSNorm(ExLlamaV2Module):
def get_weight(self) -> torch.Tensor:
# Make sure to return the original weight tensor for Gemma
if self.model.config.arch.norm_constant_bias != 0:
return self.weight.data - self.model.config.arch.norm_constant_bias
if self.archparams.norm_constant_bias != 0:
return self.weight.data - self.archparams.norm_constant_bias
return self.weight.data

View File

@@ -56,7 +56,7 @@ class TPContext:
self.model = model
cfg = self.model.config
assert cfg.arch.supports_tp, \
assert cfg.arch.lm.supports_tp, \
f"Tensor-parallel is not supported for {cfg.arch.arch_string}"
assert cfg.intermediate_size % 128 == 0, \
"Model intermediate size must be divisible by 128"

View File

@@ -283,7 +283,7 @@ if args.eval_dataset or args.standard_perplexity:
eval_len = [eval_tokens.shape[1]] * eval_tokens.shape[0]
# if args.eval_bos:
if model.config.arch.requires_bos:
if model.config.arch.lm.requires_bos:
boss = torch.full((eval_tokens.shape[0], 1), tokenizer.bos_token_id, dtype = torch.long)
eval_tokens = torch.cat((boss, eval_tokens[:, :-1]), dim = 1)