diff --git a/common/common.cpp b/common/common.cpp index 6bf6e4f9..f975aee3 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -817,6 +817,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.mla_attn = true; return true; } + if (arg == "-fmoe" || arg == "--fused-moe") { + params.fused_moe_up_gate = true; + return true; + } if (arg == "-co" || arg == "--color") { params.use_color = true; return true; @@ -1466,6 +1470,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks }); options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %s)", params.mla_attn ? "enabled" : "disabled" }); + options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" }); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" "(default: '%s')", params.prompt.c_str() }); @@ -2303,6 +2308,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; cparams.mla_attn = params.mla_attn; + cparams.fused_moe_up_gate = params.fused_moe_up_gate; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); @@ -3301,6 +3307,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); fprintf(stream, "mla_attn: %s # default: false\n", params.mla_attn ? "true" : "false"); + fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); diff --git a/common/common.h b/common/common.h index b5b67986..f86a58cb 100644 --- a/common/common.h +++ b/common/common.h @@ -175,6 +175,7 @@ struct gpt_params { bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention bool mla_attn = false; // MLA + bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens diff --git a/include/llama.h b/include/llama.h index b5ad65e7..23e32642 100644 --- a/include/llama.h +++ b/include/llama.h @@ -377,6 +377,7 @@ extern "C" { bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] bool mla_attn; // whether to use MLA attention [EXPERIMENTAL] + bool fused_moe_up_gate; // whether to use fused MoE up/down op [EXPERIMENTAL] // Abort callback // if it returns true, execution of llama_decode() will be aborted diff --git a/src/llama.cpp b/src/llama.cpp index 424506c0..eed7aa61 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2516,6 +2516,7 @@ struct llama_cparams { bool offload_kqv; bool flash_attn; bool mla_attn; + bool fused_moe_up_gate; enum llama_pooling_type pooling_type; @@ -8629,32 +8630,19 @@ llm_expert_gating_func_type gating_op, cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens); - ggml_tensor * par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); + ggml_tensor * par; + if (lctx.cparams.fused_moe_up_gate) { + par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); + } else { + ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(up, "ffn_moe_up", il); - //ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - //cb(up, "ffn_moe_up", il); + ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(gate, "ffn_moe_gate", il); - //ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - //cb(gate, "ffn_moe_gate", il); - - //// This is equivalent to the commented out code below - //ggml_tensor * par = ggml_fused_mul_unary(ctx, gate, up, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); - - ////switch (type_op) { - //// case LLM_FFN_SILU: - //// { - //// gate = ggml_silu(ctx, gate); - //// cb(gate, "ffn_moe_silu", il); - //// } break; - //// case LLM_FFN_GELU: - //// { - //// gate = ggml_gelu(ctx, gate); - //// cb(gate, "ffn_moe_gelu", il); - //// } break; - //// default: - //// GGML_ABORT("fatal error"); - ////} - ////ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens] + // This is equivalent to the commented out code below + par = ggml_fused_mul_unary(ctx, gate, up, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); + } cb(par, "ffn_moe_gate_par", il); ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] @@ -8910,6 +8898,7 @@ struct llm_build_context { const bool flash_attn; const bool mla_attn; + const bool fused_moe_up_gate; const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; @@ -8961,6 +8950,7 @@ struct llm_build_context { n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), mla_attn (cparams.mla_attn), + fused_moe_up_gate(cparams.fused_moe_up_gate), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), @@ -17608,6 +17598,7 @@ struct llama_context_params llama_context_default_params() { /*.offload_kqv =*/ true, /*.flash_attn =*/ false, /*.mla_attn =*/ false, + /*.fused_moe_up_gate =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -17807,6 +17798,7 @@ struct llama_context * llama_new_context_with_model( cparams.offload_kqv = params.offload_kqv; cparams.flash_attn = params.flash_attn; cparams.mla_attn = params.mla_attn; + cparams.fused_moe_up_gate= params.fused_moe_up_gate; cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; @@ -17874,6 +17866,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn); + LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);