model : add grok-2 support (#782)

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2025-09-23 09:31:01 -05:00
committed by GitHub
parent 18f04350e9
commit 8cd2d7ccd7
10 changed files with 271 additions and 82 deletions

View File

@@ -289,6 +289,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
{ LLM_KV_ROUTER_LOGIT_SOFTCAPPING, "%s.router_logit_softcapping" },
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
@@ -309,6 +310,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
@@ -319,6 +322,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
{ LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" },
{ LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" },
{ LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" },
{ LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" },
{ LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" },
{ LLM_KV_SPLIT_NO, "split.no" },
{ LLM_KV_SPLIT_COUNT, "split.count" },
@@ -507,9 +514,13 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
@@ -1558,6 +1569,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
LLM_CHAT_TEMPLATE_KIMI_K2,
LLM_CHAT_TEMPLATE_OPENAI_MOE,
LLM_CHAT_TEMPLATE_GROK_2,
LLM_CHAT_TEMPLATE_UNKNOWN,
};
@@ -1599,6 +1611,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
{ "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE },
{ "bitnet", LLM_CHAT_TEMPLATE_BITNET },
{ "grok-2", LLM_CHAT_TEMPLATE_GROK_2 },
};
// helper to handle gguf constants
@@ -1930,6 +1943,7 @@ struct llama_hparams {
float f_norm_rms_eps;
float f_attn_logit_softcapping = 50.0f;
float f_router_logit_softcapping = 30.0f;
float f_final_logit_softcapping = 30.0f;
float rope_attn_factor = 1.0f;
@@ -1940,6 +1954,11 @@ struct llama_hparams {
uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul;
float yarn_ext_factor = -1.0f;
float yarn_attn_factor = 1.0f;
float yarn_beta_fast = 32.0f;
float yarn_beta_slow = 1.0f;
// for State Space Models
uint32_t ssm_d_conv = 0;
uint32_t ssm_d_inner = 0;
@@ -1955,6 +1974,10 @@ struct llama_hparams {
float f_embedding_scale = 0.0f;
float f_attention_scale = 0.0f;
// grok-2
float f_attn_out_scale = 0.0f;
uint32_t attn_temp_length = 0;
bool causal_attn = true;
bool use_alibi = false;
bool attn_soft_cap = false;
@@ -3644,7 +3667,30 @@ static void llm_load_hparams(
} break;
case LLM_ARCH_GROK:
{
// defaults for old GGUFs
hparams.yarn_beta_fast = 8.0f;
hparams.f_logit_scale = 0.5773502691896257f;
hparams.f_embedding_scale = 78.38367176906169f;
hparams.f_attn_out_scale = 0.08838834764831845f;
hparams.f_attn_logit_softcapping = 30.0f;
hparams.f_router_logit_softcapping = 30.0f;
// no final_logit_softcapping in grok-1
hparams.f_final_logit_softcapping = 0.0f;
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false);
ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false);
ml.get_key(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale, false);
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
ml.get_key(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping, false);
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length, false);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, hparams.yarn_ext_factor, false);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor, false);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
switch (hparams.n_layer) {
case 64: model.type = e_model::MODEL_314B; break;
@@ -5238,7 +5284,7 @@ static bool llm_load_tensors(
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
}
}
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff/* / n_expert_used*/; // grok-1 n_ff_exp == n_ff
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
@@ -5256,12 +5302,16 @@ static bool llm_load_tensors(
layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, TENSOR_NOT_REQUIRED);
layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, TENSOR_NOT_REQUIRED);
layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, TENSOR_NOT_REQUIRED);
layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
if (layer.ffn_gate_exps) {
layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert});
layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert});
layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert});
layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert });
} else {
// merge split expert into a single tensor for compatibility with older models
// requires disabling mmap
@@ -5287,7 +5337,10 @@ static bool llm_load_tensors(
}
}
layer.layer_out_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
layer.ffn_post_norm = create_tensor(ctx_layer,tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
if (!layer.ffn_post_norm) {
layer.ffn_post_norm = create_tensor(ctx_layer,tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), { n_embd }, 0);
}
}
} break;
case LLM_ARCH_DBRX:
@@ -8197,7 +8250,7 @@ static struct ggml_tensor * llm_build_kqv(
if (model.arch == LLM_ARCH_GROK) {
// need to do the following:
// multiply by attn_output_multiplyer of 0.08838834764831845
// multiply by attn_output_multiplier
// and then :
// kq = 30 * tanh(kq / 30)
// before the softmax below
@@ -8208,7 +8261,7 @@ static struct ggml_tensor * llm_build_kqv(
//kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
//kq = ggml_scale(ctx, kq, 30);
kq = ggml_softcap(ctx, kq, 0.08838834764831845f/30.0f, 30.f);
kq = ggml_softcap(ctx, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
}
if (hparams.attn_soft_cap) {
@@ -8254,7 +8307,7 @@ static struct ggml_tensor * llm_build_kqv(
ggml_mul_mat_set_prec(kq_i, GGML_PREC_F32);
}
if (model.arch == LLM_ARCH_GROK) {
kq_i = ggml_softcap(ctx, kq_i, 0.08838834764831845f/30.0f, 30.f);
kq_i = ggml_softcap(ctx, kq_i, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
}
if (hparams.attn_soft_cap) {
kq_i = ggml_softcap_max(ctx, kq_i, kq_mask, kq_scale, hparams.f_max_alibi_bias,
@@ -9553,15 +9606,11 @@ struct llm_build_context {
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
// multiply by embedding_multiplier_scale of 78.38367176906169
inpL = ggml_scale(ctx0, inpL, 78.38367176906169f);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = build_inp_pos();
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
@@ -9623,26 +9672,23 @@ struct llm_build_context {
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
// Grok
// if attn_out_norm is present then apply it before adding the input
if (model.layers[il].attn_out_norm) {
cur = llm_build_norm(ctx0, cur, hparams,
model.layers[il].attn_out_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_out_norm", il);
}
cur = llm_build_norm(ctx0, cur, hparams,
model.layers[il].attn_out_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_out_norm", il);
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
// MoE branch
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
cur = llm_build_moe_ffn(ctx0, lctx, cur,
// MoE branch
ggml_tensor* moe_out = llm_build_moe_ffn(ctx0, lctx, cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
@@ -9653,17 +9699,29 @@ struct llm_build_context {
false, 0.0,
LLM_EXPERT_GATING_FUNC_SOFTMAX,
cb, il, gf);
cb(cur, "ffn_moe_out", il);
cb(moe_out, "ffn_moe_out", il);
// Grok
// if layer_out_norm is present then apply it before adding the input
// Idea: maybe ffn_out_norm is a better name
if (model.layers[il].layer_out_norm) {
cur = llm_build_norm(ctx0, cur, hparams,
model.layers[il].layer_out_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "layer_out_norm", il);
if (model.layers[il].ffn_up) {
ggml_tensor* ffn_out = llm_build_ffn(ctx0, lctx, cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
cb(ffn_out, "ffn_out", il);
cur = ggml_scale(ctx0, ggml_add(ctx0, ffn_out, moe_out), std::sqrt(2) / 2);
cb(cur, "ffn_out", il);
}
else {
cur = moe_out;
}
cur = llm_build_norm(ctx0, cur, hparams,
model.layers[il].ffn_post_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_post_norm", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);
@@ -9685,11 +9743,15 @@ struct llm_build_context {
// lm_head
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
// Grok
// multiply logits by output_multiplier_scale of 0.5773502691896257
cur = ggml_scale(ctx0, cur, 0.5773502691896257f);
cur = ggml_scale(ctx0, cur, hparams.f_logit_scale);
// final logit soft-capping
if (hparams.f_final_logit_softcapping) {
/*cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
cur = ggml_tanh(ctx0, cur);
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);*/
cur = ggml_softcap(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping, hparams.f_final_logit_softcapping);
}
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
@@ -19393,9 +19455,9 @@ struct llama_context_params llama_context_default_params() {
/*.rope_freq_base =*/ 0.0f,
/*.rope_freq_scale =*/ 0.0f,
/*.yarn_ext_factor =*/ -1.0f,
/*.yarn_attn_factor =*/ 1.0f,
/*.yarn_beta_fast =*/ 32.0f,
/*.yarn_beta_slow =*/ 1.0f,
/*.yarn_attn_factor =*/ -1.0f,
/*.yarn_beta_fast =*/ -1.0f,
/*.yarn_beta_slow =*/ -1.0f,
/*.yarn_orig_ctx =*/ 0,
/*.defrag_thold =*/ -1.0f,
/*.cb_eval =*/ nullptr,
@@ -19607,10 +19669,10 @@ struct llama_context * llama_new_context_with_model(
cparams.n_seq_max = std::max(1u, params.n_seq_max);
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
cparams.yarn_attn_factor = params.yarn_attn_factor;
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
cparams.defrag_thold = params.defrag_thold;
cparams.embeddings = params.embeddings;
cparams.offload_kqv = params.offload_kqv;
@@ -21987,6 +22049,8 @@ static llm_chat_template llama_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
} else if (tmpl_contains("<|im_middle|>") && tmpl_contains("<|im_end|>")) {
return LLM_CHAT_TEMPLATE_KIMI_K2;
} else if (tmpl_contains("'Assistant: ' + message['content'] + '<|separator|>")) {
return LLM_CHAT_TEMPLATE_GROK_2;
} else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) {
return LLM_CHAT_TEMPLATE_OPENAI_MOE;
}
@@ -22459,6 +22523,22 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<|start|>assistant";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_GROK_2) {
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "System: " << trim(message->content) << "<|separator|>\n\n";
}
else if (role == "user") {
ss << "Human: " << trim(message->content) << "<|separator|>\n\n";
}
else if (role == "assistant") {
ss << "Assistant: " << message->content << "<|separator|>\n\n";
}
}
if (add_ass) {
ss << "Assistant:";
}
} else {
// template not supported
return -1;