diff --git a/common/sampling.cpp b/common/sampling.cpp index 08a19b45..bd915626 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -506,3 +506,34 @@ void llama_sampling_accept( llama_sampler_dry_accept(ctx_sampling->smpl, id); } } + +std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft) { + GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); + + std::vector result; + result.reserve(idxs.size()); + + size_t i = 0; + for (; i < draft.size(); i++) { + const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]); + + llama_sampling_accept(gsmpl, ctx, id, true); + + result.push_back(id); + + if (draft[i] != id) { + break; + } + } + + if (i == draft.size()) { + const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]); + + llama_sampling_accept(gsmpl, ctx, id, true); + + result.push_back(id); + } + + return result; +} + diff --git a/common/sampling.h b/common/sampling.h index 1d5bf0b9..2517daee 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -176,3 +176,22 @@ void llama_sampling_accept( struct llama_context * ctx_main, llama_token id, bool apply_grammar); + +// generalized version of common_sampler_sample +// +// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match +// if the sampler disagrees at some point, we stop and return the accepted tokens up to now +// +// common_sampler_sample_n(gsmpl, ctx, { idx }, {}); +// +// is equivalent to +// +// common_sampler_sample(gsmpl, ctx, idx); +// common_sampler_accept(gsmpl, token, true); +// +// requires: idxs.size() == draft.size() + 1 +// +// returns at least 1 token, up to idxs.size() +// +std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft); +