mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
add llama_sampling_sample_and_accept_n to sampling
This commit is contained in:
@@ -506,3 +506,34 @@ void llama_sampling_accept(
|
|||||||
llama_sampler_dry_accept(ctx_sampling->smpl, id);
|
llama_sampler_dry_accept(ctx_sampling->smpl, id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft) {
|
||||||
|
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
||||||
|
|
||||||
|
std::vector<llama_token> result;
|
||||||
|
result.reserve(idxs.size());
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
for (; i < draft.size(); i++) {
|
||||||
|
const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]);
|
||||||
|
|
||||||
|
llama_sampling_accept(gsmpl, ctx, id, true);
|
||||||
|
|
||||||
|
result.push_back(id);
|
||||||
|
|
||||||
|
if (draft[i] != id) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i == draft.size()) {
|
||||||
|
const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]);
|
||||||
|
|
||||||
|
llama_sampling_accept(gsmpl, ctx, id, true);
|
||||||
|
|
||||||
|
result.push_back(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -176,3 +176,22 @@ void llama_sampling_accept(
|
|||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
llama_token id,
|
llama_token id,
|
||||||
bool apply_grammar);
|
bool apply_grammar);
|
||||||
|
|
||||||
|
// generalized version of common_sampler_sample
|
||||||
|
//
|
||||||
|
// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
|
||||||
|
// if the sampler disagrees at some point, we stop and return the accepted tokens up to now
|
||||||
|
//
|
||||||
|
// common_sampler_sample_n(gsmpl, ctx, { idx }, {});
|
||||||
|
//
|
||||||
|
// is equivalent to
|
||||||
|
//
|
||||||
|
// common_sampler_sample(gsmpl, ctx, idx);
|
||||||
|
// common_sampler_accept(gsmpl, token, true);
|
||||||
|
//
|
||||||
|
// requires: idxs.size() == draft.size() + 1
|
||||||
|
//
|
||||||
|
// returns at least 1 token, up to idxs.size()
|
||||||
|
//
|
||||||
|
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user