reset ngram mod state for rejected tokens

This commit is contained in:
SamuelOliveirads
2026-04-17 13:12:40 -03:00
parent 8ff2d943a3
commit dc4797b723
4 changed files with 19 additions and 7 deletions

View File

@@ -208,7 +208,7 @@ void common_ngram_map_begin(
count_keys, count_keys_del, count_values_del, count_map_entries_upd);
}
map.idx_last_check = (map.size_last_begin > 0) ? map.size_last_begin - 1 : 0;
map.idx_last_check = size_begin;
map.size_last_begin = size_begin;
}
@@ -386,7 +386,7 @@ void common_ngram_map_draft(common_ngram_map & map,
LOG_DBG("%s: key_idx = %zu, key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
curr_key.key_idx, key_offset, curr_key.key_num, draft.size());
map.last_draft_created = false;
map.last_draft_created = true;
map.last_draft_key_idx = key_offset;
map.last_draft_value_idx = 0; // value 0 is used for simple mode
return;
@@ -524,7 +524,7 @@ void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) {
struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation.
// update the value statistics
LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
LOG_DBG("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
n_accepted, curr_value.n_accepted);
curr_value.n_accepted = n_accepted;
}

View File

@@ -160,6 +160,12 @@ void llama_sampling_set_rng_seed(struct common_sampler * ctx, uint32_t seed) {
}
void common_sampler_clone(common_sampler * src, common_sampler * dst) {
dst->params = src->params;
dst->mirostat_mu = src->mirostat_mu;
dst->n_valid = src->n_valid;
dst->rng = src->rng;
dst->server_biases = src->server_biases;
if (dst->grammar) {
llama_grammar_free(dst->grammar);
dst->grammar = nullptr;
@@ -172,7 +178,12 @@ void common_sampler_clone(common_sampler * src, common_sampler * dst) {
}
dst->prev = src->prev;
if (dst->smpl) {
llama_sampler_dry_free(dst->smpl);
dst->smpl = nullptr;
}
dst->smpl = llama_sampler_dry_clone(src->smpl);
}
llama_token llama_sampling_last(common_sampler * ctx) {

View File

@@ -595,6 +595,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
i_last = 0;
n_draft_last = 0;
n_low = 0;
const size_t n = mod.get_n();
@@ -1130,10 +1131,6 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) {
spec->t_step_start_us = 0;
}
if (n_accepted == 0) {
return;
}
common_speculative_state * impl = spec->curr_impl;
GGML_ASSERT(impl);

View File

@@ -3660,6 +3660,10 @@ void server_context::speculative_decoding_accept() {
}
}
for (llama_token id : ids) {
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
}
llama_batch_free(re_batch);
SLT_DBG(slot, "spec checkpoint restored (gpu=%s): re-decoded %d accepted tokens (rejected %d)\n",
gpu_ckpt ? "yes" : "no", n_accepted, (int)(n_draft - (ids.size() - 1)));