Files
ik_llama.cpp/common/speculative.h
Samuel Oliveira Alves 09a88c9ae5 Add MTP decoding support for GLM-4.x MoE (#1270)
* wip: port MTP architecture

Ports the Multi-Token Prediction (MTP) architecture to the older `llama.cpp` codebase used by `ikllama`.

Changes include:
- Updating `llama_batch` to support `mtp_params`.
- Modifying `llama_decode_internal` (and `encode`) to handle MTP operations (Warmup, Update, Draft).
- Adding public APIs for MTP state management (`llama_set_draft_input_hidden_state`).
- Adapting the embedding extraction logic to skip MTP update passes.

* Refactors `server_slot` to support generic speculative decoding (MTP or Draft Model).

* core: enable hybrid outputs (logits + embeddings) for MTP support

* fix(mtp): correct KV-cache slot finding for updates

* fix(mtp): persist hidden states to prevent context corruption during drafting

* refactor(mtp): clean unused code

* fix(mtp): update server to new functions name

* fix(mtp): fix graph and save hidden state

* mtp: refactor integration, context params and kv cache search

* mtp: fix hidden state extraction and speculative acceptance flow

* server: fix MTP warmup for long prompts and reset token buffer

* llama: refactor MTP operation state to context parameters

* server: fix n_past calculation in MTP acceptance

* llama: fix mtp enable flags

* speculative: refactor MTP to use common_speculative interface

* context: remove unused signatures

* clip: fix deprecated enum-enum conversion warning

* common: fix format string crash in help message

* context: fix mtp activation logic
2026-02-22 18:14:39 +01:00

61 lines
2.0 KiB
C++

#pragma once
#include "llama.h"
#include "common.h"
struct common_speculative;
// comma separated list of all types
std::string common_speculative_type_name_str();
// convert string to type
enum common_speculative_type common_speculative_type_from_name(const std::string & name);
// convert type to string
std::string common_speculative_type_to_str(enum common_speculative_type type);
// check if the llama_context is compatible for speculative decoding
// note: clears the memory of the context
bool common_speculative_is_compat(llama_context * ctx_tgt);
common_speculative * common_speculative_init(
common_params_speculative & params,
llama_context * ctx_tgt);
void common_speculative_free(common_speculative * spec);
// optionally call once at the beginning of a new generation
void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt);
// sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens common_speculative_draft(
common_speculative * spec,
const common_params_speculative & params,
const llama_tokens & prompt,
llama_token id_last);
// informs the speculative decoder that n_accepted tokens were accepted by the target model
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
// print statistics about the speculative decoding
void common_speculative_print_stats(const common_speculative * spec);
// Generates speculative draft tokens using the Multi-Token Prediction (MTP) architecture.
std::vector<llama_token> mtp_speculative_gen_draft(
struct common_sampler * smpl,
struct llama_context * ctx,
int n_draft,
float p_min,
llama_token id_last,
int32_t n_past,
llama_seq_id seq_id);
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup);
void mtp_accept_tokens(
struct llama_context * ctx,
const std::vector<llama_token> & ids,
int32_t n_past_base,
llama_seq_id seq_id
);