mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-28 17:14:17 +00:00
Add MTP decoding support for GLM-4.x MoE (#1270)
* wip: port MTP architecture Ports the Multi-Token Prediction (MTP) architecture to the older `llama.cpp` codebase used by `ikllama`. Changes include: - Updating `llama_batch` to support `mtp_params`. - Modifying `llama_decode_internal` (and `encode`) to handle MTP operations (Warmup, Update, Draft). - Adding public APIs for MTP state management (`llama_set_draft_input_hidden_state`). - Adapting the embedding extraction logic to skip MTP update passes. * Refactors `server_slot` to support generic speculative decoding (MTP or Draft Model). * core: enable hybrid outputs (logits + embeddings) for MTP support * fix(mtp): correct KV-cache slot finding for updates * fix(mtp): persist hidden states to prevent context corruption during drafting * refactor(mtp): clean unused code * fix(mtp): update server to new functions name * fix(mtp): fix graph and save hidden state * mtp: refactor integration, context params and kv cache search * mtp: fix hidden state extraction and speculative acceptance flow * server: fix MTP warmup for long prompts and reset token buffer * llama: refactor MTP operation state to context parameters * server: fix n_past calculation in MTP acceptance * llama: fix mtp enable flags * speculative: refactor MTP to use common_speculative interface * context: remove unused signatures * clip: fix deprecated enum-enum conversion warning * common: fix format string crash in help message * context: fix mtp activation logic
This commit is contained in:
committed by
GitHub
parent
cbf7fc7e2f
commit
09a88c9ae5
271
src/llama.cpp
271
src/llama.cpp
@@ -546,6 +546,7 @@ struct llama_context::Prev {
|
||||
int all_seq_id;
|
||||
int n_outputs;
|
||||
int n_kv;
|
||||
llama_mtp_op_type mtp_op_type;
|
||||
ggml_cgraph * graph;
|
||||
};
|
||||
|
||||
@@ -563,11 +564,13 @@ bool llama_context::can_reuse_graph(const llama_batch & u_batch) {
|
||||
kv_self.head > 0 &&
|
||||
kv_self.n == prev->n_kv &&
|
||||
n_outputs == prev->n_outputs &&
|
||||
cparams.mtp_op_type == prev->mtp_op_type &&
|
||||
update_cache_copies();
|
||||
}
|
||||
|
||||
bool llama_context::update_cache_copies() {
|
||||
int n_layer = model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2;
|
||||
const int n_layer = model.mtp ? model.hparams.n_layer
|
||||
: model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2;
|
||||
auto layer_has_attention_kv = [&](int il) {
|
||||
return !((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && model.hparams.is_recurrent(il));
|
||||
};
|
||||
@@ -638,6 +641,12 @@ llama_context::llama_context(const llama_model & model)
|
||||
}
|
||||
}
|
||||
|
||||
void llama_context::set_mtp_op_type(llama_mtp_op_type value) {
|
||||
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
||||
|
||||
cparams.mtp_op_type = value;
|
||||
}
|
||||
|
||||
llama_context::~llama_context() {
|
||||
ggml_backend_sched_free(sched);
|
||||
|
||||
@@ -716,7 +725,8 @@ static bool llama_kv_cache_init(
|
||||
|
||||
const struct llama_hparams & hparams = model.hparams;
|
||||
|
||||
const int64_t n_layer = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
const int64_t n_layer = model.mtp ? hparams.n_layer
|
||||
: hparams.n_layer - hparams.nextn_predict_layers;
|
||||
|
||||
cache.has_shift = false;
|
||||
|
||||
@@ -993,7 +1003,8 @@ static bool llama_kv_cache_init(
|
||||
// to the first cell of the slot.
|
||||
static bool llama_kv_cache_find_slot(
|
||||
struct llama_kv_cache & cache,
|
||||
const struct llama_batch & batch) {
|
||||
const struct llama_batch & batch,
|
||||
enum llama_mtp_op_type op_type) {
|
||||
const uint32_t n_tokens = batch.n_tokens;
|
||||
|
||||
if (cache.recurrent) {
|
||||
@@ -1044,6 +1055,45 @@ static bool llama_kv_cache_find_slot(
|
||||
}
|
||||
// otherwise, one cell per token.
|
||||
|
||||
bool is_mtp_special_op = (op_type == MTP_OP_WARMUP ||
|
||||
op_type == MTP_OP_UPDATE_ACCEPTED);
|
||||
if (is_mtp_special_op) {
|
||||
const llama_pos target_pos = batch.pos[0];
|
||||
const llama_seq_id target_seq = batch.seq_id[0][0];
|
||||
|
||||
bool found = false;
|
||||
|
||||
if (cache.head < cache.size &&
|
||||
cache.cells[cache.head].pos == target_pos &&
|
||||
cache.cells[cache.head].has_seq_id(target_seq)) {
|
||||
found = true;
|
||||
}
|
||||
else {
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.cells[i].pos == target_pos &&
|
||||
cache.cells[i].has_seq_id(target_seq)) {
|
||||
|
||||
cache.head = i;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!found) {
|
||||
LLAMA_LOG_ERROR("%s: MTP Update failed - slot for seq %d pos %d not found\n",
|
||||
__func__, target_seq, target_pos);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cache.head + n_tokens > cache.size) {
|
||||
LLAMA_LOG_ERROR("%s: MTP Update out of bounds\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (n_tokens > cache.size) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
|
||||
return false;
|
||||
@@ -1893,6 +1943,7 @@ static bool llm_load_tensors(
|
||||
const float * tensor_split,
|
||||
bool use_mlock,
|
||||
bool validate_quants,
|
||||
bool mtp,
|
||||
llama_progress_callback progress_callback,
|
||||
void * progress_callback_user_data) {
|
||||
model.t_start_us = ggml_time_us();
|
||||
@@ -1921,6 +1972,7 @@ static bool llm_load_tensors(
|
||||
model.main_gpu = main_gpu;
|
||||
model.max_gpu = max_gpu;
|
||||
model.n_gpu_layers = n_gpu_layers;
|
||||
model.mtp = mtp;
|
||||
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
|
||||
@@ -2300,7 +2352,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
|
||||
|
||||
if (!llm_load_tensors(
|
||||
ml, model, params.n_gpu_layers, params.mla, params.split_mode, params.main_gpu, params.max_gpu, params.tensor_split,
|
||||
params.use_mlock, params.validate_quants,
|
||||
params.use_mlock, params.validate_quants, params.mtp,
|
||||
params.progress_callback, params.progress_callback_user_data
|
||||
)) {
|
||||
return -2;
|
||||
@@ -2969,8 +3021,9 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
|
||||
const auto n_embd = hparams.n_embd;
|
||||
|
||||
// TODO: use a per-batch flag for logits presence instead
|
||||
const bool has_logits = !cparams.embeddings;
|
||||
const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE));
|
||||
const bool has_mtp = lctx.model.hparams.nextn_predict_layers > 0 && lctx.cparams.mtp;
|
||||
const bool has_logits = !cparams.embeddings || has_mtp;
|
||||
const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE)) || has_mtp;
|
||||
|
||||
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
|
||||
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
|
||||
@@ -3049,6 +3102,24 @@ static void llama_graph_compute(
|
||||
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
|
||||
}
|
||||
|
||||
static bool prepare_mtp_graph_inputs(struct llama_context & lctx) {
|
||||
ggml_tensor * dst = lctx.inp_mtp_states;
|
||||
const float * src = nullptr;
|
||||
if (lctx.cparams.mtp_op_type == MTP_OP_WARMUP || lctx.cparams.mtp_op_type == MTP_OP_UPDATE_ACCEPTED) {
|
||||
src = lctx.embd;
|
||||
} else {
|
||||
src = lctx.draft_input_hidden_state;
|
||||
}
|
||||
|
||||
if (!src) {
|
||||
LLAMA_LOG_ERROR("%s: Source hidden state is null\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(dst, src, 0, ggml_nbytes(dst));
|
||||
return true;
|
||||
}
|
||||
|
||||
// decode a batch of tokens by evaluating the transformer
|
||||
//
|
||||
// - lctx: llama context
|
||||
@@ -3260,7 +3331,7 @@ static int llama_decode_internal(
|
||||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
if (!llama_kv_cache_find_slot(kv_self, u_batch)) {
|
||||
if (!llama_kv_cache_find_slot(kv_self, u_batch, cparams.mtp_op_type)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -3322,37 +3393,50 @@ static int llama_decode_internal(
|
||||
#endif
|
||||
if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) {
|
||||
lctx.prev = std::make_unique<llama_context::Prev>(llama_context::Prev{
|
||||
(int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n, gf});
|
||||
(int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n,
|
||||
cparams.mtp_op_type, gf});
|
||||
}
|
||||
} else {
|
||||
//printf("Reusing graph\n");
|
||||
gf = lctx.prev->graph;
|
||||
}
|
||||
|
||||
if (cparams.mtp_op_type != MTP_OP_NONE) {
|
||||
if (!prepare_mtp_graph_inputs(lctx)) {
|
||||
return GGML_STATUS_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
// the output is always the last tensor in the graph
|
||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
|
||||
struct ggml_tensor * embd = nullptr;
|
||||
|
||||
if (lctx.n_outputs == 0) {
|
||||
// no output
|
||||
res = nullptr;
|
||||
embd = nullptr;
|
||||
} else if (cparams.embeddings) {
|
||||
res = nullptr; // do not extract logits for embedding case
|
||||
embd = nullptr;
|
||||
for (int i = gf->n_nodes - 1; i >= 0; --i) {
|
||||
if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
|
||||
embd = gf->nodes[i];
|
||||
break;
|
||||
res = nullptr;
|
||||
}
|
||||
else {
|
||||
const bool has_mtp = lctx.model.hparams.nextn_predict_layers > 0 && lctx.model.mtp;
|
||||
if (cparams.embeddings || has_mtp) {
|
||||
for (int i = gf->n_nodes - 1; i >= 0; --i) {
|
||||
if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
|
||||
embd = gf->nodes[i];
|
||||
break;
|
||||
}
|
||||
if (strcmp(gf->nodes[i]->name, "result_norm") == 0) {
|
||||
embd = gf->nodes[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
if (cparams.embeddings && lctx.model.hparams.nextn_predict_layers == 0) {
|
||||
res = nullptr; // do not extract logits for embedding case
|
||||
} else {
|
||||
if (!embd) { // do not extract embeddings when not needed
|
||||
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
|
||||
} else {
|
||||
embd = nullptr; // do not extract embeddings when not needed
|
||||
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
|
||||
}
|
||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||
|
||||
#if IK_PRINT_TIMING == 1
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
@@ -3392,17 +3476,21 @@ static int llama_decode_internal(
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(lctx.logits != nullptr);
|
||||
// Do not process logits if MTP is only updating the KV cache.
|
||||
if (cparams.mtp_op_type != MTP_OP_WARMUP &&
|
||||
cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(lctx.logits != nullptr);
|
||||
|
||||
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
|
||||
const int32_t n_outputs_new = lctx.n_outputs;
|
||||
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
|
||||
const int32_t n_outputs_new = lctx.n_outputs;
|
||||
|
||||
if (n_outputs_new) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
|
||||
ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float));
|
||||
if (n_outputs_new) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
|
||||
ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float));
|
||||
}
|
||||
}
|
||||
#if IK_PRINT_TIMING
|
||||
tim2 = ggml_time_us();
|
||||
@@ -3411,7 +3499,7 @@ static int llama_decode_internal(
|
||||
}
|
||||
|
||||
// extract embeddings
|
||||
if (embd) {
|
||||
if (embd && cparams.mtp_op_type == MTP_OP_NONE) {
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
@@ -3617,57 +3705,59 @@ static int llama_encode_internal(
|
||||
|
||||
// extract embeddings
|
||||
if (embd) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
if (cparams.mtp_op_type == MTP_OP_NONE) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
|
||||
if (llama_model_has_decoder(&lctx.model)) {
|
||||
lctx.embd_enc.resize(n_tokens*n_embd);
|
||||
float * embd_out = lctx.embd_enc.data();
|
||||
if (llama_model_has_decoder(&lctx.model)) {
|
||||
lctx.embd_enc.resize(n_tokens*n_embd);
|
||||
float * embd_out = lctx.embd_enc.data();
|
||||
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
||||
|
||||
// remember the sequence ids used during the encoding - needed for cross attention later
|
||||
lctx.seq_ids_enc.resize(n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
for (int s = 0; s < batch.n_seq_id[i]; s++) {
|
||||
llama_seq_id seq_id = batch.seq_id[i][s];
|
||||
lctx.seq_ids_enc[i].insert(seq_id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(lctx.embd != nullptr);
|
||||
|
||||
switch (cparams.pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(lctx.embd != nullptr);
|
||||
float * embd_out = lctx.embd;
|
||||
|
||||
GGML_ASSERT(n_tokens*n_embd <= (int64_t) lctx.embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
{
|
||||
// extract sequence embeddings
|
||||
auto & embd_seq_out = lctx.embd_seq;
|
||||
embd_seq_out.clear();
|
||||
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||
continue;
|
||||
}
|
||||
embd_seq_out[seq_id].resize(n_embd);
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ABORT("unknown pooling type");
|
||||
// remember the sequence ids used during the encoding - needed for cross attention later
|
||||
lctx.seq_ids_enc.resize(n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
for (int s = 0; s < batch.n_seq_id[i]; s++) {
|
||||
llama_seq_id seq_id = batch.seq_id[i][s];
|
||||
lctx.seq_ids_enc[i].insert(seq_id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(lctx.embd != nullptr);
|
||||
|
||||
switch (cparams.pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(lctx.embd != nullptr);
|
||||
float * embd_out = lctx.embd;
|
||||
|
||||
GGML_ASSERT(n_tokens*n_embd <= (int64_t) lctx.embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
{
|
||||
// extract sequence embeddings
|
||||
auto & embd_seq_out = lctx.embd_seq;
|
||||
embd_seq_out.clear();
|
||||
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||
continue;
|
||||
}
|
||||
embd_seq_out[seq_id].resize(n_embd);
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ABORT("unknown pooling type");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4223,6 +4313,7 @@ struct llama_model_params llama_model_default_params() {
|
||||
/*.validate_quants =*/ false,
|
||||
/*.merge_qkv =*/ false,
|
||||
/*.merge_up_gate_exps =*/ false,
|
||||
/*.mtp =*/ false,
|
||||
};
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
@@ -4278,6 +4369,8 @@ struct llama_context_params llama_context_default_params() {
|
||||
/*.split_mode_graph_scheduling =*/ false,
|
||||
// /*.split_mode_f16 =*/ true,
|
||||
/*.scheduler_async =*/ false,
|
||||
/*.mtp =*/ false,
|
||||
/*.mtp_op_type =*/ MTP_OP_NONE,
|
||||
/*.abort_callback =*/ nullptr,
|
||||
/*.abort_callback_data =*/ nullptr,
|
||||
/*.offload_policy =*/ nullptr,
|
||||
@@ -4648,6 +4741,7 @@ struct llama_context * llama_init_from_model(
|
||||
cparams.min_experts = params.min_experts;
|
||||
cparams.thresh_experts = params.thresh_experts;
|
||||
cparams.cuda_params = params.cuda_params;
|
||||
cparams.mtp = params.mtp;
|
||||
|
||||
cparams.reduce_type = params.type_reduce;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
@@ -4725,6 +4819,12 @@ struct llama_context * llama_init_from_model(
|
||||
}
|
||||
}
|
||||
|
||||
if (model->arch != LLM_ARCH_GLM4_MOE && cparams.mtp != 0) {
|
||||
cparams.mtp = 0;
|
||||
}
|
||||
|
||||
cparams.mtp_op_type = params.mtp_op_type;
|
||||
|
||||
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
||||
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
||||
@@ -6058,7 +6158,7 @@ struct llama_data_read {
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = dest_seq_id;
|
||||
}
|
||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
||||
if (!llama_kv_cache_find_slot(kv_self, batch, ctx->cparams.mtp_op_type)) {
|
||||
llama_batch_free(batch);
|
||||
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||
return false;
|
||||
@@ -7003,6 +7103,10 @@ int32_t llama_decode(
|
||||
return ret;
|
||||
}
|
||||
|
||||
void llama_set_mtp_op_type(llama_context * ctx, llama_mtp_op_type mtp_op_type) {
|
||||
ctx->set_mtp_op_type(mtp_op_type);
|
||||
}
|
||||
|
||||
void llama_synchronize(struct llama_context * ctx) {
|
||||
ggml_backend_sched_synchronize(ctx->sched);
|
||||
|
||||
@@ -8333,3 +8437,8 @@ void llama_set_offload_policy(struct llama_context * lctx, int op, bool on_or_of
|
||||
printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXX offload(%s) = %d\n", op_name, on_or_off);
|
||||
ggml_backend_sched_set_op_offload(lctx->sched, ggml_op(op), on_or_off);
|
||||
}
|
||||
|
||||
void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state) {
|
||||
ctx->draft_input_hidden_state = hidden_state;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user