port functions from common/speculative, common/sampling

This commit is contained in:
T. M.
2025-07-25 04:39:43 +00:00
parent 642b70a64b
commit 422af9eeca
5 changed files with 89 additions and 81 deletions

View File

@@ -442,7 +442,9 @@ static llama_token_data_array llama_sampling_prepare_impl(
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
ctx_sampling->cur_p = { cur.data(), cur.size(), false };
llama_token_data_array & cur_p = ctx_sampling->cur_p;
// apply penalties
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
@@ -507,6 +509,10 @@ void llama_sampling_accept(
}
}
llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling_context * ctx_sampling) {
return &ctx_sampling->cur_p;
}
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<llama_token> & draft) {
std::vector<llama_token> result;