llama4: WIP

This commit is contained in:
Iwan Kawrakow
2025-04-09 09:56:45 +03:00
parent 5f44f4b3d0
commit 7a1fc34023
3 changed files with 281 additions and 19 deletions

View File

@@ -100,7 +100,12 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 23, //llama.cpp lists this as 28
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, //llama.cpp lists this as 28
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
};
// note: these values should be synchronized with ggml_rope

View File

@@ -439,6 +439,27 @@ struct llm_tokenizer_bpe {
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
case LLAMA_VOCAB_PRE_TYPE_GPT4O:
regex_exprs = {
// original regex from tokenizer.json
// "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
case LLAMA_VOCAB_PRE_TYPE_SUPERBPE:
regex_exprs = {
"\\p{N}+",
"(?=(\\d{3})+(?!\\d))",
};
break;
case LLAMA_VOCAB_PRE_TYPE_BAILINGMOE:
regex_exprs = {
// original regex from tokenizer.json
// "'(?i:[sdmt]|ll|ve|re)|[^\\r\\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]++[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+"
// FIXME? Changed possessive quantifiers (?+ and ++) to greedy to avoid errors and imatrix hanging (tried atomic grouping but it's not supported?)
"'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
};
break;
default:
// default regex for BPE tokenization pre-processing
regex_exprs = {

View File

@@ -184,6 +184,7 @@ static std::string format(const char * fmt, ...) {
enum llm_arch {
LLM_ARCH_LLAMA,
LLM_ARCH_LLAMA4,
LLM_ARCH_FALCON,
LLM_ARCH_BAICHUAN,
LLM_ARCH_GROK,
@@ -232,6 +233,7 @@ enum llm_arch {
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_LLAMA, "llama" },
{ LLM_ARCH_LLAMA4, "llama4" },
{ LLM_ARCH_FALCON, "falcon" },
{ LLM_ARCH_GROK, "grok" },
{ LLM_ARCH_GPT2, "gpt2" },
@@ -319,6 +321,7 @@ enum llm_kv {
LLM_KV_TIME_DECAY_EXTRA_DIM,
LLM_KV_RESIDUAL_SCALE,
LLM_KV_EMBEDDING_SCALE,
LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -614,6 +617,35 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_LLAMA4,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_BAICHUAN,
{
@@ -1432,6 +1464,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_GRANITE,
LLM_CHAT_TEMPLATE_GIGACHAT,
LLM_CHAT_TEMPLATE_MEGREZ,
LLM_CHAT_TEMPLATE_LLAMA4,
LLM_CHAT_TEMPLATE_UNKNOWN,
};
@@ -1467,6 +1500,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
};
@@ -2357,6 +2391,8 @@ enum e_model {
MODEL_10B_128x3_66B,
MODEL_57B_A14B,
MODEL_27B,
MODEL_17B_16E,
MODEL_17B_128E,
};
static const size_t kiB = 1024;
@@ -2444,6 +2480,14 @@ struct llama_hparams {
bool use_alibi = false;
bool attn_soft_cap = false;
uint32_t n_moe_layer_step = 0;
bool use_kq_norm = true;
uint32_t n_attn_chunk = 0;
// values below seems to be fixed on llama4
uint32_t n_no_rope_layer_step = 4;
uint32_t n_attn_temp_floor_scale = 8192;
float f_attn_temp_scale = 0.1;
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = -1;
@@ -4945,6 +4989,8 @@ static const char * llama_model_type_name(e_model type) {
case MODEL_10B_128x3_66B: return "10B+128x3.66B";
case MODEL_57B_A14B: return "57B.A14B";
case MODEL_27B: return "27B";
case MODEL_17B_16E: return "17Bx16E (Scout)";
case MODEL_17B_128E: return "17Bx128E (Maverick)";
default: return "?B";
}
}
@@ -5106,6 +5152,25 @@ static void llm_load_hparams(
}
}
} break;
case LLM_ARCH_LLAMA4:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full
hparams.n_attn_chunk = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
hparams.n_swa = 1; // TODO @ngxson : this is added to trigger the SWA branch (we store the chunked attn mask in the SWA tensor), will need to clean this up later
switch (hparams.n_expert) {
case 16: model.type = MODEL_17B_16E; break;
case 128: model.type = MODEL_17B_128E; break;
default: model.type = MODEL_UNKNOWN;
}
if (model.type == MODEL_17B_128E) {
hparams.use_kq_norm = false;
}
} break;
case LLM_ARCH_MINICPM:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -5929,6 +5994,23 @@ static void llm_load_vocab(
} else if (
tokenizer_pre == "codeshell") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CODESHELL;
} else if (
tokenizer_pre == "gpt-4o" ||
tokenizer_pre == "llama4") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT4O;
vocab.tokenizer_clean_spaces = false;
} else if (
tokenizer_pre == "superbpe") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SUPERBPE;
vocab.tokenizer_clean_spaces = false;
} else if (
tokenizer_pre == "trillion") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TRILLION;
vocab.tokenizer_clean_spaces = false;
} else if (
tokenizer_pre == "bailingmoe") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
vocab.tokenizer_clean_spaces = false;
} else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
}
@@ -6552,6 +6634,7 @@ static bool llm_load_tensors(
const int64_t n_embd_gqa = n_embd_v_gqa;
const int64_t n_vocab = hparams.n_vocab;
const int64_t n_vocab_type = hparams.n_vocab_type;
const int64_t n_rot = hparams.n_rot;
const int64_t n_expert = hparams.n_expert;
const int64_t n_expert_used = hparams.n_expert_used;
const int64_t n_ctx_train = hparams.n_ctx_train;
@@ -6670,6 +6753,61 @@ static bool llm_load_tensors(
}
}
} break;
case LLM_ARCH_LLAMA4:
{
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab},
llama_model_loader::TENSOR_NOT_REQUIRED);
// output
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
// if output is NULL, init from the input tok embed
if (model.output == NULL) {
model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab},
llama_model_loader::TENSOR_DUPLICATED);
}
GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Llama 4 requires n_moe_layer_step > 0");
for (int i = 0; i < n_layer; ++i) {
bool is_moe_layer = (i + 1) % hparams.n_moe_layer_step == 0;
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
auto & layer = model.layers[i];
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.rope_freqs = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2},
llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
if (is_moe_layer) {
int n_ff_exp = hparams.n_ff_exp;
layer.ffn_gate_inp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
// Shared expert
const int64_t n_ff_shexp = n_ff_exp;
layer.ffn_gate_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0);
layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
} else {
layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
}
} break;
case LLM_ARCH_GROK:
{
if (n_expert == 0) {
@@ -8884,6 +9022,7 @@ llm_expert_gating_func_type gating_op,
int il) {
int64_t n_embd = cur->ne[0];
int64_t n_tokens = cur->ne[1];
bool weight_before_ffn = lctx.model.arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens]
cb(logits, "ffn_moe_logits", il);
@@ -8912,6 +9051,12 @@ llm_expert_gating_func_type gating_op,
cb(selection_probs, "ffn_moe_probs_biased", il);
}
// llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
// see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
if (lctx.model.arch == LLM_ARCH_LLAMA4) {
selection_probs = logits;
}
// select experts
ggml_tensor * selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used,
lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens]
@@ -8940,6 +9085,14 @@ llm_expert_gating_func_type gating_op,
cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
if (weight_before_ffn) {
// TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
ggml_tensor * repeated = ggml_new_tensor_3d(ctx, cur->type, n_embd, n_expert_used, n_tokens);
repeated = ggml_repeat(ctx, cur, repeated); // [n_embd, n_expert_used, n_tokens]
cur = ggml_mul(ctx, repeated, weights);
cb(cur, "ffn_moe_weighted", il);
}
ggml_tensor * par;
if (lctx.cparams.fused_moe_up_gate) {
par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU);
@@ -8958,7 +9111,10 @@ llm_expert_gating_func_type gating_op,
ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
cb(experts, "ffn_moe_down", il);
experts = ggml_mul(ctx, experts, weights);
if (!weight_before_ffn) {
experts = ggml_mul(ctx, experts, weights);
cb(cur, "ffn_moe_weighted", il);
}
if (n_expert_used == 1) {
return ggml_cont(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0));
@@ -9667,14 +9823,26 @@ struct llm_build_context {
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
ggml_tensor * cur;
ggml_tensor * inpL;
ggml_tensor * inp_attn_scale = nullptr;
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = build_inp_pos();
if (model.arch == LLM_ARCH_LLAMA4) {
//inp_attn_scale = build_inp_attn_scale();
inp_attn_scale = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*1); // n_pos_per_token(), which is 4 for LLM_ARCH_QWEN2VL, 1 otherwise
ggml_set_input(inp_attn_scale);
auto scale_data = (float *)inp_attn_scale->data;
auto pos = (const int32_t *)inp_pos->data;
for (int i = 0; i < n_tokens; ++i) {
scale_data[i] = std::log(std::floor((pos[i] + 1.0f) / hparams.n_attn_temp_floor_scale) + 1.0f) * hparams.f_attn_temp_scale + 1.0f;
}
}
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
@@ -9683,6 +9851,8 @@ struct llm_build_context {
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
bool use_rope = model.arch == LLM_ARCH_LLAMA4 ? (il + 1) % hparams.n_no_rope_layer_step != 0 : true;
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
@@ -9720,19 +9890,33 @@ struct llm_build_context {
cb(Vcur, "Vcur", il);
}
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
if (use_rope) {
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
} else if (inp_attn_scale) {
Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
if (model.arch == LLM_ARCH_LLAMA4 && use_rope && hparams.use_kq_norm) {
// Llama4TextL2Norm
Qcur = ggml_rms_norm(ctx0, Qcur, 1e-6);
Kcur = ggml_rms_norm(ctx0, Kcur, 1e-6);
cb(Qcur, "Qcur_normed", il);
cb(Kcur, "Kcur_normed", il);
}
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
@@ -9758,6 +9942,7 @@ struct llm_build_context {
// feed-forward network
if (model.layers[il].ffn_gate_inp == nullptr) {
// non-MoE
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
@@ -9770,6 +9955,37 @@ struct llm_build_context {
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
} else if (model.arch == LLM_ARCH_LLAMA4) {
// llama4 MoE
ggml_tensor * ffn_inp_normed = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
ggml_tensor * moe_out = llm_build_moe_ffn(ctx0, lctx, ffn_inp_normed,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, false,
false, 0.0,
LLM_EXPERT_GATING_FUNC_SIGMOID,
cb, il);
// Shared experts
ggml_tensor * shexp_out = llm_build_ffn(ctx0, lctx, ffn_inp_normed,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(shexp_out, "ffn_moe_shexp", il);
cur = ggml_add(ctx0, moe_out, shexp_out);
cb(cur, "ffn_moe_out_merged", il);
} else {
// MoE branch
cur = llm_build_norm(ctx0, ffn_inp, hparams,
@@ -9782,11 +9998,11 @@ struct llm_build_context {
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
nullptr,
nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, true,
false, 0.0,
LLM_EXPERT_GATING_FUNC_SOFTMAX,
LLM_EXPERT_GATING_FUNC_SOFTMAX,
cb, il);
cb(cur, "ffn_moe_out", il);
}
@@ -15162,6 +15378,7 @@ static struct ggml_cgraph * llama_build_graph(
switch (model.arch) {
case LLM_ARCH_LLAMA:
case LLM_ARCH_LLAMA4:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
{
@@ -15504,8 +15721,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
// may need to cut off old tokens for sliding window
if (data_swa) {
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY;
if (hparams.n_attn_chunk) {
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
f = -INFINITY;
}
} else {
if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY;
}
}
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
}
@@ -18949,6 +19173,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
// use what we call a normal RoPE, operating on pairs of consecutive head values
case LLM_ARCH_LLAMA:
case LLM_ARCH_LLAMA4:
case LLM_ARCH_BAICHUAN:
case LLM_ARCH_STARCODER:
case LLM_ARCH_PLAMO:
@@ -20774,6 +20999,8 @@ static llm_chat_template llama_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_GIGACHAT;
} else if (tmpl_contains("<|role_start|>")) {
return LLM_CHAT_TEMPLATE_MEGREZ;
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
return LLM_CHAT_TEMPLATE_LLAMA4;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
@@ -21155,6 +21382,15 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<|role_start|>assistant<|role_end|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA4) {
// Llama 4
for (auto message : chat) {
std::string role(message->role);
ss << "<|header_start|>" << role << "<|header_end|>\n\n" << trim(message->content) << "<|eot|>";
}
if (add_ass) {
ss << "<|header_start|>assistant<|header_end|>\n\n";
}
} else {
// template not supported
return -1;