Adaptive p: bugfix + optimization + refactor (#1155)

* adaptive-p sampler: fix zeroed orig_probs bug and refactor

- Fix bug where original probabilities were captured as zero by calculating
  them from logits in llama_prep_adaptive_p (new).
- Replace vector with unordered_map to track candidate probabilities,
  filtering for relevance via logit delta (16.6f).
- Standardize API naming: llama_<action/verb>_<focus/name/topic>_<extra/info>
- Update function signatures to follow most other samplers.

* resolve merge bug

* adaptive-p: revert reordering function definitions
This commit is contained in:
dungquixote42
2026-01-18 01:26:06 -05:00
committed by GitHub
parent d71a3ec315
commit 6dfbef27ec
5 changed files with 121 additions and 58 deletions

View File

@@ -61,6 +61,7 @@ struct llama_sampler_dry * llama_sampler_init_dry_impl(
void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_array* cur_p);
// maintains an exponential moving average of the *ORIGINAL* probabilities of selected tokens
// used to compute an adapted target at each sampling step.
// see llama.h for a full description of the sampler
@@ -70,15 +71,30 @@ struct llama_sampler_adaptive_p {
std::mt19937 rng; // RNG
float weighted_sum; // sum(p_n * decay^N)
float total_weight; // sum(decay^i), converges to 1/(1-decay)
float max_logit; // maximum logit found during transform
std::vector<float> probs; // cumulative probabilities
// first referenced in prep
std::unordered_map<llama_token, float> orig_prob_map; // probabilities before sampler_queue
float cum_orig_prob; // for normalizing orig_prob in sample_token
// first referenced in sample
float max_xform_logit; // maximum logit found during transform
// first referenced in sample_token
std::vector<float> cum_probs; // cumulative probability distribution
};
void llama_sampler_adaptive_p_apply(
struct llama_sampler_adaptive_p * adapt_p_ctx,
llama_token_data_array * candidates);
struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(
const float target,
const float decay,
const uint32_t seed);
struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p_impl(const float target, const float decay, const uint32_t seed);
void llama_prep_adaptive_p_impl(
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx);
void llama_sample_adaptive_p_impl(
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx);
void llama_sample_repetition_penalties_impl(
@@ -101,6 +117,6 @@ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, ll
llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
llama_token llama_sample_token_adaptive_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx, float * orig_probs);
llama_token llama_sample_token_adaptive_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx);