mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 06:19:00 +00:00
Refactor architecture logic for code reuse between LLM/VLM
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
|
||||
# Common keys
|
||||
@@ -105,7 +106,7 @@ class ExLlamaV2ArchParams:
|
||||
"""
|
||||
Get architecture definition from model config. If the architecture isn't recognized, defaults to Llama
|
||||
architecture.
|
||||
W
|
||||
|
||||
:param arch_string:
|
||||
Architecture string from config.json
|
||||
|
||||
@@ -116,115 +117,135 @@ W
|
||||
self.arch_string = arch_string
|
||||
arch_recognized = False
|
||||
|
||||
# Keys to expect in model dict
|
||||
self.expect_keys = []
|
||||
|
||||
# Keys to expect in model dict, per layer
|
||||
self.layer_keys = []
|
||||
|
||||
# Map tensors in HF model to standard keys
|
||||
self.keymap = None
|
||||
|
||||
# Fused tensors
|
||||
self.fused_qkv_key = None
|
||||
self.fused_mlp_key_12 = None
|
||||
self.fused_mlp_key_3 = None
|
||||
@dataclass
|
||||
class Params:
|
||||
keys: dict = field(default_factory = lambda: {
|
||||
"norm_eps": "rms_norm_eps",
|
||||
"norm_1": ".input_layernorm",
|
||||
"norm_1_post": None,
|
||||
"fused_qkv": None,
|
||||
"mlp_gate": ".mlp.gate_proj",
|
||||
"mlp_up": ".mlp.up_proj",
|
||||
"mlp_down": ".mlp.down_proj",
|
||||
"lm_head": "lm_head",
|
||||
"norm_2": ".post_attention_layernorm",
|
||||
"norm_2_post": None,
|
||||
"fused_mlp_12": None,
|
||||
"fused_mlp_3": None,
|
||||
"learned_pos_emb": None
|
||||
})
|
||||
|
||||
# Alternate packing scheme for fused QKV tensor (InternLM2 quirk)
|
||||
self.fused_qkv_altpack = False
|
||||
# Compute logit scale from `dim_model_base` key in config.json (MiniCPM quirk)
|
||||
logit_scale_basedim = False
|
||||
|
||||
# Learned position embeddings
|
||||
self.learned_pos_emb_key = None
|
||||
# Clamp hidden states to FP16 range
|
||||
clamp_hidden_states = False
|
||||
|
||||
# Default multiplier for MLP inner dim (GPT2 quirk)
|
||||
self.default_inner_dim_mult = None
|
||||
# Upcast hidden state to FP32 before adding to residual stream
|
||||
residual_stream_fp32 = False
|
||||
|
||||
# Compute logit scale from `dim_model_base` key in config.json (MiniCPM quirk)
|
||||
self.logit_scale_basedim = False
|
||||
# Normalize embeddings (Gemma quirk)
|
||||
normalize_embeddings = False
|
||||
|
||||
# Constant bias for layernorm (Gemma quirk)
|
||||
norm_constant_bias = 0
|
||||
|
||||
# Alternate packing scheme for fused QKV tensor (InternLM2 quirk)
|
||||
fused_qkv_altpack = False
|
||||
|
||||
# SWA required by architecture
|
||||
swa = False
|
||||
alternating_swa = False
|
||||
|
||||
# Model only works with eager attention
|
||||
eager_attn_only = False
|
||||
|
||||
# Expect bias for linear layers
|
||||
attention_bias_qkv = False
|
||||
attention_bias_o = False
|
||||
mlp_bias = False
|
||||
|
||||
# Default multiplier for MLP inner dim (GPT2 quirk)
|
||||
default_inner_dim_mult = None
|
||||
|
||||
# Use gated MLP
|
||||
mlp_gate = True
|
||||
|
||||
# Use block-sparse MLP
|
||||
is_moe = False
|
||||
|
||||
# Use parallel decoder blocks (Cohere quirk)
|
||||
parallel_decoder_blocks = False
|
||||
|
||||
# Use MQA, effectively num_key_value_heads = 1 (GPTBigCode quirk)
|
||||
mqa = False
|
||||
|
||||
# Model is incoherent without BOS at the start of the context
|
||||
requires_bos = False
|
||||
|
||||
# Scale attn weights (GPT2 quirk, not important for inference)
|
||||
scale_attn_weights = False
|
||||
|
||||
# Model implementation works in tensor-parallel mode
|
||||
supports_tp = False
|
||||
|
||||
# Activation function
|
||||
mlp_act_func = "silu"
|
||||
|
||||
# Layer norm type
|
||||
norm = "rmsnorm"
|
||||
|
||||
# RoPE style
|
||||
rope_style = RopeStyle.NEOX
|
||||
|
||||
# Expected keys
|
||||
expect_keys: list[str] = field(default_factory = lambda: [])
|
||||
layer_keys: list[str] = field(default_factory = lambda: [])
|
||||
|
||||
# Component models
|
||||
self.lm_prefix = ""
|
||||
self.vt_prefix = ""
|
||||
self.mmp_prefix = ""
|
||||
self.lm = Params()
|
||||
self.mmp = Params()
|
||||
self.vt = Params()
|
||||
|
||||
self.mmp.keys.update({
|
||||
"norm_1": None,
|
||||
"norm_1_post": None,
|
||||
"norm_2": None,
|
||||
"norm_2_post": None,
|
||||
"fused_mlp_12": None,
|
||||
"fused_mlp_3": None,
|
||||
})
|
||||
self.mmp.rope_style = RopeStyle.NONE
|
||||
|
||||
# Tensors are transposed in original model weights
|
||||
self.orig_weights_transposed = False
|
||||
|
||||
# Post norm keys
|
||||
self.norm_key_1_post = None
|
||||
self.norm_key_2_post = None
|
||||
|
||||
# SWA required by architecture
|
||||
self.swa = False
|
||||
self.alternating_swa = False
|
||||
|
||||
# Model only works with eager attention
|
||||
self.eager_attn_only = False
|
||||
|
||||
# Clamp hidden states to FP16 range
|
||||
self.clamp_hidden_states = False
|
||||
|
||||
# Upcast hidden state to FP32 before adding to residual stream
|
||||
self.residual_stream_fp32 = False
|
||||
|
||||
# Expect bias for linear layers
|
||||
self.attention_bias_qkv = False
|
||||
self.attention_bias_o = False
|
||||
self.mlp_bias = False
|
||||
|
||||
# Use gated MLP
|
||||
self.mlp_gate = True
|
||||
|
||||
# Use block-sparse MLP
|
||||
self.is_moe = False
|
||||
|
||||
# Normalize embeddings (Gemma quirk)
|
||||
self.normalize_embeddings = False
|
||||
|
||||
# Constant bias for layernorm (Gemma quirk)
|
||||
self.norm_constant_bias = 0
|
||||
|
||||
# Use parallel decoder blocks (Cohere quirk)
|
||||
self.parallel_decoder_blocks = False
|
||||
|
||||
# Model is incoherent without BOS at the start of the context
|
||||
self.requires_bos = False
|
||||
|
||||
# Use MQA, effectively num_key_valu_heads = 1 (GPTBigCode quirk)
|
||||
self.mqa = False
|
||||
|
||||
# Scale attn weights (GPT2 quirk, not important for inference)
|
||||
self.scale_attn_weights = False
|
||||
|
||||
# Model implementation works in tensor-parallel mode
|
||||
self.supports_tp = False
|
||||
|
||||
# Mistral
|
||||
|
||||
if arch_string == "MistralForCausalLM":
|
||||
arch_recognized = True
|
||||
self.layer_keys += \
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_llama_norms + \
|
||||
layer_keys_llama_attn + \
|
||||
layer_keys_llama_mlp
|
||||
self.expect_keys += \
|
||||
self.lm.expect_keys += \
|
||||
expect_keys_llama
|
||||
self.norm_eps_key = "rms_norm_eps"
|
||||
self.mlp_key_gate = ".mlp.gate_proj"
|
||||
self.mlp_key_up = ".mlp.up_proj"
|
||||
self.mlp_key_down = ".mlp.down_proj"
|
||||
self.lm_head_key = "lm_head"
|
||||
self.norm_key_1 = ".input_layernorm"
|
||||
self.norm_key_2 = ".post_attention_layernorm"
|
||||
self.mlp_act_func = "silu"
|
||||
self.norm = "rmsnorm"
|
||||
self.rope_style = RopeStyle.NEOX
|
||||
self.supports_tp = True
|
||||
self.lm.supports_tp = True
|
||||
|
||||
# Mixtral
|
||||
|
||||
if arch_string == "MixtralForCausalLM":
|
||||
arch_recognized = True
|
||||
self.layer_keys += \
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_llama_norms + \
|
||||
layer_keys_llama_attn + \
|
||||
layer_keys_mixtral_mlp
|
||||
self.expect_keys += \
|
||||
self.lm.expect_keys += \
|
||||
expect_keys_llama
|
||||
self.norm_eps_key = "rms_norm_eps"
|
||||
self.mlp_key_gate = ".block_sparse_moe.experts.*.w1"
|
||||
@@ -238,372 +259,307 @@ W
|
||||
self.norm = "rmsnorm"
|
||||
self.rope_style = RopeStyle.NEOX
|
||||
self.is_moe = True
|
||||
self.lm.keys.update({
|
||||
"mlp_gate": ".block_sparse_moe.experts.*.w1",
|
||||
"mlp_up": ".block_sparse_moe.experts.*.w3",
|
||||
"mlp_down": ".block_sparse_moe.experts.*.w2",
|
||||
"mlp_expert_gate": ".block_sparse_moe.gate"
|
||||
})
|
||||
self.lm.is_moe = True
|
||||
|
||||
# Yi
|
||||
|
||||
if arch_string == "YiForCausalLM":
|
||||
arch_recognized = True
|
||||
self.layer_keys += \
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_yi_norms + \
|
||||
layer_keys_llama_attn + \
|
||||
layer_keys_llama_mlp
|
||||
self.expect_keys += \
|
||||
self.lm.expect_keys += \
|
||||
expect_keys_llama
|
||||
self.norm_eps_key = "rms_norm_eps"
|
||||
self.mlp_key_gate = ".mlp.gate_proj"
|
||||
self.mlp_key_up = ".mlp.up_proj"
|
||||
self.mlp_key_down = ".mlp.down_proj"
|
||||
self.mlp_act_func = "silu"
|
||||
self.norm_key_1 = ".ln1"
|
||||
self.norm_key_2 = ".ln2"
|
||||
self.norm = "rmsnorm"
|
||||
self.lm_head_key = "lm_head"
|
||||
self.rope_style = RopeStyle.NEOX
|
||||
self.lm.keys.update({
|
||||
"norm_1": ".ln1",
|
||||
"norm_2": ".ln2",
|
||||
})
|
||||
|
||||
# Orion
|
||||
|
||||
if arch_string == "OrionForCausalLM":
|
||||
arch_recognized = True
|
||||
self.layer_keys += \
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_llama_norms + \
|
||||
layer_keys_llama_attn + \
|
||||
layer_keys_llama_mlp
|
||||
self.expect_keys += \
|
||||
self.lm.expect_keys += \
|
||||
expect_keys_llama
|
||||
self.norm_eps_key = "rms_norm_eps"
|
||||
self.mlp_key_gate = ".mlp.gate_proj"
|
||||
self.mlp_key_up = ".mlp.up_proj"
|
||||
self.mlp_key_down = ".mlp.down_proj"
|
||||
self.lm_head_key = "lm_head"
|
||||
self.norm_key_1 = ".input_layernorm"
|
||||
self.norm_key_2 = ".post_attention_layernorm"
|
||||
self.mlp_act_func = "silu"
|
||||
self.norm = "layernorm"
|
||||
self.rope_style = RopeStyle.NEOX
|
||||
self.lm.norm = "layernorm"
|
||||
|
||||
# Qwen2 (1.5)
|
||||
|
||||
if arch_string == "Qwen2ForCausalLM":
|
||||
arch_recognized = True
|
||||
self.layer_keys += \
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_llama_norms + \
|
||||
layer_keys_llama_attn + \
|
||||
layer_keys_llama_mlp
|
||||
self.expect_keys += \
|
||||
self.lm.expect_keys += \
|
||||
expect_keys_llama
|
||||
self.norm_eps_key = "rms_norm_eps"
|
||||
self.mlp_key_gate = ".mlp.gate_proj"
|
||||
self.mlp_key_up = ".mlp.up_proj"
|
||||
self.mlp_key_down = ".mlp.down_proj"
|
||||
self.lm_head_key = "lm_head"
|
||||
self.norm_key_1 = ".input_layernorm"
|
||||
self.norm_key_2 = ".post_attention_layernorm"
|
||||
self.mlp_act_func = "silu"
|
||||
self.norm = "rmsnorm"
|
||||
self.rope_style = RopeStyle.NEOX
|
||||
self.attention_bias_qkv = True
|
||||
self.supports_tp = True
|
||||
self.lm.attention_bias_qkv = True
|
||||
self.lm.supports_tp = True
|
||||
|
||||
# Gemma
|
||||
|
||||
if arch_string == "GemmaForCausalLM":
|
||||
arch_recognized = True
|
||||
self.layer_keys += \
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_llama_norms + \
|
||||
layer_keys_llama_attn + \
|
||||
layer_keys_llama_mlp
|
||||
self.expect_keys += \
|
||||
self.lm.expect_keys += \
|
||||
expect_keys_gemma
|
||||
self.norm_eps_key = "rms_norm_eps"
|
||||
self.mlp_key_gate = ".mlp.gate_proj"
|
||||
self.mlp_key_up = ".mlp.up_proj"
|
||||
self.mlp_key_down = ".mlp.down_proj"
|
||||
self.lm_head_key = "model.embed_tokens"
|
||||
self.norm_key_1 = ".input_layernorm"
|
||||
self.norm_key_2 = ".post_attention_layernorm"
|
||||
self.mlp_act_func = "gelu"
|
||||
self.norm = "rmsnorm"
|
||||
self.rope_style = RopeStyle.NEOX
|
||||
self.normalize_embeddings = True
|
||||
self.norm_constant_bias = 1
|
||||
self.requires_bos = True
|
||||
self.lm.keys.update({
|
||||
"lm_head": "model.embed_tokens",
|
||||
})
|
||||
self.lm.mlp_act_func = "gelu"
|
||||
self.lm.normalize_embeddings = True
|
||||
self.lm.norm_constant_bias = 1
|
||||
self.lm.requires_bos = True
|
||||
|
||||
# Gemma2
|
||||
|
||||
if arch_string == "Gemma2ForCausalLM":
|
||||
arch_recognized = True
|
||||
self.layer_keys += \
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_gemma2_norms + \
|
||||
layer_keys_llama_attn + \
|
||||
layer_keys_llama_mlp
|
||||
self.expect_keys += \
|
||||
self.lm.expect_keys += \
|
||||
expect_keys_gemma
|
||||
self.norm_eps_key = "rms_norm_eps"
|
||||
self.mlp_key_gate = ".mlp.gate_proj"
|
||||
self.mlp_key_up = ".mlp.up_proj"
|
||||
self.mlp_key_down = ".mlp.down_proj"
|
||||
self.lm_head_key = "model.embed_tokens"
|
||||
self.norm_key_1 = ".input_layernorm"
|
||||
self.norm_key_1_post = ".post_attention_layernorm"
|
||||
self.norm_key_2 = ".pre_feedforward_layernorm"
|
||||
self.norm_key_2_post = ".post_feedforward_layernorm"
|
||||
self.mlp_act_func = "gelu"
|
||||
self.norm = "rmsnorm"
|
||||
self.rope_style = RopeStyle.NEOX
|
||||
self.normalize_embeddings = True
|
||||
self.norm_constant_bias = 1
|
||||
self.requires_bos = True
|
||||
self.pre_post_layernorm = True
|
||||
self.alternating_swa = True
|
||||
self.residual_stream_fp32 = True
|
||||
self.lm.keys.update({
|
||||
"lm_head": "model.embed_tokens",
|
||||
"norm_1": ".input_layernorm",
|
||||
"norm_1_post": ".post_attention_layernorm",
|
||||
"norm_2": ".pre_feedforward_layernorm",
|
||||
"norm_2_post": ".post_feedforward_layernorm",
|
||||
})
|
||||
self.lm.mlp_act_func = "gelu"
|
||||
self.lm.normalize_embeddings = True
|
||||
self.lm.norm_constant_bias = 1
|
||||
self.lm.requires_bos = True
|
||||
self.lm.alternating_swa = True
|
||||
self.lm.residual_stream_fp32 = True
|
||||
|
||||
# StarCoder2
|
||||
|
||||
if arch_string == "Starcoder2ForCausalLM":
|
||||
arch_recognized = True
|
||||
self.layer_keys += \
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_llama_norms + \
|
||||
layer_keys_llama_attn + \
|
||||
layer_keys_starcoder2_mlp
|
||||
self.expect_keys += \
|
||||
self.lm.expect_keys += \
|
||||
expect_keys_starcoder2
|
||||
self.norm_eps_key = "norm_epsilon"
|
||||
self.mlp_key_up = ".mlp.c_fc"
|
||||
self.mlp_key_down = ".mlp.c_proj"
|
||||
self.lm_head_key = "model.embed_tokens"
|
||||
self.norm_key_1 = ".input_layernorm"
|
||||
self.norm_key_2 = ".post_attention_layernorm"
|
||||
self.mlp_act_func = "gelu"
|
||||
self.norm = "layernorm"
|
||||
self.rope_style = RopeStyle.NEOX
|
||||
self.attention_bias_qkv = True
|
||||
self.attention_bias_o = True
|
||||
self.mlp_bias = True
|
||||
self.mlp_gate = False
|
||||
self.lm.keys.update({
|
||||
"mlp_gate": None,
|
||||
"mlp_up": ".mlp.c_fc",
|
||||
"mlp_down": ".mlp.c_proj",
|
||||
"lm_head": "model.embed_tokens",
|
||||
})
|
||||
self.lm.mlp_act_func = "gelu"
|
||||
self.lm.norm = "layernorm"
|
||||
self.lm.attention_bias_qkv = True
|
||||
self.lm.attention_bias_o = True
|
||||
self.lm.mlp_bias = True
|
||||
self.lm.mlp_gate = False
|
||||
|
||||
# GemMoE
|
||||
|
||||
if arch_string == "GemmoeForCausalLM":
|
||||
arch_recognized = True
|
||||
print(f" !! Warning, Gemmoe support is experimental and has not been fully tested")
|
||||
self.layer_keys += \
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_llama_norms + \
|
||||
layer_keys_llama_attn + \
|
||||
layer_keys_mixtral_mlp
|
||||
self.expect_keys += \
|
||||
self.lm.expect_keys += \
|
||||
expect_keys_gemma
|
||||
self.norm_eps_key = "rms_norm_eps"
|
||||
self.mlp_key_gate = ".block_sparse_moe.experts.*.w1"
|
||||
self.mlp_key_up = ".block_sparse_moe.experts.*.w3"
|
||||
self.mlp_key_down = ".block_sparse_moe.experts.*.w2"
|
||||
self.mlp_key_expert_gate = ".block_sparse_moe.gate"
|
||||
self.lm_head_key = "model.embed_tokens"
|
||||
self.norm_key_1 = ".input_layernorm"
|
||||
self.norm_key_2 = ".post_attention_layernorm"
|
||||
self.mlp_act_func = "gelu"
|
||||
self.norm = "rmsnorm"
|
||||
self.rope_style = RopeStyle.NEOX
|
||||
self.normalize_embeddings = True
|
||||
self.norm_constant_bias = 1
|
||||
self.is_moe = True
|
||||
self.requires_bos = True
|
||||
self.lm.keys.update({
|
||||
"mlp_gate": ".block_sparse_moe.experts.*.w1",
|
||||
"mlp_up": ".block_sparse_moe.experts.*.w3",
|
||||
"mlp_down": ".block_sparse_moe.experts.*.w2",
|
||||
"mlp_expert_gate": ".block_sparse_moe.gate",
|
||||
"lm_head": "model.embed_tokens",
|
||||
})
|
||||
self.lm.mlp_act_func = "gelu"
|
||||
self.lm.normalize_embeddings = True
|
||||
self.lm.norm_constant_bias = 1
|
||||
self.lm.is_moe = True
|
||||
self.lm.requires_bos = True
|
||||
|
||||
# Cohere
|
||||
|
||||
if arch_string == "CohereForCausalLM":
|
||||
arch_recognized = True
|
||||
self.layer_keys += \
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_cohere_norms + \
|
||||
layer_keys_llama_attn + \
|
||||
layer_keys_llama_mlp
|
||||
self.expect_keys += \
|
||||
self.lm.expect_keys += \
|
||||
expect_keys_gemma
|
||||
self.norm_eps_key = "layer_norm_eps"
|
||||
self.mlp_key_gate = ".mlp.gate_proj"
|
||||
self.mlp_key_up = ".mlp.up_proj"
|
||||
self.mlp_key_down = ".mlp.down_proj"
|
||||
self.lm_head_key = "model.embed_tokens"
|
||||
self.norm_key_1 = ".input_layernorm"
|
||||
self.norm_key_2 = None
|
||||
self.mlp_act_func = "silu"
|
||||
self.norm = "layernorm"
|
||||
self.rope_style = RopeStyle.GPTJ
|
||||
self.parallel_decoder_blocks = True
|
||||
self.requires_bos = True
|
||||
self.lm.keys.update({
|
||||
"norm_eps": "layer_norm_eps",
|
||||
"lm_head": "model.embed_tokens",
|
||||
"norm_1": ".input_layernorm",
|
||||
"norm_2": None,
|
||||
})
|
||||
self.lm.norm = "layernorm"
|
||||
self.lm.rope_style = RopeStyle.GPTJ
|
||||
self.lm.parallel_decoder_blocks = True
|
||||
self.lm.requires_bos = True
|
||||
|
||||
# DBRX
|
||||
|
||||
if arch_string == "DbrxForCausalLM":
|
||||
arch_recognized = True
|
||||
self.keymap = dbrx_keymap
|
||||
self.layer_keys += \
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_llama_norms + \
|
||||
layer_keys_dbrx_attn + \
|
||||
layer_keys_dbrx_mlp
|
||||
self.expect_keys += \
|
||||
self.lm.expect_keys += \
|
||||
expect_keys_llama
|
||||
self.norm_eps_key = None
|
||||
self.mlp_key_gate = ".block_sparse_moe.experts.*.w1"
|
||||
self.mlp_key_up = ".block_sparse_moe.experts.*.v1"
|
||||
self.mlp_key_down = ".block_sparse_moe.experts.*.w2"
|
||||
self.mlp_key_expert_gate = ".block_sparse_moe.gate"
|
||||
self.lm_head_key = "lm_head"
|
||||
self.norm_key_1 = ".input_layernorm"
|
||||
self.norm_key_2 = ".post_attention_layernorm"
|
||||
self.mlp_act_func = "silu"
|
||||
self.norm = "layernorm"
|
||||
self.rope_style = RopeStyle.NEOX
|
||||
self.fused_qkv_key = "Wqkv"
|
||||
self.is_moe = True
|
||||
self.lm.keys.update({
|
||||
"norm_eps": None,
|
||||
"mlp_gate": ".block_sparse_moe.experts.*.w1",
|
||||
"mlp_up": ".block_sparse_moe.experts.*.v1",
|
||||
"mlp_down": ".block_sparse_moe.experts.*.w2",
|
||||
"mlp_expert_gate": ".block_sparse_moe.gate",
|
||||
"lm_head": "model.embed_tokens",
|
||||
"fused_qkv": "Wqkv",
|
||||
})
|
||||
self.lm.is_moe = True
|
||||
|
||||
# Phi3
|
||||
|
||||
if arch_string == "Phi3ForCausalLM":
|
||||
arch_recognized = True
|
||||
self.layer_keys += \
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_llama_norms + \
|
||||
layer_keys_phi3_attn + \
|
||||
layer_keys_phi3_mlp
|
||||
self.expect_keys += \
|
||||
self.lm.expect_keys += \
|
||||
expect_keys_llama
|
||||
self.norm_eps_key = "rms_norm_eps"
|
||||
self.mlp_key_gate = ".mlp.gate_proj"
|
||||
self.mlp_key_up = ".mlp.up_proj"
|
||||
self.mlp_key_down = ".mlp.down_proj"
|
||||
self.lm_head_key = "lm_head"
|
||||
self.norm_key_1 = ".input_layernorm"
|
||||
self.norm_key_2 = ".post_attention_layernorm"
|
||||
self.fused_qkv_key = "qkv_proj"
|
||||
self.fused_mlp_key_12 = "gate_up_proj"
|
||||
self.mlp_act_func = "silu"
|
||||
self.norm = "rmsnorm"
|
||||
self.rope_style = RopeStyle.NEOX
|
||||
self.lm.keys.update({
|
||||
"fused_qkv": "qkv_proj",
|
||||
"fused_mlp_12": "gate_up_proj",
|
||||
})
|
||||
|
||||
# GPTBigCode
|
||||
|
||||
if arch_string == "GPTBigCodeForCausalLM":
|
||||
arch_recognized = True
|
||||
self.keymap = bigcode_keymap
|
||||
self.layer_keys += \
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_gpt2_norms + \
|
||||
layer_keys_gpt2_attn + \
|
||||
layer_keys_gpt2_mlp
|
||||
self.expect_keys += \
|
||||
self.lm.expect_keys += \
|
||||
expect_keys_gpt2
|
||||
self.norm_eps_key = "layer_norm_epsilon"
|
||||
self.mlp_key_gate = None
|
||||
self.mlp_key_up = ".mlp.c_fc"
|
||||
self.mlp_key_down = ".mlp.c_proj"
|
||||
self.lm_head_key = "model.embed_tokens"
|
||||
self.norm_key_1 = ".ln_1"
|
||||
self.norm_key_2 = ".ln_2"
|
||||
self.fused_qkv_key = "c_attn"
|
||||
self.learned_pos_emb_key = "model.wpe"
|
||||
self.mlp_act_func = "gelu"
|
||||
self.norm = "layernorm"
|
||||
self.rope_style = RopeStyle.NONE
|
||||
self.mqa = True
|
||||
self.attention_bias_qkv = True
|
||||
self.attention_bias_o = True
|
||||
self.mlp_bias = True
|
||||
self.mlp_gate = False
|
||||
self.lm.keys.update({
|
||||
"norm_eps": "layer_norm_epsilon",
|
||||
"mlp_gate": None,
|
||||
"mlp_up": ".mlp.c_fc",
|
||||
"mlp_down": ".mlp.c_proj",
|
||||
"lm_head": "model.embed_tokens",
|
||||
"norm_1": ".ln_1",
|
||||
"norm_2": ".ln_2",
|
||||
"fused_qkv": "c_attn",
|
||||
"learned_pos_emb": "model.wpe",
|
||||
})
|
||||
self.lm.mlp_act_func = "gelu"
|
||||
self.lm.norm = "layernorm"
|
||||
self.lm.rope_style = RopeStyle.NONE
|
||||
self.lm.mqa = True
|
||||
self.lm.attention_bias_qkv = True
|
||||
self.lm.attention_bias_o = True
|
||||
self.lm.mlp_bias = True
|
||||
self.lm.mlp_gate = False
|
||||
|
||||
# GPT2
|
||||
|
||||
if arch_string == "GPT2LMHeadModel":
|
||||
arch_recognized = True
|
||||
self.keymap = gpt2_keymap
|
||||
self.layer_keys += \
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_gpt2_norms + \
|
||||
layer_keys_gpt2_attn + \
|
||||
layer_keys_gpt2_mlp
|
||||
self.expect_keys += \
|
||||
self.lm.expect_keys += \
|
||||
expect_keys_gpt2
|
||||
self.norm_eps_key = "layer_norm_epsilon"
|
||||
self.mlp_key_gate = None
|
||||
self.mlp_key_up = ".mlp.c_fc"
|
||||
self.mlp_key_down = ".mlp.c_proj"
|
||||
self.lm_head_key = "model.embed_tokens"
|
||||
self.norm_key_1 = ".ln_1"
|
||||
self.norm_key_2 = ".ln_2"
|
||||
self.fused_qkv_key = "c_attn"
|
||||
self.learned_pos_emb_key = "model.wpe"
|
||||
self.mlp_act_func = "gelu"
|
||||
self.norm = "layernorm"
|
||||
self.rope_style = RopeStyle.NONE
|
||||
self.default_inner_dim_mult = 4
|
||||
self.lm.keys.update({
|
||||
"norm_eps": "layer_norm_epsilon",
|
||||
"mlp_gate": None,
|
||||
"mlp_up": ".mlp.c_fc",
|
||||
"mlp_down": ".mlp.c_proj",
|
||||
"lm_head": "model.embed_tokens",
|
||||
"norm_1": ".ln_1",
|
||||
"norm_2": ".ln_2",
|
||||
"fused_qkv": "c_attn",
|
||||
"learned_pos_emb": "model.wpe",
|
||||
})
|
||||
self.lm.mlp_act_func = "gelu"
|
||||
self.lm.norm = "layernorm"
|
||||
self.lm.rope_style = RopeStyle.NONE
|
||||
self.lm.default_inner_dim_mult = 4
|
||||
self.lm.attention_bias_qkv = True
|
||||
self.lm.attention_bias_o = True
|
||||
self.lm.mlp_bias = True
|
||||
self.lm.mlp_gate = False
|
||||
self.orig_weights_transposed = True
|
||||
self.attention_bias_qkv = True
|
||||
self.attention_bias_o = True
|
||||
self.mlp_bias = True
|
||||
self.mlp_gate = False
|
||||
|
||||
# MiniCPM
|
||||
|
||||
if arch_string == "MiniCPMForCausalLM":
|
||||
arch_recognized = True
|
||||
self.layer_keys += \
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_llama_norms + \
|
||||
layer_keys_llama_attn + \
|
||||
layer_keys_llama_mlp
|
||||
self.expect_keys += \
|
||||
self.lm.expect_keys += \
|
||||
expect_keys_llama
|
||||
self.norm_eps_key = "rms_norm_eps"
|
||||
self.mlp_key_gate = ".mlp.gate_proj"
|
||||
self.mlp_key_up = ".mlp.up_proj"
|
||||
self.mlp_key_down = ".mlp.down_proj"
|
||||
self.lm_head_key = "lm_head"
|
||||
self.norm_key_1 = ".input_layernorm"
|
||||
self.norm_key_2 = ".post_attention_layernorm"
|
||||
self.mlp_act_func = "silu"
|
||||
self.norm = "rmsnorm"
|
||||
self.rope_style = RopeStyle.NEOX
|
||||
self.logit_scale_basedim = True
|
||||
self.lm.logit_scale_basedim = True
|
||||
|
||||
# InternLM2
|
||||
|
||||
if arch_string == "InternLM2ForCausalLM":
|
||||
arch_recognized = True
|
||||
self.keymap = internlm2_keymap
|
||||
self.layer_keys += \
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_internlm2_norms + \
|
||||
layer_keys_internlm2_attn + \
|
||||
layer_keys_internlm2_mlp
|
||||
self.expect_keys += \
|
||||
self.lm.expect_keys += \
|
||||
expect_keys_llama
|
||||
self.norm_eps_key = "rms_norm_eps"
|
||||
self.mlp_key_gate = ".feed_forward.w1"
|
||||
self.mlp_key_up = ".feed_forward.w3"
|
||||
self.mlp_key_down = ".feed_forward.w2"
|
||||
self.lm_head_key = "lm_head"
|
||||
self.norm_key_1 = ".attention_norm"
|
||||
self.norm_key_2 = ".ffn_norm"
|
||||
self.fused_qkv_key = "wqkv"
|
||||
self.mlp_act_func = "silu"
|
||||
self.norm = "rmsnorm"
|
||||
self.rope_style = RopeStyle.NEOX
|
||||
self.fused_qkv_altpack = True
|
||||
self.lm.keys.update({
|
||||
"mlp_gate": ".feed_forward.w1",
|
||||
"mlp_up": ".feed_forward.w3",
|
||||
"mlp_down": ".feed_forward.w2",
|
||||
"norm_1": ".attention_norm",
|
||||
"norm_2": ".ffn_norm",
|
||||
"fused_qkv": "wqkv",
|
||||
})
|
||||
self.lm.fused_qkv_altpack = True
|
||||
|
||||
# Index
|
||||
|
||||
if arch_string == "IndexForCausalLM":
|
||||
arch_recognized = True
|
||||
self.layer_keys += \
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_llama_norms + \
|
||||
layer_keys_llama_attn + \
|
||||
layer_keys_llama_mlp
|
||||
self.expect_keys += \
|
||||
self.lm.expect_keys += \
|
||||
expect_keys_llama
|
||||
self.norm_eps_key = "rms_norm_eps"
|
||||
self.mlp_key_gate = ".mlp.gate_proj"
|
||||
self.mlp_key_up = ".mlp.up_proj"
|
||||
self.mlp_key_down = ".mlp.down_proj"
|
||||
self.lm_head_key = "lm_head"
|
||||
self.norm_key_1 = ".input_layernorm"
|
||||
self.norm_key_2 = ".post_attention_layernorm"
|
||||
self.mlp_act_func = "silu"
|
||||
self.norm = "rmsnorm"
|
||||
self.rope_style = RopeStyle.NEOX
|
||||
|
||||
# Llama (default + fallback)
|
||||
|
||||
@@ -612,49 +568,41 @@ W
|
||||
print(f" !! Loading as LlamaForCausalLM")
|
||||
self.arch_string = "LlamaForCausalLM"
|
||||
if not arch_recognized:
|
||||
self.layer_keys += \
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_llama_norms + \
|
||||
layer_keys_llama_attn + \
|
||||
layer_keys_llama_mlp
|
||||
self.expect_keys += \
|
||||
self.lm.expect_keys += \
|
||||
expect_keys_llama
|
||||
self.norm_eps_key = "rms_norm_eps"
|
||||
self.mlp_key_gate = ".mlp.gate_proj"
|
||||
self.mlp_key_up = ".mlp.up_proj"
|
||||
self.mlp_key_down = ".mlp.down_proj"
|
||||
self.lm_head_key = "lm_head"
|
||||
self.norm_key_1 = ".input_layernorm"
|
||||
self.norm_key_2 = ".post_attention_layernorm"
|
||||
self.mlp_act_func = "silu"
|
||||
self.norm = "rmsnorm"
|
||||
self.rope_style = RopeStyle.NEOX
|
||||
self.supports_tp = True
|
||||
self.lm.supports_tp = True
|
||||
|
||||
# Arch overrides
|
||||
|
||||
if read_config.get("attention_bias", False):
|
||||
self.attention_bias_qkv = True
|
||||
self.attention_bias_o = True
|
||||
self.lm.attention_bias_qkv = True
|
||||
self.lm.attention_bias_o = True
|
||||
|
||||
if read_config.get("mlp_bias", False):
|
||||
self.mlp_bias = True
|
||||
self.lm.mlp_bias = True
|
||||
|
||||
if read_config.get("tie_word_embeddings", False):
|
||||
if ["lm_head"] in self.expect_keys:
|
||||
self.expect_keys.remove(["lm_head"])
|
||||
self.lm_head_key = "model.embed_tokens"
|
||||
if ["lm_head"] in self.lm.expect_keys:
|
||||
self.lm.expect_keys.remove(["lm_head"])
|
||||
self.lm.keys.update({
|
||||
"lm_head": "model.embed_tokens",
|
||||
})
|
||||
|
||||
# Sanity checks
|
||||
|
||||
if self.residual_stream_fp32:
|
||||
assert self.norm_key_1_post and self.norm_key_2_post, \
|
||||
if self.lm.residual_stream_fp32:
|
||||
assert self.lm.keys["norm_1_post"] and self.lm.keys["norm_2_post"], \
|
||||
"FP32 residual stream only implemented for arch with post layernorms"
|
||||
|
||||
def make_fused_mlp(self):
|
||||
|
||||
for x in layer_keys_llama_mlp: self.layer_keys.remove(x)
|
||||
self.layer_keys += layer_keys_llama_mlp_swiglu
|
||||
self.fused_mlp_key_12 = layer_keys_llama_mlp_swiglu[0][0]
|
||||
self.fused_mlp_key_3 = layer_keys_llama_mlp_swiglu[1][0]
|
||||
|
||||
|
||||
for x in layer_keys_llama_mlp: self.lm.layer_keys.remove(x)
|
||||
self.lm.layer_keys += layer_keys_llama_mlp_swiglu
|
||||
self.lm.keys.update({
|
||||
"fused_mlp_12": layer_keys_llama_mlp_swiglu[0][0],
|
||||
"fused_mlp_3": layer_keys_llama_mlp_swiglu[1][0],
|
||||
})
|
||||
|
||||
@@ -134,11 +134,15 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
layer_idx: int,
|
||||
has_norm: bool = True,
|
||||
has_residual: bool = True,
|
||||
sliding_window: int = 0
|
||||
sliding_window: int = 0,
|
||||
archparams = None
|
||||
):
|
||||
super().__init__(model, key)
|
||||
super().__init__(model, key, archparams)
|
||||
|
||||
cfg = self.model.config
|
||||
ap = self.archparams
|
||||
km = self.archparams.keys
|
||||
|
||||
self.is_tp = False
|
||||
self.tp_dq_size = None
|
||||
|
||||
@@ -151,27 +155,28 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
|
||||
hidden_size = cfg.hidden_size
|
||||
|
||||
if self.has_norm:
|
||||
if cfg.arch.norm == "layernorm":
|
||||
self.pre_layernorm = ExLlamaV2LayerNorm(model, key + cfg.arch.norm_key_1)
|
||||
self.post_layernorm = ExLlamaV2LayerNorm(model, key + cfg.arch.norm_key_1_post) if cfg.arch.norm_key_1_post else None
|
||||
elif cfg.arch.norm == "rmsnorm":
|
||||
self.pre_layernorm = ExLlamaV2RMSNorm(model, key + cfg.arch.norm_key_1)
|
||||
self.post_layernorm = ExLlamaV2RMSNorm(model, key + cfg.arch.norm_key_1_post) if cfg.arch.norm_key_1_post else None
|
||||
if self.has_norm and (km["norm_1"] or km["norm_1_post"]):
|
||||
if ap.norm == "layernorm":
|
||||
self.pre_layernorm = ExLlamaV2LayerNorm(model, key + km["norm_1"], archparams)
|
||||
self.post_layernorm = ExLlamaV2LayerNorm(model, key + km["norm_1_post"], archparams) if km["norm_1_post"] else None
|
||||
elif ap.norm == "rmsnorm":
|
||||
self.pre_layernorm = ExLlamaV2RMSNorm(model, key + km["norm_1"], archparams)
|
||||
self.post_layernorm = ExLlamaV2RMSNorm(model, key + km["norm_1_post"], archparams) if km["norm_1_post"] else None
|
||||
else:
|
||||
self.pre_layernorm = None
|
||||
self.post_layernorm = None
|
||||
self.has_norm = False
|
||||
|
||||
f_a = 0
|
||||
f_b = cfg.num_attention_heads * cfg.head_dim
|
||||
f_c = f_b + cfg.num_key_value_heads * cfg.head_dim
|
||||
f_d = f_c + cfg.num_key_value_heads * cfg.head_dim
|
||||
f_key = (key + ".self_attn." + cfg.arch.fused_qkv_key) if cfg.arch.fused_qkv_key else None
|
||||
f_key = (key + ".self_attn." + km["fused_qkv"]) if km["fused_qkv"] else None
|
||||
|
||||
self.q_proj = ExLlamaV2Linear(model, key + ".self_attn.q_proj", hidden_size, cfg.num_attention_heads * cfg.head_dim, cfg.arch.attention_bias_qkv, f_key = f_key, f_beg = f_a, f_end = f_b, altpack_qkv = cfg.arch.fused_qkv_altpack)
|
||||
self.k_proj = ExLlamaV2Linear(model, key + ".self_attn.k_proj", hidden_size, cfg.num_key_value_heads * cfg.head_dim, cfg.arch.attention_bias_qkv, f_key = f_key, f_beg = f_b, f_end = f_c, altpack_qkv = cfg.arch.fused_qkv_altpack)
|
||||
self.v_proj = ExLlamaV2Linear(model, key + ".self_attn.v_proj", hidden_size, cfg.num_key_value_heads * cfg.head_dim, cfg.arch.attention_bias_qkv, f_key = f_key, f_beg = f_c, f_end = f_d, altpack_qkv = cfg.arch.fused_qkv_altpack)
|
||||
self.o_proj = ExLlamaV2Linear(model, key + ".self_attn.o_proj", cfg.num_attention_heads * cfg.head_dim, hidden_size, cfg.arch.attention_bias_o, prescale = cfg.scale_depth)
|
||||
self.q_proj = ExLlamaV2Linear(model, key + ".self_attn.q_proj", hidden_size, cfg.num_attention_heads * cfg.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_a, f_end = f_b, altpack_qkv = ap.fused_qkv_altpack)
|
||||
self.k_proj = ExLlamaV2Linear(model, key + ".self_attn.k_proj", hidden_size, cfg.num_key_value_heads * cfg.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_b, f_end = f_c, altpack_qkv = ap.fused_qkv_altpack)
|
||||
self.v_proj = ExLlamaV2Linear(model, key + ".self_attn.v_proj", hidden_size, cfg.num_key_value_heads * cfg.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_c, f_end = f_d, altpack_qkv = ap.fused_qkv_altpack)
|
||||
self.o_proj = ExLlamaV2Linear(model, key + ".self_attn.o_proj", cfg.num_attention_heads * cfg.head_dim, hidden_size, ap.attention_bias_o, prescale = cfg.scale_depth)
|
||||
|
||||
if cfg.use_qk_norm:
|
||||
self.q_norm = ExLlamaV2HeadNorm(model, key + ".self_attn.q_norm", cfg.num_attention_heads, cfg.head_dim)
|
||||
@@ -180,10 +185,12 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
self.q_norm = None
|
||||
self.k_norm = None
|
||||
|
||||
self.submodules = [self.q_proj,
|
||||
self.k_proj,
|
||||
self.v_proj,
|
||||
self.o_proj]
|
||||
self.submodules = [
|
||||
self.q_proj,
|
||||
self.k_proj,
|
||||
self.v_proj,
|
||||
self.o_proj
|
||||
]
|
||||
if self.pre_layernorm:
|
||||
self.submodules += [self.pre_layernorm]
|
||||
if self.post_layernorm:
|
||||
@@ -294,12 +301,12 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
cfg.head_dim,
|
||||
cfg.max_seq_len,
|
||||
self.has_residual,
|
||||
cfg.arch.rope_style.value,
|
||||
self.archparams.rope_style.value,
|
||||
q_norm,
|
||||
k_norm,
|
||||
post_norm_weight,
|
||||
post_norm_bias,
|
||||
cfg.arch.residual_stream_fp32,
|
||||
self.archparams.residual_stream_fp32,
|
||||
not cfg.no_graphs
|
||||
)
|
||||
|
||||
@@ -522,7 +529,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
if cfg.use_qk_norm:
|
||||
q = self.q_norm.forward(q)
|
||||
k = self.k_norm.forward(k)
|
||||
if cfg.arch.rope_style != RopeStyle.NONE:
|
||||
if self.archparams.rope_style != RopeStyle.NONE:
|
||||
for t, heads in [(q, cfg.num_attention_heads), (k, cfg.num_key_value_heads)]:
|
||||
ext_c.rope_(
|
||||
t,
|
||||
@@ -532,7 +539,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
heads,
|
||||
cfg.head_dim,
|
||||
cache_seqlens,
|
||||
cfg.arch.rope_style == RopeStyle.NEOX
|
||||
self.archparams.rope_style == RopeStyle.NEOX
|
||||
)
|
||||
if attn_params.is_sequential:
|
||||
k_ = k_cache_f[:, attn_params.first_index : attn_params.first_index + q_len, :, :]
|
||||
@@ -659,7 +666,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
self.v_proj.q_handle,
|
||||
self.o_proj.q_handle,
|
||||
cfg.head_dim,
|
||||
int(cfg.arch.rope_style),
|
||||
int(self.archparams.rope_style),
|
||||
batch_size,
|
||||
q_len,
|
||||
sin,
|
||||
@@ -728,7 +735,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
assert False, "TP not implemented for QK norm" # TODO: ...
|
||||
# q = self.q_norm.forward(q)
|
||||
# k = self.k_norm.forward(k)
|
||||
if cfg.arch.rope_style != RopeStyle.NONE:
|
||||
if self.archparams.rope_style != RopeStyle.NONE:
|
||||
for idx, (dev, a, b) in enumerate(split):
|
||||
context = self.model.get_device_context(dev)
|
||||
torch.cuda.set_stream(context.stream)
|
||||
@@ -741,7 +748,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
(b - a) * heads,
|
||||
cfg.head_dim,
|
||||
attn_params.cache_seqlens_tp[idx],
|
||||
cfg.arch.rope_style == RopeStyle.NEOX
|
||||
self.archparams.rope_style == RopeStyle.NEOX
|
||||
)
|
||||
if attn_params.is_sequential:
|
||||
k_ = [x[:, attn_params.first_index: attn_params.first_index + q_len, :, :] for x in k_cache_f]
|
||||
@@ -1113,7 +1120,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
pass_lora_temp
|
||||
)
|
||||
|
||||
if cfg.arch.clamp_hidden_states:
|
||||
if self.archparams.clamp_hidden_states:
|
||||
hidden_states.clamp_(-65504, 65504)
|
||||
|
||||
return hidden_states
|
||||
@@ -1166,7 +1173,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
self.v_proj.q_handle,
|
||||
self.o_proj.q_handle,
|
||||
cfg.head_dim,
|
||||
int(cfg.arch.rope_style),
|
||||
int(self.archparams.rope_style),
|
||||
batch_size,
|
||||
q_len,
|
||||
sin,
|
||||
@@ -1223,7 +1230,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
else:
|
||||
k_cache, v_cache = None, None
|
||||
|
||||
if cfg.arch.rope_style != RopeStyle.NONE:
|
||||
if self.archparams.rope_style != RopeStyle.NONE:
|
||||
for idx, (dev, a, b) in enumerate(split):
|
||||
context = self.model.get_device_context(dev)
|
||||
torch.cuda.set_stream(context.stream)
|
||||
@@ -1236,7 +1243,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
(b - a) * heads,
|
||||
cfg.head_dim,
|
||||
attn_params.position_offsets_tp[idx] if attn_params.position_offsets is not None else none_tensor,
|
||||
cfg.arch.rope_style == RopeStyle.NEOX
|
||||
self.archparams.rope_style == RopeStyle.NEOX
|
||||
)
|
||||
|
||||
attn_outputs = []
|
||||
@@ -1370,9 +1377,9 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
else:
|
||||
position_offsets = none_tensor
|
||||
|
||||
if cfg.arch.rope_style != RopeStyle.NONE:
|
||||
ext_c.rope_(query_states, constants.sin, constants.cos, past_len, num_attention_heads, head_dim, position_offsets, cfg.arch.rope_style == RopeStyle.NEOX)
|
||||
ext_c.rope_(key_states, constants.sin, constants.cos, past_len, num_key_value_heads, head_dim, position_offsets, cfg.arch.rope_style == RopeStyle.NEOX)
|
||||
if self.archparams.rope_style != RopeStyle.NONE:
|
||||
ext_c.rope_(query_states, constants.sin, constants.cos, past_len, num_attention_heads, head_dim, position_offsets, self.archparams.rope_style == RopeStyle.NEOX)
|
||||
ext_c.rope_(key_states, constants.sin, constants.cos, past_len, num_key_value_heads, head_dim, position_offsets, self.archparams.rope_style == RopeStyle.NEOX)
|
||||
|
||||
# Add keys and values to cache
|
||||
|
||||
@@ -1417,15 +1424,15 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
# Post layernorm
|
||||
|
||||
if self.post_layernorm:
|
||||
attn_proj = self.post_layernorm.forward(attn_proj, output_fp32 = cfg.arch.residual_stream_fp32)
|
||||
attn_proj = self.post_layernorm.forward(attn_proj, output_fp32 = self.archparams.residual_stream_fp32)
|
||||
|
||||
# Add residual connection
|
||||
|
||||
hidden_states = (attn_proj + residual) if self.has_residual else attn_proj
|
||||
|
||||
if cfg.arch.residual_stream_fp32:
|
||||
if self.archparams.residual_stream_fp32:
|
||||
hidden_states = hidden_states.float()
|
||||
elif cfg.arch.clamp_hidden_states:
|
||||
elif self.archparams.clamp_hidden_states:
|
||||
hidden_states.clamp_(-65504, 65504)
|
||||
|
||||
if intermediates:
|
||||
|
||||
@@ -10,12 +10,21 @@ from typing import Any, Dict, List, TypeVar, Union, cast
|
||||
T = TypeVar('T')
|
||||
no_default = object()
|
||||
|
||||
def read(input_dict: dict[str, Any], expected_type: type | list[type], keys: str | list[str], default = no_default) -> T:
|
||||
def read(
|
||||
input_dict: dict[str, Any],
|
||||
expected_type: type | list[type],
|
||||
keys: str | list[str],
|
||||
default = no_default,
|
||||
opt_subkey: str | None = None
|
||||
) -> T:
|
||||
|
||||
expected_types = expected_type if isinstance(expected_type, list) else [expected_type]
|
||||
|
||||
if isinstance(keys, str): keys = [keys]
|
||||
|
||||
if opt_subkey is not None:
|
||||
keys = keys + [opt_subkey + "->" + k for k in keys]
|
||||
|
||||
for key in keys:
|
||||
input_dict_s = input_dict
|
||||
|
||||
@@ -115,6 +124,13 @@ class ExLlamaV2Config:
|
||||
yarn_rope_original_max_position_embeddings: int | None
|
||||
checkpoint_fused_mlp: bool
|
||||
checkpoint_offset_qzeros: bool
|
||||
vision_model_type: str | None
|
||||
vision_head_dim: int | None
|
||||
vision_hidden_size: int | None
|
||||
vision_hidden_act: int | None
|
||||
vision_patch_size: int | None
|
||||
vision_rope_theta: float | None
|
||||
vision_feature_layer: int | None
|
||||
|
||||
# Deprecated fields, kept for compatibiltiy
|
||||
|
||||
@@ -204,36 +220,51 @@ class ExLlamaV2Config:
|
||||
self.bos_token_id = read(read_config, int, "bos_token_id", None) # 1
|
||||
self.eos_token_id = read(read_config, [int, list], "eos_token_id", None) # 2
|
||||
self.pad_token_id = read(read_config, int, "pad_token_id", None) # 0
|
||||
self.vocab_size = read(read_config, int, "vocab_size")
|
||||
self.vocab_size = read(read_config, int, "vocab_size", opt_subkey = "text_config")
|
||||
|
||||
if isinstance(self.eos_token_id, list):
|
||||
self.eos_token_id = self.eos_token_id[0] # TODO: Figure out a way to maybe use all the EOS tokens somehow
|
||||
|
||||
# Standard params
|
||||
|
||||
self.initializer_range = read(read_config, float, ["initializer_range"])
|
||||
self.num_hidden_layers = read(read_config, int, ["num_hidden_layers", "n_layers", "n_layer"])
|
||||
self.initializer_range = read(read_config, float, ["initializer_range"], 0.02)
|
||||
self.num_hidden_layers = read(read_config, int, ["num_hidden_layers", "n_layers", "n_layer"], opt_subkey = "text_config")
|
||||
|
||||
# Norm params
|
||||
|
||||
if self.arch.norm_eps_key:
|
||||
self.norm_eps = read(read_config, float, self.arch.norm_eps_key)
|
||||
if self.arch.lm.keys["norm_eps"]:
|
||||
self.norm_eps = read(read_config, float, self.arch.lm.keys["norm_eps"], opt_subkey = "text_config")
|
||||
else:
|
||||
self.norm_eps = 1e-5 # Torch default
|
||||
|
||||
# Model dimensions
|
||||
|
||||
self.hidden_size = read(read_config, int, ["hidden_size", "d_model", "n_embd"])
|
||||
self.hidden_size = read(read_config, int, ["hidden_size", "d_model", "n_embd"], opt_subkey = "text_config")
|
||||
|
||||
# Attn params
|
||||
|
||||
self.num_attention_heads = read(read_config, int, ["num_attention_heads", "n_heads", "n_head"])
|
||||
self.head_dim = read(read_config, int, "head_dim", self.hidden_size // self.num_attention_heads)
|
||||
self.num_attention_heads = read(read_config, int, ["num_attention_heads", "n_heads", "n_head"], 0, opt_subkey = "text_config")
|
||||
self.head_dim = read(
|
||||
read_config,
|
||||
int,
|
||||
"head_dim",
|
||||
(self.hidden_size // self.num_attention_heads) if self.num_attention_heads else no_default,
|
||||
opt_subkey = "text_config"
|
||||
)
|
||||
|
||||
if self.arch.mqa:
|
||||
if not self.num_attention_heads:
|
||||
self.num_attention_heads = self.hidden_size // self.head_dim
|
||||
|
||||
if self.arch.lm.mqa:
|
||||
self.num_key_value_heads = 1
|
||||
else:
|
||||
self.num_key_value_heads = read(read_config, int, ["num_key_value_heads", "attn_config->kv_n_heads"], self.num_attention_heads)
|
||||
self.num_key_value_heads = read(
|
||||
read_config,
|
||||
int,
|
||||
["num_key_value_heads", "attn_config->kv_n_heads"],
|
||||
self.num_attention_heads,
|
||||
opt_subkey = "text_config",
|
||||
)
|
||||
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
|
||||
self.use_qk_norm = read(read_config, bool, ["use_qk_norm"], False)
|
||||
|
||||
@@ -241,19 +272,25 @@ class ExLlamaV2Config:
|
||||
|
||||
# MLP params
|
||||
|
||||
if self.arch.default_inner_dim_mult is not None:
|
||||
default_intermediate_size = self.arch.default_inner_dim_mult * self.hidden_size
|
||||
if self.arch.lm.default_inner_dim_mult is not None:
|
||||
default_intermediate_size = self.arch.lm.default_inner_dim_mult * self.hidden_size
|
||||
else:
|
||||
default_intermediate_size = no_default
|
||||
|
||||
self.intermediate_size = read(read_config, int, ["intermediate_size", "ffn_config->ffn_hidden_size", "n_inner"], default_intermediate_size)
|
||||
self.intermediate_size = read(
|
||||
read_config,
|
||||
int,
|
||||
["intermediate_size", "ffn_config->ffn_hidden_size", "n_inner"],
|
||||
default_intermediate_size,
|
||||
opt_subkey = "text_config",
|
||||
)
|
||||
self.num_experts = read(read_config, int, ["num_local_experts", "ffn_config->moe_num_experts"], None)
|
||||
self.num_experts_per_token = read(read_config, int,["num_experts_per_tok", "ffn_config->moe_top_k"], None)
|
||||
|
||||
# Logit/embedding/residual scale
|
||||
|
||||
self.logit_scale = read(read_config, float, "logit_scale", 1)
|
||||
if self.arch.logit_scale_basedim:
|
||||
if self.arch.lm.logit_scale_basedim:
|
||||
dim_model_base = read(read_config, int, "dim_model_base", self.hidden_size)
|
||||
self.logit_scale /= (self.hidden_size / dim_model_base)
|
||||
|
||||
@@ -273,16 +310,24 @@ class ExLlamaV2Config:
|
||||
|
||||
# Positional embeddings
|
||||
|
||||
self.rotary_embedding_base = read(read_config, float, ["rope_theta", "attn_config->rope_theta"], 10000.0)
|
||||
self.rotary_embedding_base = read(
|
||||
read_config,
|
||||
float,
|
||||
["rope_theta", "attn_config->rope_theta"],
|
||||
10000.0,
|
||||
opt_subkey = "text_config",
|
||||
)
|
||||
|
||||
self.max_seq_len = read(read_config, int,["max_sequence_length",
|
||||
"model_max_length",
|
||||
"max_position_embeddings",
|
||||
"max_seq_len",
|
||||
"n_positions"], 2048)
|
||||
self.max_seq_len = read(
|
||||
read_config,
|
||||
int,
|
||||
["max_sequence_length", "model_max_length", "max_position_embeddings", "max_seq_len", "n_positions"],
|
||||
2048,
|
||||
opt_subkey = "text_config"
|
||||
)
|
||||
self.original_max_seq_len = self.max_seq_len
|
||||
|
||||
self.sliding_window = read(read_config, int, ["sliding_window", "sliding_window_size"], 0)
|
||||
self.sliding_window = read(read_config, int, ["sliding_window", "sliding_window_size"], 0, opt_subkey = "text_config")
|
||||
|
||||
rs = read(read_config, dict, "rope_scaling", None)
|
||||
if rs:
|
||||
@@ -345,37 +390,51 @@ class ExLlamaV2Config:
|
||||
|
||||
# Make sure we found all the layers we need
|
||||
|
||||
expect_keys = self.arch.expect_keys.copy()
|
||||
def check_keys(archparams, prefix):
|
||||
|
||||
if not self.num_experts or self.num_experts == 1:
|
||||
per_layer_keys = self.arch.layer_keys
|
||||
else:
|
||||
per_layer_keys = set()
|
||||
for expert_idx in range(self.num_experts):
|
||||
for k in self.arch.layer_keys:
|
||||
skt = [sk.replace(".*.", f".{expert_idx}.") for sk in k]
|
||||
per_layer_keys.add(tuple(skt))
|
||||
per_layer_keys = list(per_layer_keys)
|
||||
expect_keys = archparams.expect_keys.copy()
|
||||
|
||||
for layer_idx in range(self.num_hidden_layers):
|
||||
for ks in per_layer_keys:
|
||||
prefixes = [f"model.layers.{layer_idx}.{k}" for k in ks]
|
||||
expect_keys.append(prefixes)
|
||||
if not self.num_experts or self.num_experts == 1:
|
||||
per_layer_keys = archparams.layer_keys
|
||||
else:
|
||||
per_layer_keys = set()
|
||||
for expert_idx in range(self.num_experts):
|
||||
for k in archparams.layer_keys:
|
||||
skt = [sk.replace(".*.", f".{expert_idx}.") for sk in k]
|
||||
per_layer_keys.add(tuple(skt))
|
||||
per_layer_keys = list(per_layer_keys)
|
||||
|
||||
all_keys = set(self.tensor_file_map.keys())
|
||||
suffixes = [".q_weight", ".qweight", ".weight", ""]
|
||||
for layer_idx in range(self.num_hidden_layers):
|
||||
for ks in per_layer_keys:
|
||||
prefixes = [f"model.layers.{layer_idx}.{k}" for k in ks]
|
||||
expect_keys.append(prefixes)
|
||||
|
||||
for prefixes in expect_keys:
|
||||
match = False
|
||||
for prefix in prefixes:
|
||||
for suffix in suffixes:
|
||||
if (prefix + suffix) in all_keys:
|
||||
match = True
|
||||
break
|
||||
if self.arch.lm_prefix:
|
||||
expect_keys = [
|
||||
[prefix + k for k in k2]
|
||||
for k2 in expect_keys
|
||||
]
|
||||
|
||||
all_keys = set(self.tensor_file_map.keys())
|
||||
suffixes = [".q_weight", ".qweight", ".weight", ""]
|
||||
|
||||
for prefixes in expect_keys:
|
||||
match = False
|
||||
for prefix in prefixes:
|
||||
for suffix in suffixes:
|
||||
if (prefix + suffix) in all_keys:
|
||||
match = True
|
||||
break
|
||||
if match: break
|
||||
if match: break
|
||||
if match: break
|
||||
if not match:
|
||||
raise ValueError(f" ## Could not find {prefix}.* in model")
|
||||
if not match:
|
||||
raise ValueError(f" ## Could not find {prefix}.* in model")
|
||||
|
||||
check_keys(self.arch.lm, self.arch.lm_prefix)
|
||||
check_keys(self.arch.mmp, self.arch.mmp_prefix)
|
||||
check_keys(self.arch.vt, self.arch.vt_prefix)
|
||||
|
||||
# Cleanup
|
||||
|
||||
cleanup_stfiles()
|
||||
|
||||
@@ -391,7 +450,7 @@ class ExLlamaV2Config:
|
||||
|
||||
warnings = []
|
||||
|
||||
if self.arch.eager_attn_only:
|
||||
if self.arch.lm.eager_attn_only:
|
||||
warnings.append(" !! Warning: Architecture currently supports only eager attention")
|
||||
if not warn_only:
|
||||
warnings.append(" !! Warning: flash-attn, xformers and SDPA are disabled")
|
||||
@@ -406,7 +465,7 @@ class ExLlamaV2Config:
|
||||
if self.attn_logit_softcapping and not has_flash_attn_with_softcap:
|
||||
warnings.append(" !! Warning: model requires softcap, not supported in installed version of flash-attn")
|
||||
disable = True
|
||||
if (self.arch.swa or self.arch.alternating_swa) and not has_flash_attn_with_window:
|
||||
if (self.arch.lm.swa or self.arch.lm.alternating_swa) and not has_flash_attn_with_window:
|
||||
warnings.append(" !! Warning: model requires SWA, not supported in installed version of flash-attn")
|
||||
disable = True
|
||||
if disable and not warn_only:
|
||||
@@ -418,7 +477,7 @@ class ExLlamaV2Config:
|
||||
if self.attn_logit_softcapping:
|
||||
warnings.append(" !! Warning: model requires softcap, not supported in xformers")
|
||||
disable = True
|
||||
if self.arch.swa or self.arch.alternating_swa:
|
||||
if self.arch.lm.swa or self.arch.lm.alternating_swa:
|
||||
warnings.append(" !! Warning: model requires SWA, not supported in xformers")
|
||||
disable = True
|
||||
if disable and not warn_only:
|
||||
|
||||
@@ -90,7 +90,7 @@ def compile_model(job, save_fn, model):
|
||||
|
||||
if isinstance(module, ExLlamaV2MLP):
|
||||
|
||||
has_gate = model.config.arch.mlp_gate
|
||||
has_gate = model.config.arch.lm.mlp_gate
|
||||
d = get_f_module(job, module.pre_layernorm)
|
||||
if d: out_dict.update(d); current_size += _dsize(d)
|
||||
d = get_f_module(job, module.post_layernorm)
|
||||
@@ -110,7 +110,7 @@ def compile_model(job, save_fn, model):
|
||||
|
||||
if isinstance(module, ExLlamaV2ParallelDecoder):
|
||||
|
||||
has_gate = model.config.arch.mlp_gate
|
||||
has_gate = model.config.arch.lm.mlp_gate
|
||||
has_qk_norm = model.config.use_qk_norm
|
||||
d = get_f_module(job, module.input_layernorm); out_dict.update(d); current_size += _dsize(d)
|
||||
d = get_q_module(job, module.attn.q_proj); out_dict.update(d); current_size += _dsize(d)
|
||||
|
||||
@@ -208,7 +208,7 @@ model.load(lazy = True)
|
||||
|
||||
# Limit context length if necessary
|
||||
|
||||
if model.config.arch.rope_style == RopeStyle.NONE:
|
||||
if model.config.arch.lm.rope_style == RopeStyle.NONE:
|
||||
max_ctx = model.config.max_seq_len
|
||||
if job["length"] > max_ctx:
|
||||
print (f" !! Warning: Reducing calibration length to model max context: {max_ctx}")
|
||||
|
||||
@@ -204,7 +204,7 @@ def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_p
|
||||
|
||||
def measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_params, reuse_h_up_proj = None):
|
||||
|
||||
has_gate = module.model.config.arch.mlp_gate
|
||||
has_gate = module.model.config.arch.lm.mlp_gate
|
||||
|
||||
qjobs, qmaps = get_qparams_reduced(qparams_mlp, not has_gate)
|
||||
results = []
|
||||
@@ -490,7 +490,7 @@ def measure_quant(job, save_fn, model, hidden_state_offload_layers):
|
||||
|
||||
elif isinstance(module, ExLlamaV2MLP):
|
||||
mode = "mlp"
|
||||
has_gate = module.model.config.arch.mlp_gate
|
||||
has_gate = module.model.config.arch.lm.mlp_gate
|
||||
if has_gate: quantizers["gate_proj"] = AdaptiveGPTQ(module.gate_proj.linear)
|
||||
quantizers["up_proj"] = AdaptiveGPTQ(module.up_proj.linear)
|
||||
quantizers["down_proj"] = AdaptiveGPTQ(module.down_proj.linear)
|
||||
@@ -508,7 +508,7 @@ def measure_quant(job, save_fn, model, hidden_state_offload_layers):
|
||||
quantizers["k_proj"] = AdaptiveGPTQ(module.attn.k_proj.linear)
|
||||
quantizers["v_proj"] = AdaptiveGPTQ(module.attn.v_proj.linear)
|
||||
quantizers["o_proj"] = AdaptiveGPTQ(module.attn.o_proj.linear)
|
||||
has_gate = module.model.config.arch.mlp_gate
|
||||
has_gate = module.model.config.arch.lm.mlp_gate
|
||||
if has_gate: quantizers["gate_proj"] = AdaptiveGPTQ(module.mlp.gate_proj.linear)
|
||||
quantizers["up_proj"] = AdaptiveGPTQ(module.mlp.up_proj.linear)
|
||||
quantizers["down_proj"] = AdaptiveGPTQ(module.mlp.down_proj.linear)
|
||||
|
||||
@@ -9,7 +9,7 @@ def optimize(job, save_fn, model):
|
||||
|
||||
cfg = model.config
|
||||
|
||||
has_gate = cfg.arch.mlp_gate
|
||||
has_gate = cfg.arch.lm.mlp_gate
|
||||
if has_gate: mlp_key_gate = cfg.arch.mlp_key_gate
|
||||
mlp_key_up = cfg.arch.mlp_key_up
|
||||
mlp_key_down = cfg.arch.mlp_key_down
|
||||
@@ -38,7 +38,7 @@ def optimize(job, save_fn, model):
|
||||
key_v = key + ".self_attn.v_proj"
|
||||
key_o = key + ".self_attn.o_proj"
|
||||
|
||||
if not cfg.arch.is_moe:
|
||||
if not cfg.arch.lm.is_moe:
|
||||
if has_gate: key_g = key + mlp_key_gate
|
||||
key_u = key + mlp_key_up
|
||||
key_d = key + mlp_key_down
|
||||
@@ -83,7 +83,7 @@ def optimize(job, save_fn, model):
|
||||
params = []
|
||||
|
||||
for i in range(num_layers):
|
||||
if cfg.arch.parallel_decoder_blocks:
|
||||
if cfg.arch.lm.parallel_decoder_blocks:
|
||||
m1 = measurement["model.layers." + str(i) + ".parallel_decoder"]["attn"]
|
||||
m2 = measurement["model.layers." + str(i) + ".parallel_decoder"]["mlp"]
|
||||
else:
|
||||
|
||||
@@ -153,7 +153,7 @@ def quant_attn(job, module, hidden_states, target_states, quantizers, attn_param
|
||||
|
||||
def quant_mlp(job, module, hidden_states, target_states, quantizers, attn_params, strat, reuse_h_up_proj = None):
|
||||
|
||||
has_mlp = module.model.config.arch.mlp_gate
|
||||
has_mlp = module.model.config.arch.lm.mlp_gate
|
||||
|
||||
if reuse_h_up_proj is not None:
|
||||
quantizers["up_proj"].reuse_h(quantizers[reuse_h_up_proj])
|
||||
@@ -311,7 +311,7 @@ def quant(job, save_fn, model):
|
||||
|
||||
elif isinstance(module, ExLlamaV2MLP):
|
||||
mode = "mlp"
|
||||
has_mlp = model.config.arch.mlp_gate
|
||||
has_mlp = model.config.arch.lm.mlp_gate
|
||||
# testc(module, hidden_states, hidden_i_states, module.post_attention_layernorm, [module.gate_proj, module.up_proj])
|
||||
if has_mlp: quantizers["gate_proj"] = AdaptiveGPTQ(module.gate_proj.linear)
|
||||
quantizers["up_proj"] = AdaptiveGPTQ(module.up_proj.linear)
|
||||
@@ -341,7 +341,7 @@ def quant(job, save_fn, model):
|
||||
quantizers["k_proj"] = AdaptiveGPTQ(module.attn.k_proj.linear)
|
||||
quantizers["v_proj"] = AdaptiveGPTQ(module.attn.v_proj.linear)
|
||||
quantizers["o_proj"] = AdaptiveGPTQ(module.attn.o_proj.linear)
|
||||
has_gate = module.model.config.arch.mlp_gate
|
||||
has_gate = module.model.config.arch.lm.mlp_gate
|
||||
if has_gate: quantizers["gate_proj"] = AdaptiveGPTQ(module.mlp.gate_proj.linear)
|
||||
quantizers["up_proj"] = AdaptiveGPTQ(module.mlp.up_proj.linear)
|
||||
quantizers["down_proj"] = AdaptiveGPTQ(module.mlp.down_proj.linear)
|
||||
|
||||
@@ -50,7 +50,8 @@ class ExLlamaV2DeviceContext:
|
||||
self,
|
||||
model: ExLlamaV2,
|
||||
device_idx: int,
|
||||
scratch_bytes: int
|
||||
scratch_bytes: int,
|
||||
archparams = None
|
||||
):
|
||||
self.model = model
|
||||
self.device_idx = device_idx
|
||||
@@ -58,6 +59,7 @@ class ExLlamaV2DeviceContext:
|
||||
self.scratch = None
|
||||
self.scratch_bytes = scratch_bytes
|
||||
self.scratch_idx = 0
|
||||
self.archparams = archparams or model.config.arch.lm
|
||||
|
||||
# Create streams (only one per device)
|
||||
|
||||
@@ -140,7 +142,7 @@ class ExLlamaV2DeviceContext:
|
||||
device = _torch_device(self.device_idx)
|
||||
|
||||
cfg = self.model.config
|
||||
if cfg.arch.rope_style == RopeStyle.NONE:
|
||||
if self.archparams.rope_style == RopeStyle.NONE:
|
||||
self.sin = torch.zeros((1,), device = device, dtype = torch.half)
|
||||
self.cos = self.sin
|
||||
return
|
||||
@@ -254,9 +256,9 @@ class ExLlamaV2DeviceContext:
|
||||
if scale != 1.0: t /= scale
|
||||
|
||||
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
||||
if cfg.arch.rope_style == RopeStyle.NEOX:
|
||||
if self.archparams.rope_style == RopeStyle.NEOX:
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
elif cfg.arch.rope_style == RopeStyle.GPTJ:
|
||||
elif self.archparams.rope_style == RopeStyle.GPTJ:
|
||||
emb = torch.repeat_interleave(freqs, 2, dim=-1)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
@@ -23,9 +23,10 @@ class ExLlamaV2Embedding(ExLlamaV2Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: ExLlamaV2,
|
||||
key: str
|
||||
key: str,
|
||||
archparams = None
|
||||
):
|
||||
super().__init__(model, key)
|
||||
super().__init__(model, key, archparams)
|
||||
|
||||
self.is_tp = False
|
||||
|
||||
@@ -156,9 +157,9 @@ class ExLlamaV2Embedding(ExLlamaV2Module):
|
||||
|
||||
# Normalization
|
||||
|
||||
if cfg.arch.residual_stream_fp32:
|
||||
if self.archparams.residual_stream_fp32:
|
||||
combined_embeddings = combined_embeddings.float()
|
||||
if cfg.arch.normalize_embeddings:
|
||||
if self.archparams.normalize_embeddings:
|
||||
combined_embeddings *= cfg.hidden_size ** 0.5
|
||||
|
||||
# Extract indexed embeddings and insert in-place
|
||||
@@ -179,9 +180,9 @@ class ExLlamaV2Embedding(ExLlamaV2Module):
|
||||
else:
|
||||
hidden_states = self.embedding(hidden_states)
|
||||
|
||||
if cfg.arch.residual_stream_fp32:
|
||||
if self.archparams.residual_stream_fp32:
|
||||
hidden_states = hidden_states.float()
|
||||
if cfg.arch.normalize_embeddings:
|
||||
if self.archparams.normalize_embeddings:
|
||||
hidden_states *= cfg.hidden_size ** 0.5
|
||||
|
||||
# Move to pinned temp buffer for TP
|
||||
|
||||
@@ -26,9 +26,10 @@ class ExLlamaV2HeadNorm(ExLlamaV2Module):
|
||||
model: ExLlamaV2,
|
||||
key: str,
|
||||
num_heads: int,
|
||||
head_dim: int
|
||||
head_dim: int,
|
||||
archparams = None
|
||||
):
|
||||
super().__init__(model, key)
|
||||
super().__init__(model, key, archparams)
|
||||
|
||||
self.layernorm = None
|
||||
self.weight = None
|
||||
|
||||
@@ -21,9 +21,10 @@ class ExLlamaV2LayerNorm(ExLlamaV2Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: ExLlamaV2,
|
||||
key: str
|
||||
key: str,
|
||||
archparams = None
|
||||
):
|
||||
super().__init__(model, key)
|
||||
super().__init__(model, key, archparams)
|
||||
|
||||
self.layernorm = None
|
||||
self.weight = None
|
||||
|
||||
@@ -67,9 +67,10 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
f_end: int = None,
|
||||
is_sub_module: bool = True,
|
||||
altpack_qkv: bool = False,
|
||||
normalize_unq: bool = False
|
||||
normalize_unq: bool = False,
|
||||
archparams = None
|
||||
):
|
||||
super().__init__(model, key)
|
||||
super().__init__(model, key, archparams)
|
||||
|
||||
self.is_sub_module = is_sub_module
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import os, json
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
from torch import load as load_file
|
||||
import torch
|
||||
import math
|
||||
from exllamav2.compat import safe_move_tensor
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -58,7 +59,7 @@ class ExLlamaV2Lora:
|
||||
|
||||
# Compatibility check
|
||||
|
||||
assert not self.model.config.arch.residual_stream_fp32, \
|
||||
assert not self.model.config.arch.lm.residual_stream_fp32, \
|
||||
"LoRAs not (yet) supported for models with FP32 residual stream"
|
||||
|
||||
# Grab relevant items from LoRA config
|
||||
@@ -88,7 +89,7 @@ class ExLlamaV2Lora:
|
||||
tensor = f[key]
|
||||
|
||||
# Find target
|
||||
if key.endswith(f'{self.config.arch.lm_head_key}.weight'):
|
||||
if key.endswith(f'{self.config.arch.lm.keys["lm_head"]}.weight'):
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
tensor = tensor.to(torch.float16)
|
||||
elif tensor.dtype == torch.float32:
|
||||
|
||||
117
exllamav2/mlp.py
117
exllamav2/mlp.py
@@ -38,15 +38,29 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
tp_dq_size: list[int] | None
|
||||
|
||||
|
||||
def __init__(self,
|
||||
model: ExLlamaV2,
|
||||
key: str,
|
||||
layer_idx: int,
|
||||
has_norm: bool = True,
|
||||
has_residual: bool = True):
|
||||
|
||||
super().__init__(model, key)
|
||||
def __init__(
|
||||
self,
|
||||
model: ExLlamaV2,
|
||||
key: str,
|
||||
layer_idx: int,
|
||||
has_norm: bool = True,
|
||||
has_residual: bool = True,
|
||||
archparams = None,
|
||||
in_features: int | None = None,
|
||||
out_features: int | None = None,
|
||||
interm_features: int | None = None,
|
||||
):
|
||||
super().__init__(model, key, archparams)
|
||||
cfg = self.model.config
|
||||
ap = self.archparams
|
||||
km = self.archparams.keys
|
||||
|
||||
if in_features is None: in_features = cfg.hidden_size
|
||||
if out_features is None: out_features = cfg.hidden_size
|
||||
if interm_features is None: interm_features = cfg.intermediate_size
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.interm_features = interm_features
|
||||
|
||||
self.is_tp = False
|
||||
self.tp_dq_size = None
|
||||
@@ -59,23 +73,23 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
self.temp_lora_size = 0
|
||||
|
||||
f_a = 0
|
||||
f_b = cfg.intermediate_size
|
||||
f_c = f_b + cfg.intermediate_size
|
||||
f_key = (key + ".mlp." + cfg.arch.fused_mlp_key_12) if cfg.arch.fused_mlp_key_12 else None
|
||||
f_b = interm_features
|
||||
f_c = f_b + interm_features
|
||||
f_key = (key + ".mlp." + km["fused_mlp_12"]) if km["fused_mlp_12"] else None
|
||||
|
||||
if self.has_norm:
|
||||
if cfg.arch.norm == "layernorm":
|
||||
self.pre_layernorm = ExLlamaV2LayerNorm(model, key + cfg.arch.norm_key_2)
|
||||
self.post_layernorm = ExLlamaV2LayerNorm(model, key + cfg.arch.norm_key_2_post) if cfg.arch.norm_key_2_post else None
|
||||
elif cfg.arch.norm == "rmsnorm":
|
||||
self.pre_layernorm = ExLlamaV2RMSNorm(model, key + cfg.arch.norm_key_2)
|
||||
self.post_layernorm = ExLlamaV2RMSNorm(model, key + cfg.arch.norm_key_2_post) if cfg.arch.norm_key_2_post else None
|
||||
if self.has_norm and (km["norm_2"] or km["norm_2_post"]):
|
||||
if ap.norm == "layernorm":
|
||||
self.pre_layernorm = ExLlamaV2LayerNorm(model, key + km["norm_2"])
|
||||
self.post_layernorm = ExLlamaV2LayerNorm(model, key + km["norm_2_post"]) if km["norm_2_post"] else None
|
||||
elif ap.norm == "rmsnorm":
|
||||
self.pre_layernorm = ExLlamaV2RMSNorm(model, key + km["norm_2"])
|
||||
self.post_layernorm = ExLlamaV2RMSNorm(model, key + km["norm_2_post"]) if km["norm_2_post"] else None
|
||||
else:
|
||||
self.pre_layernorm = None
|
||||
self.post_layernorm = None
|
||||
|
||||
self.up_proj = ExLlamaV2Linear(model, key + cfg.arch.mlp_key_up, cfg.hidden_size, cfg.intermediate_size, self.model.config.arch.mlp_bias, f_key = f_key, f_beg = f_b, f_end = f_c)
|
||||
self.down_proj = ExLlamaV2Linear(model, key + cfg.arch.mlp_key_down, cfg.intermediate_size, cfg.hidden_size, self.model.config.arch.mlp_bias, prescale = cfg.scale_depth)
|
||||
self.up_proj = ExLlamaV2Linear(model, key + km["mlp_up"], in_features, interm_features, ap.mlp_bias, f_key = f_key, f_beg = f_b, f_end = f_c)
|
||||
self.down_proj = ExLlamaV2Linear(model, key + km["mlp_down"], interm_features, out_features, ap.mlp_bias, prescale = cfg.scale_depth)
|
||||
|
||||
self.submodules = [self.up_proj,
|
||||
self.down_proj]
|
||||
@@ -84,8 +98,8 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
if self.post_layernorm:
|
||||
self.submodules += [self.post_layernorm]
|
||||
|
||||
if cfg.arch.mlp_gate:
|
||||
self.gate_proj = ExLlamaV2Linear(model, key + cfg.arch.mlp_key_gate, cfg.hidden_size, cfg.intermediate_size, self.model.config.arch.mlp_bias, f_key = f_key, f_beg = f_a, f_end = f_b)
|
||||
if ap.mlp_gate:
|
||||
self.gate_proj = ExLlamaV2Linear(model, key + km["mlp_gate"], in_features, interm_features, ap.mlp_bias, f_key = f_key, f_beg = f_a, f_end = f_b)
|
||||
self.submodules += [self.gate_proj]
|
||||
else:
|
||||
self.gate_proj = None
|
||||
@@ -96,7 +110,7 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
numel = self.up_proj.numel() + \
|
||||
self.down_proj.numel()
|
||||
|
||||
if self.model.config.arch.mlp_gate:
|
||||
if self.archparams.arch.mlp_gate:
|
||||
numel += self.gate_proj.numel()
|
||||
|
||||
if self.pre_layernorm is not None:
|
||||
@@ -113,6 +127,7 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
device_context: bool = True
|
||||
):
|
||||
cfg = self.model.config
|
||||
km = self.archparams.keys
|
||||
|
||||
if self.pre_layernorm is not None:
|
||||
self.pre_layernorm.load()
|
||||
@@ -120,10 +135,10 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
self.post_layernorm.load()
|
||||
|
||||
if cfg.checkpoint_fused_mlp:
|
||||
w12 = self.load_weight(self.key + cfg.arch.fused_mlp_key_12)
|
||||
w1 = nn.Parameter(w12[:cfg.intermediate_size, :].contiguous())
|
||||
w2 = nn.Parameter(w12[cfg.intermediate_size:, :].contiguous())
|
||||
w3 = self.load_weight(self.key + cfg.arch.fused_mlp_key_3)
|
||||
w12 = self.load_weight(self.key + km["fused_mlp_12"])
|
||||
w1 = nn.Parameter(w12[:self.interm_features, :].contiguous())
|
||||
w2 = nn.Parameter(w12[self.interm_features:, :].contiguous())
|
||||
w3 = self.load_weight(self.key + km["fused_mlp_3"])
|
||||
self.down_proj.load(w3, device_context = device_context)
|
||||
self.gate_proj.load(w1, device_context = device_context)
|
||||
self.up_proj.load(w2, device_context = device_context)
|
||||
@@ -180,11 +195,11 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
temp_b,
|
||||
temp_dq,
|
||||
cfg.max_input_len * cfg.max_batch_size,
|
||||
cfg.arch.mlp_act_func == "gelu",
|
||||
self.archparams.mlp_act_func == "gelu",
|
||||
self.has_residual,
|
||||
post_norm_weight,
|
||||
post_norm_bias,
|
||||
cfg.arch.residual_stream_fp32,
|
||||
self.archparams.residual_stream_fp32,
|
||||
not cfg.no_graphs
|
||||
)
|
||||
|
||||
@@ -205,7 +220,8 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
def weight_footprint(self) -> int:
|
||||
|
||||
if self.model.config.checkpoint_fused_mlp:
|
||||
fp = 3 * self.model.config.intermediate_size * self.model.config.hidden_size * 2
|
||||
fp = 2 * self.in_features * self.interm_features * 2 + \
|
||||
self.interm_features * self.out_features * 2
|
||||
else:
|
||||
fp = self.up_proj.weight_footprint() + \
|
||||
self.down_proj.weight_footprint()
|
||||
@@ -230,30 +246,31 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
|
||||
def scratch_space(self) -> int:
|
||||
|
||||
cfg = self.model.config
|
||||
assert cfg.intermediate_size >= cfg.hidden_size
|
||||
return self.temp_state_size() + \
|
||||
self.temp_a_size() + \
|
||||
self.temp_b_size() + \
|
||||
self.temp_dq_size()
|
||||
assert self.interm_features >= self.in_features and self.interm_features >= self.out_features
|
||||
return (
|
||||
self.temp_state_size() +
|
||||
self.temp_a_size() +
|
||||
self.temp_b_size() +
|
||||
self.temp_dq_size()
|
||||
)
|
||||
|
||||
|
||||
def temp_state_size(self) -> int:
|
||||
|
||||
cfg = self.model.config
|
||||
return cfg.max_input_len * cfg.max_batch_size * cfg.hidden_size * 2 + 128
|
||||
return cfg.max_input_len * cfg.max_batch_size * self.out_features * 2 + 128
|
||||
|
||||
|
||||
def temp_a_size(self) -> int:
|
||||
|
||||
cfg = self.model.config
|
||||
return cfg.max_input_len * cfg.max_batch_size * cfg.intermediate_size * 2 + 128
|
||||
return cfg.max_input_len * cfg.max_batch_size * self.interm_features * 2 + 128
|
||||
|
||||
|
||||
def temp_b_size(self) -> int:
|
||||
|
||||
cfg = self.model.config
|
||||
return cfg.max_input_len * cfg.max_batch_size * cfg.intermediate_size * 2 + 128
|
||||
return cfg.max_input_len * cfg.max_batch_size * self.interm_features * 2 + 128
|
||||
|
||||
|
||||
def temp_dq_size(self) -> int:
|
||||
@@ -316,7 +333,7 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
pass_loras,
|
||||
pass_lora_temp)
|
||||
|
||||
if cfg.arch.clamp_hidden_states:
|
||||
if self.archparams.clamp_hidden_states:
|
||||
hidden_states.clamp_(-65504, 65504)
|
||||
|
||||
return hidden_states
|
||||
@@ -354,7 +371,7 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
self.gate_proj.q_handle if self.gate_proj is not None else [],
|
||||
self.up_proj.q_handle,
|
||||
self.down_proj.q_handle,
|
||||
cfg.arch.mlp_act_func == "gelu"
|
||||
self.archparams.mlp_act_func == "gelu"
|
||||
)
|
||||
|
||||
return ctx.get_pinned(0, batch_size, q_len, cfg.hidden_size)
|
||||
@@ -391,9 +408,9 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
context = self.model.get_device_context(dev)
|
||||
torch.cuda.set_stream(context.stream)
|
||||
|
||||
if cfg.arch.mlp_act_func == "silu":
|
||||
if self.archparams.mlp_act_func == "silu":
|
||||
output = F.silu(gate[idx])
|
||||
elif cfg.arch.mlp_act_func == "gelu":
|
||||
elif self.archparams.mlp_act_func == "gelu":
|
||||
output = F.gelu(gate[idx], approximate = "tanh")
|
||||
output *= up[idx]
|
||||
# output.clamp_(min = -65504.0, max = 65504.0)
|
||||
@@ -430,28 +447,28 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
|
||||
if self.gate_proj is not None:
|
||||
gate = self.gate_proj.forward(post_norm, loras = loras)
|
||||
if cfg.arch.mlp_act_func == "silu":
|
||||
if self.archparams.mlp_act_func == "silu":
|
||||
y = F.silu(gate)
|
||||
elif cfg.arch.mlp_act_func == "gelu":
|
||||
elif self.archparams.mlp_act_func == "gelu":
|
||||
y = F.gelu(gate, approximate = "tanh")
|
||||
up = self.up_proj.forward(post_norm, loras = loras)
|
||||
y *= up
|
||||
y.clamp_(min = -65504.0, max = 65504.0)
|
||||
else:
|
||||
up = self.up_proj.forward(post_norm, loras = loras)
|
||||
if cfg.arch.mlp_act_func == "silu":
|
||||
if self.archparams.mlp_act_func == "silu":
|
||||
y = F.silu(up)
|
||||
elif cfg.arch.mlp_act_func == "gelu":
|
||||
elif self.archparams.mlp_act_func == "gelu":
|
||||
y = F.gelu(up, approximate = "tanh")
|
||||
|
||||
down = self.down_proj.forward(y, loras = loras)
|
||||
if self.post_layernorm:
|
||||
down = self.post_layernorm.forward(down, output_fp32 = cfg.arch.residual_stream_fp32)
|
||||
down = self.post_layernorm.forward(down, output_fp32 = self.archparams.residual_stream_fp32)
|
||||
hidden_states = down + residual if self.has_residual else down
|
||||
|
||||
if cfg.arch.residual_stream_fp32:
|
||||
if self.archparams.residual_stream_fp32:
|
||||
hidden_states = hidden_states.float()
|
||||
elif cfg.arch.clamp_hidden_states:
|
||||
elif self.archparams.clamp_hidden_states:
|
||||
hidden_states = hidden_states.clamp(-65504, 65504)
|
||||
|
||||
if intermediates:
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import os, sys
|
||||
|
||||
from exllamav2.architecture import RopeStyle
|
||||
|
||||
min_version = (3, 8)
|
||||
if sys.version_info < min_version:
|
||||
print("")
|
||||
@@ -49,6 +47,7 @@ from exllamav2.compat import safe_move_tensor
|
||||
from exllamav2.stloader import cleanup_stfiles
|
||||
from exllamav2.device import ExLlamaV2DeviceContext, set_device_streams
|
||||
from exllamav2.tensor_p import TPContext, BROADCAST_VC
|
||||
from exllamav2.architecture import RopeStyle
|
||||
import gc
|
||||
import threading
|
||||
from typing import Callable
|
||||
@@ -76,7 +75,12 @@ class ExLlamaV2:
|
||||
|
||||
tp_context: TPContext | None
|
||||
|
||||
def __init__(self, config: ExLlamaV2Config, lazy_load = False):
|
||||
def __init__(
|
||||
self,
|
||||
config: ExLlamaV2Config,
|
||||
lazy_load = False,
|
||||
archparams = None
|
||||
):
|
||||
|
||||
self.config = config
|
||||
self.modules = []
|
||||
@@ -88,47 +92,57 @@ class ExLlamaV2:
|
||||
|
||||
# Build model
|
||||
|
||||
emb = ExLlamaV2Embedding(self, "model.embed_tokens")
|
||||
cfg = self.config
|
||||
if archparams is None: archparams = cfg.arch.lm
|
||||
self.archparams = archparams
|
||||
|
||||
emb = ExLlamaV2Embedding(self, cfg.arch.lm_prefix + "model.embed_tokens")
|
||||
self.modules += [emb]
|
||||
|
||||
if self.config.arch.learned_pos_emb_key:
|
||||
pos_emb = ExLlamaV2PosEmbedding(self, self.config.arch.learned_pos_emb_key)
|
||||
if archparams.keys["learned_pos_emb"]:
|
||||
pos_emb = ExLlamaV2PosEmbedding(self, archparams.keys["learned_pos_emb"])
|
||||
self.modules += [pos_emb]
|
||||
|
||||
for layer_idx in range(self.config.num_hidden_layers):
|
||||
for layer_idx in range(cfg.num_hidden_layers):
|
||||
|
||||
layer_key = f"model.layers.{layer_idx}"
|
||||
if self.config.arch.parallel_decoder_blocks:
|
||||
layer_key = cfg.arch.lm_prefix + f"model.layers.{layer_idx}"
|
||||
if cfg.arch.lm.parallel_decoder_blocks:
|
||||
pd = ExLlamaV2ParallelDecoder(self, layer_key, layer_idx)
|
||||
self.modules += [pd]
|
||||
else:
|
||||
if self.config.arch.alternating_swa:
|
||||
swa = self.config.sliding_window if not bool(layer_idx % 2) else 0
|
||||
elif self.config.arch.swa:
|
||||
swa = self.config.sliding_window
|
||||
if cfg.arch.lm.alternating_swa:
|
||||
swa = cfg.sliding_window if not bool(layer_idx % 2) else 0
|
||||
elif cfg.arch.lm.swa:
|
||||
swa = cfg.sliding_window
|
||||
else:
|
||||
swa = 0
|
||||
attn = ExLlamaV2Attention(self, layer_key, layer_idx, sliding_window = swa)
|
||||
if self.config.arch.is_moe: mlp = ExLlamaV2MoEMLP(self, layer_key, layer_idx)
|
||||
if cfg.arch.lm.is_moe: mlp = ExLlamaV2MoEMLP(self, layer_key, layer_idx)
|
||||
else: mlp = ExLlamaV2MLP(self, layer_key, layer_idx)
|
||||
self.modules += [attn, mlp]
|
||||
|
||||
if self.config.arch.norm == "layernorm": norm = ExLlamaV2LayerNorm(self, "model.norm")
|
||||
elif self.config.arch.norm == "rmsnorm": norm = ExLlamaV2RMSNorm(self, "model.norm")
|
||||
else: raise ValueError("unknown norm type")
|
||||
if cfg.arch.lm.norm == "layernorm":
|
||||
norm = ExLlamaV2LayerNorm(self, cfg.arch.lm_prefix + "model.norm")
|
||||
elif cfg.arch.lm.norm == "rmsnorm":
|
||||
norm = ExLlamaV2RMSNorm(self, cfg.arch.lm_prefix + "model.norm")
|
||||
else:
|
||||
raise ValueError("unknown norm type")
|
||||
self.modules += [norm]
|
||||
|
||||
self.head_layer_idx = len(self.modules)
|
||||
head = ExLlamaV2Linear(self, "lm_head",
|
||||
self.config.hidden_size,
|
||||
self.config.vocab_size,
|
||||
False,
|
||||
max_out_len = self.config.max_output_len,
|
||||
prescale = self.config.logit_scale,
|
||||
is_sub_module = False,
|
||||
normalize_unq = bool(self.config.norm_head))
|
||||
if self.config.arch.lm_head_key != "lm_head":
|
||||
head.alt_key = self.config.arch.lm_head_key
|
||||
head = ExLlamaV2Linear(
|
||||
self,
|
||||
cfg.arch.lm_prefix + "lm_head",
|
||||
cfg.hidden_size,
|
||||
cfg.vocab_size,
|
||||
False,
|
||||
max_out_len = cfg.max_output_len,
|
||||
prescale = cfg.logit_scale,
|
||||
is_sub_module = False,
|
||||
normalize_unq = bool(cfg.norm_head)
|
||||
)
|
||||
if archparams.keys["lm_head"] != "lm_head":
|
||||
head.alt_key = archparams.keys["lm_head"]
|
||||
self.modules += [head]
|
||||
|
||||
# Compile dictionary of modules
|
||||
@@ -157,17 +171,21 @@ class ExLlamaV2:
|
||||
embed_cpu: bool = True
|
||||
) -> list[float]:
|
||||
|
||||
cfg = self.config
|
||||
self.cache_map = {}
|
||||
|
||||
# Constant shared between layers
|
||||
|
||||
sincos_size = self.config.head_dim * self.config.max_seq_len * 2
|
||||
constant_size = sincos_size * 2
|
||||
if self.archparams.rope_style != RopeStyle.NONE:
|
||||
sincos_size = cfg.head_dim * cfg.max_seq_len * 2
|
||||
constant_size = sincos_size * 2
|
||||
else:
|
||||
constant_size = 0
|
||||
|
||||
# Max size of hidden state
|
||||
|
||||
state_size = self.config.hidden_size * self.config.max_input_len * self.config.max_batch_size * 2
|
||||
mask_size = self.config.max_input_len ** 2 * self.config.max_batch_size * 2
|
||||
state_size = cfg.hidden_size * cfg.max_input_len * cfg.max_batch_size * 2
|
||||
mask_size = cfg.max_input_len ** 2 * cfg.max_batch_size * 2
|
||||
|
||||
# Bytes remaining per device
|
||||
|
||||
@@ -184,7 +202,7 @@ class ExLlamaV2:
|
||||
|
||||
# Special case for token embeddings on CPU
|
||||
|
||||
if idx == 0 and embed_cpu:
|
||||
if isinstance(module, ExLlamaV2Embedding) and embed_cpu:
|
||||
|
||||
module.set_device_idx(-1)
|
||||
continue
|
||||
@@ -245,6 +263,7 @@ class ExLlamaV2:
|
||||
callback: Callable[[int, int], None] | None = None,
|
||||
callback_gen: Callable[[int, int], None] | None = None,
|
||||
progress: bool = False
|
||||
progress: bool = False,
|
||||
):
|
||||
"""
|
||||
Load model, regular manual split mode.
|
||||
|
||||
@@ -25,9 +25,16 @@ class ExLlamaV2Module:
|
||||
submodules: list[ExLlamaV2Module]
|
||||
assumed_footprint: int
|
||||
|
||||
def __init__(self,
|
||||
model: ExLlamaV2,
|
||||
key: str):
|
||||
def __init__(
|
||||
self,
|
||||
model: ExLlamaV2,
|
||||
key: str,
|
||||
archparams = None,
|
||||
):
|
||||
|
||||
if archparams is None:
|
||||
archparams = model.config.arch.lm
|
||||
self.archparams = archparams
|
||||
|
||||
self.model = model
|
||||
self.key = key
|
||||
|
||||
@@ -29,13 +29,18 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module):
|
||||
|
||||
temp_lora_size: int
|
||||
|
||||
def __init__(self,
|
||||
model: ExLlamaV2,
|
||||
key: str,
|
||||
layer_idx: int):
|
||||
super().__init__(model, key)
|
||||
def __init__(
|
||||
self,
|
||||
model: ExLlamaV2,
|
||||
key: str,
|
||||
layer_idx: int,
|
||||
archparams = None
|
||||
):
|
||||
super().__init__(model, key, archparams)
|
||||
|
||||
cfg = self.model.config
|
||||
ap = self.archparams
|
||||
km = self.archparams.keys
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
@@ -47,19 +52,19 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module):
|
||||
self.num_experts = cfg.num_experts
|
||||
self.num_experts_per_token = cfg.num_experts_per_token
|
||||
|
||||
if cfg.arch.norm == "layernorm":
|
||||
self.post_attention_layernorm = ExLlamaV2LayerNorm(model, key + self.model.config.arch.norm_key_2)
|
||||
elif cfg.arch.norm == "rmsnorm":
|
||||
self.post_attention_layernorm = ExLlamaV2RMSNorm(model, key + self.model.config.arch.norm_key_2)
|
||||
if ap.norm == "layernorm":
|
||||
self.post_attention_layernorm = ExLlamaV2LayerNorm(model, key + km["norm_2"])
|
||||
elif ap.norm == "rmsnorm":
|
||||
self.post_attention_layernorm = ExLlamaV2RMSNorm(model, key + km["norm_2"])
|
||||
|
||||
w1_key = key + cfg.arch.mlp_key_gate
|
||||
w2_key = key + cfg.arch.mlp_key_down
|
||||
w3_key = key + cfg.arch.mlp_key_up
|
||||
w1_key = key + km["mlp_gate"]
|
||||
w2_key = key + km["mlp_down"]
|
||||
w3_key = key + km["mlp_up"]
|
||||
w1_f_key = w1_key.replace(".*.", ".")
|
||||
w2_f_key = w2_key.replace(".*.", ".")
|
||||
w3_f_key = w3_key.replace(".*.", ".")
|
||||
|
||||
gate_key = cfg.arch.mlp_key_expert_gate
|
||||
gate_key = km["mlp_expert_gate"]
|
||||
|
||||
self.w1 = []
|
||||
self.w2 = []
|
||||
@@ -72,9 +77,9 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module):
|
||||
# ad = bd
|
||||
bu += intermediate_size
|
||||
# bd += hidden_size
|
||||
w1 = ExLlamaV2Linear(model, w1_key.replace("*", str(e)), hidden_size, intermediate_size, cfg.arch.mlp_bias, f_key = w1_f_key, f_beg = au, f_end = bu)
|
||||
w2 = ExLlamaV2Linear(model, w2_key.replace("*", str(e)), intermediate_size, hidden_size, cfg.arch.mlp_bias, f_key = w2_f_key, f_beg = au, f_end = bu)
|
||||
w3 = ExLlamaV2Linear(model, w3_key.replace("*", str(e)), hidden_size, intermediate_size, cfg.arch.mlp_bias, f_key = w3_f_key, f_beg = au, f_end = bu)
|
||||
w1 = ExLlamaV2Linear(model, w1_key.replace("*", str(e)), hidden_size, intermediate_size, ap.mlp_bias, f_key = w1_f_key, f_beg = au, f_end = bu)
|
||||
w2 = ExLlamaV2Linear(model, w2_key.replace("*", str(e)), intermediate_size, hidden_size, ap.mlp_bias, f_key = w2_f_key, f_beg = au, f_end = bu)
|
||||
w3 = ExLlamaV2Linear(model, w3_key.replace("*", str(e)), hidden_size, intermediate_size, ap.mlp_bias, f_key = w3_f_key, f_beg = au, f_end = bu)
|
||||
self.w1.append(w1)
|
||||
self.w2.append(w2)
|
||||
self.w3.append(w3)
|
||||
@@ -124,7 +129,7 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module):
|
||||
device_context.get_scratch_slice(self.temp_logit_size()),
|
||||
device_context.get_scratch_slice(self.temp_dq_size()),
|
||||
self.model.config.max_input_len * self.model.config.max_batch_size,
|
||||
self.model.config.arch.mlp_act_func == "gelu"
|
||||
self.archparams.mlp_act_func == "gelu"
|
||||
)
|
||||
|
||||
|
||||
@@ -296,9 +301,9 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module):
|
||||
gate = self.w1[expert_idx].forward(current_state, loras = loras)
|
||||
up = self.w3[expert_idx].forward(current_state, loras = loras)
|
||||
|
||||
if self.model.config.arch.mlp_act_func == "silu":
|
||||
if self.archparams.mlp_act_func == "silu":
|
||||
current_hidden_states = F.silu(gate)
|
||||
elif self.model.config.arch.mlp_act_func == "gelu":
|
||||
elif self.archparams.mlp_act_func == "gelu":
|
||||
current_hidden_states = F.gelu(gate)
|
||||
current_hidden_states *= up
|
||||
if intermediates: result[f"pre_down.{expert_idx}"] = current_hidden_states
|
||||
|
||||
@@ -24,18 +24,23 @@ class ExLlamaV2ParallelDecoder(ExLlamaV2Module):
|
||||
attn: ExLlamaV2Attention
|
||||
mlp: ExLlamaV2MLP
|
||||
|
||||
def __init__(self,
|
||||
model: ExLlamaV2,
|
||||
key: str,
|
||||
layer_idx: int):
|
||||
super().__init__(model, key)
|
||||
def __init__(
|
||||
self,
|
||||
model: ExLlamaV2,
|
||||
key: str,
|
||||
layer_idx: int,
|
||||
archparams = None
|
||||
):
|
||||
super().__init__(model, key, archparams)
|
||||
|
||||
cfg = self.model.config
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
if self.model.config.arch.norm == "layernorm":
|
||||
self.input_layernorm = ExLlamaV2LayerNorm(model, key + self.model.config.arch.norm_key_1)
|
||||
elif self.model.config.arch.norm == "rmsnorm":
|
||||
self.input_layernorm = ExLlamaV2RMSNorm(model, key + self.model.config.arch.norm_key_1)
|
||||
if self.archparams.norm == "layernorm":
|
||||
self.input_layernorm = ExLlamaV2LayerNorm(model, key + self.archparams.keys["norm_1"])
|
||||
elif self.archparams.norm == "rmsnorm":
|
||||
self.input_layernorm = ExLlamaV2RMSNorm(model, key + self.archparams.keys["norm_1"])
|
||||
|
||||
self.attn = ExLlamaV2Attention(model, key, layer_idx, has_norm = False, has_residual = False)
|
||||
self.mlp = ExLlamaV2MLP(model, key, layer_idx, has_norm = False, has_residual = False)
|
||||
|
||||
@@ -20,8 +20,13 @@ class ExLlamaV2RMSNorm(ExLlamaV2Module):
|
||||
is_tp: bool
|
||||
broadcast_type: int | None
|
||||
|
||||
def __init__(self, model, key):
|
||||
super().__init__(model, key)
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
key,
|
||||
archparams = None
|
||||
):
|
||||
super().__init__(model, key, archparams)
|
||||
|
||||
self.is_tp = False
|
||||
self.broadcast_type = None
|
||||
@@ -50,8 +55,8 @@ class ExLlamaV2RMSNorm(ExLlamaV2Module):
|
||||
self.variance_epsilon = self.model.config.norm_eps
|
||||
|
||||
# Gemma adds 1 to the norm tensor for some reason
|
||||
if self.model.config.arch.norm_constant_bias != 0:
|
||||
self.weight += self.model.config.arch.norm_constant_bias
|
||||
if self.archparams.norm_constant_bias != 0:
|
||||
self.weight += self.archparams.norm_constant_bias
|
||||
|
||||
|
||||
def unload(self):
|
||||
@@ -74,8 +79,8 @@ class ExLlamaV2RMSNorm(ExLlamaV2Module):
|
||||
def get_weight(self) -> torch.Tensor:
|
||||
|
||||
# Make sure to return the original weight tensor for Gemma
|
||||
if self.model.config.arch.norm_constant_bias != 0:
|
||||
return self.weight.data - self.model.config.arch.norm_constant_bias
|
||||
if self.archparams.norm_constant_bias != 0:
|
||||
return self.weight.data - self.archparams.norm_constant_bias
|
||||
|
||||
return self.weight.data
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ class TPContext:
|
||||
self.model = model
|
||||
cfg = self.model.config
|
||||
|
||||
assert cfg.arch.supports_tp, \
|
||||
assert cfg.arch.lm.supports_tp, \
|
||||
f"Tensor-parallel is not supported for {cfg.arch.arch_string}"
|
||||
assert cfg.intermediate_size % 128 == 0, \
|
||||
"Model intermediate size must be divisible by 128"
|
||||
|
||||
@@ -283,7 +283,7 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
eval_len = [eval_tokens.shape[1]] * eval_tokens.shape[0]
|
||||
|
||||
# if args.eval_bos:
|
||||
if model.config.arch.requires_bos:
|
||||
if model.config.arch.lm.requires_bos:
|
||||
boss = torch.full((eval_tokens.shape[0], 1), tokenizer.bos_token_id, dtype = torch.long)
|
||||
eval_tokens = torch.cat((boss, eval_tokens[:, :-1]), dim = 1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user