This should be better

This commit is contained in:
Kawrakow
2026-01-19 08:40:07 +00:00
parent a9f37c2f80
commit 4df3251b12
5 changed files with 22 additions and 49 deletions

View File

@@ -118,7 +118,9 @@ struct llama_sampling_context * common_sampler_init(const struct llama_vocab* vo
}
case llama_sampler_type::ADAPTIVE_P:
{
result->adapt_p_ctx = llama_init_adaptive_p(params.adaptive_target, params.adaptive_decay, result->rng());
GGML_ASSERT(vocab);
auto n_vocab = llama_vocab_n_tokens(vocab);
result->adapt_p_ctx = llama_init_adaptive_p(n_vocab, params.adaptive_target, params.adaptive_decay, result->rng());
break;
}
default:

View File

@@ -1384,7 +1384,7 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
/// @details Adaptive p sampler initializer
/// @param target Select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
/// @param decay Decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
LLAMA_API struct llama_sampler_adaptive_p * llama_init_adaptive_p(
LLAMA_API struct llama_sampler_adaptive_p * llama_init_adaptive_p(int n_vocab,
const float target,
const float decay,
const uint32_t seed);

View File

@@ -1127,63 +1127,34 @@ void llama_prep_adaptive_p_impl(
struct llama_sampler_adaptive_p * adapt_p_ctx) {
constexpr float kDelta = 16.6f;
auto & orig_prob = adapt_p_ctx->orig_prob;
if (!candidates->sorted) {
float max_logit = candidates->data[0].logit;
for (int j = 1; j < int(candidates->size); ++j) {
max_logit = std::max(max_logit, candidates->data[j].logit);
}
float min_logit = max_logit - kDelta;
float cum_prob = 0.0f;
if (orig_prob.size() != candidates->size) {
orig_prob.resize(candidates->size);
}
for (int j = 0; j < int(candidates->size); ++j) {
if (candidates->data[j].logit > min_logit) {
float prob = expf(candidates->data[j].logit - max_logit);
cum_prob += prob;
orig_prob[j] = prob;
} else {
orig_prob[j] = 0;
}
}
adapt_p_ctx->cum_orig_prob = cum_prob;
return;
if (candidates->size != orig_prob.size() || candidates->sorted) {
LLAMA_LOG_ERROR("%s: this function must be called before any other sampler has been applied\n", __func__);
LLAMA_LOG_ERROR("%s: the sampler has been initialized with a vocabulary of %zu, but is being called with %zu candidates\n",
__func__, orig_prob.size(), candidates->size);
GGML_ABORT("Bad candidates in adaptive_p sampler");
}
// Hopefully we never end here
// But if we do, let's issue some warnings
if (adapt_p_ctx->n_warn < 10) {
LLAMA_LOG_WARN("%s: this function should be called before any other sampler is applied\n", __func__);
++adapt_p_ctx->n_warn;
}
llama_token max_id = 0;
for (int j = 0; j < int(candidates->size); ++j) max_id = std::max(max_id, candidates->data[j].id);
if (max_id + 1 != int(orig_prob.size())) orig_prob.resize(max_id + 1);
std::memset(orig_prob.data(), 0, orig_prob.size()*sizeof(float));
float max_logit = candidates->data[0].logit;
for (int j = 1; j < int(candidates->size); ++j) {
max_logit = std::max(max_logit, candidates->data[j].logit);
}
float min_logit = max_logit - kDelta;
float cum_prob = 0.0f;
for (int j = 0; j < int(candidates->size); ++j) {
auto logit = candidates->data[j].logit;
if (logit <= min_logit) {
break;
}
float prob = expf(logit - max_logit);
float prob = candidates->data[j].logit > min_logit ? expf(candidates->data[j].logit - max_logit) : 0.0f;
cum_prob += prob;
orig_prob[candidates->data[j].id] = prob;
orig_prob[j] = prob;
}
adapt_p_ctx->cum_orig_prob = cum_prob;
}
struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(
struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab,
const float target,
const float decay,
const uint32_t seed) {
GGML_ASSERT(n_vocab > 0);
const float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
return new llama_sampler_adaptive_p {
auto result = new llama_sampler_adaptive_p {
/* .target = */ target,
/* .decay = */ clamped_decay,
/* .rng = */ std::mt19937(seed),
@@ -1191,10 +1162,11 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
/* .orig_prob = */ {},
/* .cum_orig_prob = */ 0.0f,
/* .n_warn = */ 0,
/* .max_xform_logit = */ -INFINITY,
/* .cum_probs = */ {},
};
result->orig_prob.resize(n_vocab);
return result;
}
// grammar

View File

@@ -75,7 +75,6 @@ struct llama_sampler_adaptive_p {
// first referenced in prep
std::vector<float> orig_prob; // for storing the original proibabilities
float cum_orig_prob; // for normalizing orig_prob in sample_token
int n_warn = 0; // for warnings
// first referenced in sample
float max_xform_logit; // maximum logit found during transform
@@ -84,7 +83,7 @@ struct llama_sampler_adaptive_p {
std::vector<float> cum_probs; // cumulative probability distribution
};
struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(
struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab,
const float target,
const float decay,
const uint32_t seed);

View File

@@ -7797,8 +7797,8 @@ void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token)
}
struct llama_sampler_adaptive_p * llama_init_adaptive_p(const float target, const float decay, const uint32_t seed) {
return llama_init_adaptive_p_impl(target, decay, seed);
struct llama_sampler_adaptive_p * llama_init_adaptive_p(int n_vocab, const float target, const float decay, const uint32_t seed) {
return llama_init_adaptive_p_impl(n_vocab, target, decay, seed);
}