add llama_sampling_sample_and_accept_n to sampling

This commit is contained in:
T. M.
2025-07-25 03:39:59 +00:00
parent 5c96a7f619
commit 99c1ef3c01
2 changed files with 50 additions and 0 deletions

View File

@@ -506,3 +506,34 @@ void llama_sampling_accept(
llama_sampler_dry_accept(ctx_sampling->smpl, id);
}
}
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
std::vector<llama_token> result;
result.reserve(idxs.size());
size_t i = 0;
for (; i < draft.size(); i++) {
const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]);
llama_sampling_accept(gsmpl, ctx, id, true);
result.push_back(id);
if (draft[i] != id) {
break;
}
}
if (i == draft.size()) {
const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]);
llama_sampling_accept(gsmpl, ctx, id, true);
result.push_back(id);
}
return result;
}

View File

@@ -176,3 +176,22 @@ void llama_sampling_accept(
struct llama_context * ctx_main,
llama_token id,
bool apply_grammar);
// generalized version of common_sampler_sample
//
// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
// if the sampler disagrees at some point, we stop and return the accepted tokens up to now
//
// common_sampler_sample_n(gsmpl, ctx, { idx }, {});
//
// is equivalent to
//
// common_sampler_sample(gsmpl, ctx, idx);
// common_sampler_accept(gsmpl, token, true);
//
// requires: idxs.size() == draft.size() + 1
//
// returns at least 1 token, up to idxs.size()
//
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft);