mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 15:44:10 +00:00
Fix adaptive p sampler bug with string ban (#1287)
* adaptive p: upadte internal state only if not rewinding * adaptive p: conditional update for speculative decoding * adaptive p: refactor to rewind instead of update * adaptive p fix: better comments * fix rewind check * add record to handle multi-token rewind * better comment
This commit is contained in:
@@ -1053,10 +1053,27 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab&
|
||||
|
||||
// adaptive p
|
||||
|
||||
void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind) {
|
||||
if (record && rewind) {
|
||||
LLAMA_LOG_WARN("%s: record AND rewind is invalid\n", __func__);
|
||||
return;
|
||||
}
|
||||
if (record) {
|
||||
adapt_p_ctx->recd_weighted_sum = adapt_p_ctx->weighted_sum;
|
||||
adapt_p_ctx->recd_total_weight = adapt_p_ctx->total_weight;
|
||||
return;
|
||||
}
|
||||
if (rewind) {
|
||||
adapt_p_ctx->weighted_sum = adapt_p_ctx->recd_weighted_sum;
|
||||
adapt_p_ctx->total_weight = adapt_p_ctx->recd_total_weight;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_adaptive_p_impl(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates,
|
||||
struct llama_sampler_adaptive_p * adapt_p_ctx) {
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates,
|
||||
struct llama_sampler_adaptive_p * adapt_p_ctx) {
|
||||
GGML_ASSERT(candidates->size > 0);
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
@@ -1082,8 +1099,8 @@ llama_token llama_sample_token_adaptive_p_impl(
|
||||
|
||||
// update history
|
||||
const float update_prob = ctx->updt_w_cur
|
||||
? candidates->data[idx].p / ctx->cum_cur_p
|
||||
: ctx->orig_prob[id] / ctx->cum_orig_prob;
|
||||
? candidates->data[idx].p / ctx->cum_cur_p
|
||||
: ctx->orig_prob[id] / ctx->cum_orig_prob;
|
||||
if (update_prob > 0) {
|
||||
ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob;
|
||||
ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f;
|
||||
@@ -1186,17 +1203,19 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab,
|
||||
GGML_ASSERT(n_vocab > 0);
|
||||
const float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
|
||||
auto result = new llama_sampler_adaptive_p {
|
||||
/* .target = */ target,
|
||||
/* .decay = */ clamped_decay,
|
||||
/* .updt_w_cur = */ updt_w_cur,
|
||||
/* .rng = */ std::mt19937(seed),
|
||||
/* .weighted_sum = */ target / (1.0f - clamped_decay),
|
||||
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
|
||||
/* .orig_prob = */ {},
|
||||
/* .cum_orig_prob = */ 1.0f,
|
||||
/* .cum_cur_p = */ 1.0f,
|
||||
/* .max_xform_logit = */ -INFINITY,
|
||||
/* .cum_probs = */ {},
|
||||
/* .target = */ target,
|
||||
/* .decay = */ clamped_decay,
|
||||
/* .updt_w_cur = */ updt_w_cur,
|
||||
/* .rng = */ std::mt19937(seed),
|
||||
/* .weighted_sum = */ target / (1.0f - clamped_decay),
|
||||
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
|
||||
/* .orig_prob = */ {},
|
||||
/* .cum_orig_prob = */ 1.0f,
|
||||
/* .cum_cur_p = */ 1.0f,
|
||||
/* .max_xform_logit = */ -INFINITY,
|
||||
/* .cum_probs = */ {},
|
||||
/* .recd_weighted_sum = */ target / (1.0f - clamped_decay),
|
||||
/* .recd_total_weight = */ 1.0f / (1.0f - clamped_decay),
|
||||
};
|
||||
result->orig_prob.resize(n_vocab);
|
||||
return result;
|
||||
|
||||
Reference in New Issue
Block a user