mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-23 14:44:09 +00:00
Faster MoE inference (#112)
* multi_sdd: WIP * multi_sdd: CPU works * multi_add: CUDA * multi_add: simplify * multi_add: Metal * Metal: speed up mul_mat_id For the Granite-1B MoE model PP-512 goes from 156 t/s to 890 t/s, so nearly a 6X speedup! --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -8351,25 +8351,40 @@ static struct ggml_tensor * llm_build_moe_ffn(
|
||||
|
||||
experts = ggml_mul(ctx, experts, weights);
|
||||
|
||||
// aggregate experts
|
||||
ggml_tensor * moe_out = nullptr;
|
||||
for (int i = 0; i < n_expert_used; ++i) {
|
||||
ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
|
||||
experts->nb[2], i*experts->nb[1]);
|
||||
|
||||
if (i == 0) {
|
||||
moe_out = cur_expert;
|
||||
} else {
|
||||
moe_out = ggml_add(ctx, moe_out, cur_expert);
|
||||
}
|
||||
}
|
||||
|
||||
if (n_expert_used == 1) {
|
||||
// avoid returning a non-contiguous tensor
|
||||
moe_out = ggml_cont(ctx, moe_out);
|
||||
return ggml_cont(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0));
|
||||
}
|
||||
if (n_expert_used == 2) {
|
||||
return ggml_add(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0),
|
||||
ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], experts->nb[1]));
|
||||
}
|
||||
return ggml_multi_add(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0), n_expert_used);
|
||||
|
||||
return moe_out;
|
||||
//// aggregate experts
|
||||
//ggml_tensor * moe_out = nullptr;
|
||||
////ggml_tensor * first_expert = nullptr;
|
||||
//for (int i = 0; i < n_expert_used; ++i) {
|
||||
// ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
|
||||
// experts->nb[2], i*experts->nb[1]);
|
||||
|
||||
// if (i == 0) {
|
||||
// moe_out = cur_expert;
|
||||
// //first_expert = cur_expert;
|
||||
// //printf("%s: %d: %d x %d x %d x %d | %d x %d x %d x %d\n", __func__, ggml_is_contiguous(first_expert),
|
||||
// // (int)cur_expert->ne[0], (int)cur_expert->ne[1], (int)cur_expert->ne[2], (int)cur_expert->ne[3],
|
||||
// // (int)cur_expert->nb[0], (int)cur_expert->nb[1], (int)cur_expert->nb[2], (int)cur_expert->nb[3]);
|
||||
// } else {
|
||||
// moe_out = ggml_add(ctx, moe_out, cur_expert);
|
||||
// //printf("%s: %d %d\n", __func__, ggml_is_contiguous(cur_expert), ggml_are_same_shape(cur_expert, first_expert));
|
||||
// }
|
||||
//}
|
||||
|
||||
//if (n_expert_used == 1) {
|
||||
// // avoid returning a non-contiguous tensor
|
||||
// moe_out = ggml_cont(ctx, moe_out);
|
||||
//}
|
||||
|
||||
//return moe_out;
|
||||
}
|
||||
|
||||
static struct ggml_tensor * llm_build_kqv(
|
||||
@@ -9011,6 +9026,7 @@ struct llm_build_context {
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||
if (hparams.f_attention_scale != 0) {
|
||||
// Why is hparams.f_attention_scale not simply absorbed into model.layers[il].wq ?
|
||||
Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
@@ -9062,6 +9078,7 @@ struct llm_build_context {
|
||||
|
||||
// For Granite architecture
|
||||
if (hparams.f_residual_scale) {
|
||||
// Why is hparams.f_residual_scale not simply absorbed into model.layers[il].wv ?
|
||||
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
||||
}
|
||||
|
||||
@@ -9103,6 +9120,7 @@ struct llm_build_context {
|
||||
|
||||
// For Granite architecture
|
||||
if (hparams.f_residual_scale) {
|
||||
// Why is hparams.f_residual_scale not simply absorbed into model.layers[il].ffn_down_exps ?
|
||||
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
||||
}
|
||||
|
||||
@@ -9128,6 +9146,7 @@ struct llm_build_context {
|
||||
|
||||
// For Granite architecture
|
||||
if (hparams.f_logit_scale) {
|
||||
// Why is hparams.f_logit_scale not simply absorbed into model.output ?
|
||||
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user