mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-27 08:34:09 +00:00
Add MTP decoding support for GLM-4.x MoE (#1270)
* wip: port MTP architecture Ports the Multi-Token Prediction (MTP) architecture to the older `llama.cpp` codebase used by `ikllama`. Changes include: - Updating `llama_batch` to support `mtp_params`. - Modifying `llama_decode_internal` (and `encode`) to handle MTP operations (Warmup, Update, Draft). - Adding public APIs for MTP state management (`llama_set_draft_input_hidden_state`). - Adapting the embedding extraction logic to skip MTP update passes. * Refactors `server_slot` to support generic speculative decoding (MTP or Draft Model). * core: enable hybrid outputs (logits + embeddings) for MTP support * fix(mtp): correct KV-cache slot finding for updates * fix(mtp): persist hidden states to prevent context corruption during drafting * refactor(mtp): clean unused code * fix(mtp): update server to new functions name * fix(mtp): fix graph and save hidden state * mtp: refactor integration, context params and kv cache search * mtp: fix hidden state extraction and speculative acceptance flow * server: fix MTP warmup for long prompts and reset token buffer * llama: refactor MTP operation state to context parameters * server: fix n_past calculation in MTP acceptance * llama: fix mtp enable flags * speculative: refactor MTP to use common_speculative interface * context: remove unused signatures * clip: fix deprecated enum-enum conversion warning * common: fix format string crash in help message * context: fix mtp activation logic
This commit is contained in:
committed by
GitHub
parent
cbf7fc7e2f
commit
09a88c9ae5
@@ -279,6 +279,12 @@ extern "C" {
|
||||
LLAMA_SPLIT_MODE_GRAPH = 3, // splits computations across GPUs
|
||||
};
|
||||
|
||||
enum llama_mtp_op_type {
|
||||
MTP_OP_NONE = 0,
|
||||
MTP_OP_WARMUP = 1,
|
||||
MTP_OP_UPDATE_ACCEPTED = 2,
|
||||
MTP_OP_DRAFT_GEN = 3,
|
||||
};
|
||||
|
||||
typedef struct llama_token_data {
|
||||
llama_token id; // token id
|
||||
@@ -394,6 +400,7 @@ extern "C" {
|
||||
bool validate_quants; // if true, check for NaNs while loading the model
|
||||
bool merge_qkv; // if true, merge separate Q, K, V tensors into a single, contiguous tensor
|
||||
bool merge_up_gate_exps; // if true, merge ffn_up_exps and ffn_gate_exps tensors into a single, contiguous tensor
|
||||
bool mtp; // if true, load MTP layers if present
|
||||
};
|
||||
|
||||
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
||||
@@ -449,6 +456,8 @@ extern "C" {
|
||||
bool split_mode_graph_scheduling; // if true, force split mode graph scheduling
|
||||
//bool split_mode_f16; // if true, cast intermediate results to f16 before copying to other GPUs
|
||||
bool scheduler_async; // if true, with split mode "graph" graph evaluation will be done using multiple threads
|
||||
bool mtp; // Activate MTP if supported
|
||||
enum llama_mtp_op_type mtp_op_type;
|
||||
|
||||
// Abort callback
|
||||
// if it returns true, execution of llama_decode() will be aborted
|
||||
@@ -1463,6 +1472,17 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
|
||||
|
||||
LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
|
||||
|
||||
//
|
||||
// MTP
|
||||
//
|
||||
|
||||
LLAMA_API int32_t llama_model_n_nextn_layer(const struct llama_model * model);
|
||||
|
||||
// Set which, if any, MTP operation the context will use
|
||||
LLAMA_API void llama_set_mtp_op_type(struct llama_context * ctx, enum llama_mtp_op_type mtp_op_type);
|
||||
|
||||
LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user