mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-12 06:50:08 +00:00
Merge branch 'main' into s6/dots
This commit is contained in:
@@ -427,6 +427,7 @@ struct llm_tokenizer_bpe {
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
|
||||
case LLAMA_VOCAB_PRE_TYPE_QWEN2:
|
||||
case LLAMA_VOCAB_PRE_TYPE_HUNYUAN:
|
||||
regex_exprs = {
|
||||
// original regex from tokenizer.json
|
||||
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||
@@ -477,6 +478,13 @@ struct llm_tokenizer_bpe {
|
||||
"'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_SEED_CODER:
|
||||
regex_exprs = {
|
||||
// original regex from tokenizer.json
|
||||
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\r\n]+|\\s*[\r\n]+|\\s+(?!\\S)|\\s+"
|
||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
default:
|
||||
// default regex for BPE tokenization pre-processing
|
||||
regex_exprs = {
|
||||
|
||||
339
src/llama.cpp
339
src/llama.cpp
@@ -236,6 +236,7 @@ enum llm_arch {
|
||||
LLM_ARCH_GRANITE_MOE,
|
||||
LLM_ARCH_COHERE2,
|
||||
LLM_ARCH_DOTS1,
|
||||
LLM_ARCH_HUNYUAN_MOE,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -293,6 +294,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
||||
{ LLM_ARCH_COHERE2, "cohere2" },
|
||||
{ LLM_ARCH_DOTS1, "dots1" },
|
||||
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@@ -1624,6 +1626,28 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
||||
}
|
||||
},
|
||||
LLM_ARCH_HUNYUAN_MOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
@@ -1669,6 +1693,7 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_LLAMA4,
|
||||
LLM_CHAT_TEMPLATE_BITNET,
|
||||
LLM_CHAT_TEMPLATE_DOTS1,
|
||||
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
|
||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -1706,6 +1731,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
|
||||
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
|
||||
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
||||
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
|
||||
{ "bitnet", LLM_CHAT_TEMPLATE_BITNET },
|
||||
};
|
||||
|
||||
@@ -2602,6 +2628,7 @@ enum e_model {
|
||||
MODEL_27B,
|
||||
MODEL_17B_16E,
|
||||
MODEL_17B_128E,
|
||||
MODEL_80B_A13B,
|
||||
};
|
||||
|
||||
static const size_t kiB = 1024;
|
||||
@@ -5236,6 +5263,7 @@ static const char * llama_model_type_name(e_model type) {
|
||||
case MODEL_27B: return "27B";
|
||||
case MODEL_17B_16E: return "17Bx16E (Scout)";
|
||||
case MODEL_17B_128E: return "17Bx128E (Maverick)";
|
||||
case MODEL_80B_A13B: return "80B.A13B";
|
||||
default: return "?B";
|
||||
}
|
||||
}
|
||||
@@ -6084,6 +6112,17 @@ static void llm_load_hparams(
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_HUNYUAN_MOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 32: model.type = e_model::MODEL_80B_A13B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: (void)0;
|
||||
}
|
||||
|
||||
@@ -6354,6 +6393,14 @@ static void llm_load_vocab(
|
||||
tokenizer_pre == "bailingmoe") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
|
||||
vocab.tokenizer_clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "seed-coder") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SEED_CODER;
|
||||
vocab.tokenizer_clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "hunyuan") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_HUNYUAN;
|
||||
vocab.tokenizer_clean_spaces = false;
|
||||
} else {
|
||||
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||
}
|
||||
@@ -9260,6 +9307,47 @@ static bool llm_load_tensors(
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_HUNYUAN_MOE:
|
||||
{
|
||||
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (model.output == NULL) {
|
||||
model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
||||
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
||||
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||
|
||||
layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||
layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
||||
layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
|
||||
layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
||||
|
||||
layer.ffn_gate_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
|
||||
layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
|
||||
layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
@@ -9697,12 +9785,7 @@ static struct ggml_tensor * llm_build_norm(
|
||||
const llm_build_cb & cb,
|
||||
int il, float scale_eps = 1) {
|
||||
|
||||
#ifdef GGML_USE_VULKAN
|
||||
constexpr bool use_fused_rms_norm = false;
|
||||
#else
|
||||
constexpr bool use_fused_rms_norm = true;
|
||||
#endif
|
||||
if (use_fused_rms_norm && type == LLM_NORM_RMS && mw) {
|
||||
if (type == LLM_NORM_RMS && mw) {
|
||||
cur = ggml_fused_rms_norm(ctx, cur, mw, scale_eps * hparams.f_norm_rms_eps);
|
||||
if (mb) {
|
||||
cb(cur, "fused_norm", il);
|
||||
@@ -9793,13 +9876,7 @@ static struct ggml_tensor * llm_build_ffn(
|
||||
cur = tmp;
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_VULKAN
|
||||
constexpr bool use_fused_mul_unary = false;
|
||||
#else
|
||||
constexpr bool use_fused_mul_unary = true;
|
||||
#endif
|
||||
|
||||
if (use_fused_mul_unary && type_gate == LLM_FFN_PAR &&
|
||||
if (type_gate == LLM_FFN_PAR &&
|
||||
(type_op == LLM_FFN_SILU || type_op == LLM_FFN_RELU || (type_op == LLM_FFN_GELU && !act_scales))) {
|
||||
cur = ggml_fused_mul_unary(ctx, cur, tmp, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU :
|
||||
type_op == LLM_FFN_RELU ? GGML_UNARY_OP_RELU : GGML_UNARY_OP_GELU);
|
||||
@@ -9981,6 +10058,28 @@ llm_expert_gating_func_type gating_op,
|
||||
cb(cur, "ffn_moe_weighted", il);
|
||||
}
|
||||
|
||||
//#ifdef GGML_USE_VULKAN
|
||||
// // aggregate experts
|
||||
// ggml_tensor * moe_out = nullptr;
|
||||
// //ggml_tensor * first_expert = nullptr;
|
||||
// for (int i = 0; i < n_expert_used; ++i) {
|
||||
// ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
|
||||
// experts->nb[2], i*experts->nb[1]);
|
||||
//
|
||||
// if (i == 0) {
|
||||
// moe_out = cur_expert;
|
||||
// } else {
|
||||
// moe_out = ggml_add(ctx, moe_out, cur_expert);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// if (n_expert_used == 1) {
|
||||
// // avoid returning a non-contiguous tensor
|
||||
// moe_out = ggml_cont(ctx, moe_out);
|
||||
// }
|
||||
//
|
||||
// return moe_out;
|
||||
//#else
|
||||
if (n_expert_used == 1) {
|
||||
return ggml_cont(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0));
|
||||
}
|
||||
@@ -9989,32 +10088,8 @@ llm_expert_gating_func_type gating_op,
|
||||
ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], experts->nb[1]));
|
||||
}
|
||||
return ggml_multi_add(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0), n_expert_used);
|
||||
//#endif
|
||||
|
||||
//// aggregate experts
|
||||
//ggml_tensor * moe_out = nullptr;
|
||||
////ggml_tensor * first_expert = nullptr;
|
||||
//for (int i = 0; i < n_expert_used; ++i) {
|
||||
// ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
|
||||
// experts->nb[2], i*experts->nb[1]);
|
||||
|
||||
// if (i == 0) {
|
||||
// moe_out = cur_expert;
|
||||
// //first_expert = cur_expert;
|
||||
// //printf("%s: %d: %d x %d x %d x %d | %d x %d x %d x %d\n", __func__, ggml_is_contiguous(first_expert),
|
||||
// // (int)cur_expert->ne[0], (int)cur_expert->ne[1], (int)cur_expert->ne[2], (int)cur_expert->ne[3],
|
||||
// // (int)cur_expert->nb[0], (int)cur_expert->nb[1], (int)cur_expert->nb[2], (int)cur_expert->nb[3]);
|
||||
// } else {
|
||||
// moe_out = ggml_add(ctx, moe_out, cur_expert);
|
||||
// //printf("%s: %d %d\n", __func__, ggml_is_contiguous(cur_expert), ggml_are_same_shape(cur_expert, first_expert));
|
||||
// }
|
||||
//}
|
||||
|
||||
//if (n_expert_used == 1) {
|
||||
// // avoid returning a non-contiguous tensor
|
||||
// moe_out = ggml_cont(ctx, moe_out);
|
||||
//}
|
||||
|
||||
//return moe_out;
|
||||
}
|
||||
|
||||
static struct ggml_tensor * llm_build_kqv(
|
||||
@@ -16988,8 +17063,7 @@ struct llm_build_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
//auto * inp_attn = build_attn_inp_kv_unified();
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@@ -17038,7 +17112,7 @@ struct llm_build_context {
|
||||
cb(Kcur, "Kcur", il);
|
||||
//cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
|
||||
@@ -17098,7 +17172,7 @@ struct llm_build_context {
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
@@ -17119,12 +17193,162 @@ struct llm_build_context {
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_hunyuan_moe() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
||||
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
||||
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
}
|
||||
|
||||
ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
}
|
||||
|
||||
ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
}
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
cb(Kcur, "Kcur_norm", il);
|
||||
|
||||
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
cb(Qcur, "Qcur_norm", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
cur = llm_build_norm(ctx0,ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// feed-forward network (non-MoE)
|
||||
ggml_tensor * cur_mlp = llm_build_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur_mlp, "ffn_mlp", il);
|
||||
|
||||
// MoE branch
|
||||
ggml_tensor * cur_moe = llm_build_moe_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU,
|
||||
true, // norm_topk_prob
|
||||
false,
|
||||
0.0,
|
||||
LLM_EXPERT_GATING_FUNC_SOFTMAX,
|
||||
cb,
|
||||
il);
|
||||
cb(cur_moe, "ffn_moe_out", il);
|
||||
|
||||
ggml_tensor * ffn_out = ggml_add(ctx0, cur_moe, cur_mlp);
|
||||
cb(ffn_out, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, ffn_out, ffn_inp);
|
||||
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
//res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
//res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
||||
llama_batch dummy;
|
||||
dummy.n_tokens = 0;
|
||||
@@ -17422,6 +17646,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
{
|
||||
result = llm.build_dots1();
|
||||
} break;
|
||||
case LLM_ARCH_HUNYUAN_MOE:
|
||||
{
|
||||
result = llm.build_hunyuan_moe();
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
@@ -21195,6 +21423,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_GPTNEOX:
|
||||
case LLM_ARCH_CODESHELL:
|
||||
case LLM_ARCH_DOTS1:
|
||||
case LLM_ARCH_HUNYUAN_MOE:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
// all model arches should be listed explicitly here
|
||||
@@ -23010,6 +23239,8 @@ static llm_chat_template llama_chat_detect_template(const std::string & tmpl) {
|
||||
return LLM_CHAT_TEMPLATE_LLAMA4;
|
||||
} else if (tmpl_contains("<|endofuserprompt|>")) {
|
||||
return LLM_CHAT_TEMPLATE_DOTS1;
|
||||
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
|
||||
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
|
||||
}
|
||||
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
||||
}
|
||||
@@ -23443,6 +23674,18 @@ static int32_t llama_chat_apply_template_internal(
|
||||
if (add_ass) {
|
||||
ss << "<|response|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) {
|
||||
// tencent/Hunyuan-A13B-Instruct
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
ss << "<|startoftext|>" << message->content << "<|extra_4|>";
|
||||
} else if (role == "assistant") {
|
||||
ss << "<|startoftext|>" << message->content << "<|eos|>";
|
||||
} else {
|
||||
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
@@ -23650,6 +23893,9 @@ struct llama_sampler_dry * llama_sampler_init_dry(const struct llama_vocab* voca
|
||||
}
|
||||
|
||||
void llama_sampler_dry_reset(struct llama_sampler_dry* smpl) {
|
||||
if (!smpl) {
|
||||
return;
|
||||
}
|
||||
smpl->last_tokens.clear();
|
||||
smpl->dry_repeat_count.clear();
|
||||
smpl->dry_max_token_repeat.clear();
|
||||
@@ -23675,6 +23921,9 @@ struct llama_sampler_dry* llama_sampler_dry_clone(struct llama_sampler_dry* smpl
|
||||
}
|
||||
|
||||
void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token) {
|
||||
if (!smpl) {
|
||||
return;
|
||||
}
|
||||
if (smpl->dry_multiplier == 0.0f || smpl->dry_base < 1.0f || smpl->dry_penalty_last_n == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user