mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-06 03:50:08 +00:00
fix adaptive p sampler rewinding too far back (#1359)
* fix adaptive p sampler rewinding too far back * update comments * correct default value for total_weight, more comments * new variables/names * update comment for n_rewind * move null pointer check back to common_sampler_review() * refactor weighted_sum and total_weight to vector<pair>, better boundary check in llama_review_adaptive_p_impl()
This commit is contained in:
@@ -106,6 +106,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
}
|
||||
}
|
||||
|
||||
result->n_rewind = -1;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -143,16 +145,12 @@ void common_sampler_reset(common_sampler * ctx) {
|
||||
}
|
||||
|
||||
void common_sampler_review(common_sampler * ctx) {
|
||||
if (!ctx->adapt_p_ctx) {
|
||||
return;
|
||||
const int32_t n_rewind = ctx->n_rewind;
|
||||
|
||||
// add stateful samplers here
|
||||
if (ctx->adapt_p_ctx != nullptr) {
|
||||
llama_review_adaptive_p(ctx->adapt_p_ctx, n_rewind);
|
||||
}
|
||||
const bool record = ctx->record_samplers;
|
||||
const bool rewind = ctx->rewind_samplers;
|
||||
|
||||
llama_review_adaptive_p(ctx->adapt_p_ctx, record, rewind);
|
||||
|
||||
ctx->record_samplers = false;
|
||||
ctx->rewind_samplers = false;
|
||||
}
|
||||
|
||||
void llama_sampling_set_rng_seed(struct common_sampler * ctx, uint32_t seed) {
|
||||
|
||||
@@ -135,8 +135,7 @@ struct common_sampler {
|
||||
|
||||
std::mt19937 rng;
|
||||
|
||||
bool record_samplers = false; // record current state for stateful samplers
|
||||
bool rewind_samplers = false; // rewind stateful samplers to last recorded
|
||||
int32_t n_rewind; // number of tokens to rewind
|
||||
};
|
||||
|
||||
|
||||
@@ -152,8 +151,7 @@ void common_sampler_free(struct common_sampler * ctx);
|
||||
void common_sampler_reset(common_sampler * ctx);
|
||||
|
||||
// Review stateful samplers
|
||||
// | record current state for rewinding
|
||||
// | rewind to last recorded state
|
||||
// - rewind internal states (maybe)
|
||||
void common_sampler_review(common_sampler * ctx);
|
||||
|
||||
// Set the sampler seed
|
||||
|
||||
@@ -3332,6 +3332,7 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
|
||||
bool next_token = has_next_token(result, slot);
|
||||
bool send_result = slot.token_buffer.size() >= slot.n_buffer || !next_token;
|
||||
int32_t n_rewind = 0;
|
||||
bool sent_results = false;
|
||||
// don't restore if last time was also rewind
|
||||
if (!slot.rewind_status) {
|
||||
slot.ctx_sampling->params.logit_bias = slot.logit_bias; // restore logit bias
|
||||
@@ -3343,7 +3344,6 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
|
||||
if (n_rewind > 0 && (slot.rewind_count <20 || slot.rewind_count <= 2 * slot.ban_phrases.size())) {
|
||||
rewind_context(slot, n_rewind);
|
||||
slot.rewind_status = true;
|
||||
slot.ctx_sampling->rewind_samplers = true;
|
||||
}
|
||||
else if (send_result) {
|
||||
slot.rewind_status = false;
|
||||
@@ -3356,12 +3356,14 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
|
||||
// send 1 token
|
||||
send_token_results(slot.token_buffer, slot, 1);
|
||||
}
|
||||
slot.ctx_sampling->record_samplers = true;
|
||||
sent_results = true;
|
||||
}
|
||||
else {
|
||||
// buffer the result
|
||||
slot.sampled = result.tok; // for common batch add
|
||||
}
|
||||
|
||||
slot.ctx_sampling->n_rewind = sent_results ? -1 : n_rewind;
|
||||
}
|
||||
|
||||
void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
|
||||
@@ -1415,7 +1415,7 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
|
||||
llama_token_data_array * candidates,
|
||||
struct llama_sampler_adaptive_p * adapt_p_ctx);
|
||||
|
||||
void llama_review_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind);
|
||||
void llama_review_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx, const int32_t n_rewind);
|
||||
|
||||
|
||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
|
||||
@@ -1053,20 +1053,48 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab&
|
||||
|
||||
// adaptive p
|
||||
|
||||
void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind) {
|
||||
if (record && rewind) {
|
||||
LLAMA_LOG_WARN("%s: record AND rewind is invalid\n", __func__);
|
||||
void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const int32_t n_rewind) {
|
||||
if ((n_rewind == 0) || (adapt_p_ctx->target < 0.0f)) {
|
||||
return;
|
||||
}
|
||||
if (record) {
|
||||
adapt_p_ctx->recd_weighted_sum = adapt_p_ctx->weighted_sum;
|
||||
adapt_p_ctx->recd_total_weight = adapt_p_ctx->total_weight;
|
||||
// auto & weighted_sum = adapt_p_ctx->weighted_sum;
|
||||
// auto & total_weight = adapt_p_ctx->total_weight;
|
||||
|
||||
const int32_t sz = adapt_p_ctx->history.size();
|
||||
if ((sz <= 0) || (sz <= n_rewind)) {
|
||||
// critically short history. reset to initial state
|
||||
LLAMA_LOG_WARN("%s: sz=%d, n_rewind=%d should not be possible\n", __func__, sz, n_rewind);
|
||||
adapt_p_ctx->history.clear();
|
||||
adapt_p_ctx->history.push_back({
|
||||
adapt_p_ctx->target / adapt_p_ctx->decay, // weighted_sum
|
||||
1.0f / adapt_p_ctx->decay }); // total_weight
|
||||
return;
|
||||
}
|
||||
if (rewind) {
|
||||
adapt_p_ctx->weighted_sum = adapt_p_ctx->recd_weighted_sum;
|
||||
adapt_p_ctx->total_weight = adapt_p_ctx->recd_total_weight;
|
||||
return;
|
||||
|
||||
if (n_rewind < 0) {
|
||||
// clear history except most recent
|
||||
adapt_p_ctx->history.front() = adapt_p_ctx->history.back();
|
||||
adapt_p_ctx->history.resize(1);
|
||||
} else {
|
||||
// rewind
|
||||
adapt_p_ctx->history.resize(sz - n_rewind);
|
||||
|
||||
// int32_t sz = weighted_sum.size() - n_rewind;
|
||||
// if (sz > 0) {
|
||||
// weighted_sum.resize(sz);
|
||||
// } else {
|
||||
// LLAMA_LOG_WARN("%s: n_rewind=%d, sz=%d should not be possible\n", __func__, n_rewind, sz);
|
||||
// weighted_sum.clear();
|
||||
// weighted_sum.push_back(adapt_p_ctx->target / adapt_p_ctx->decay); // set to default value
|
||||
// }
|
||||
// sz = total_weight.size() - n_rewind;
|
||||
// if (sz > 0) {
|
||||
// total_weight.resize(sz);
|
||||
// } else {
|
||||
// LLAMA_LOG_WARN("%s: n_rewind=%d, sz=%d should not be possible\n", __func__, n_rewind, sz);
|
||||
// total_weight.clear();
|
||||
// total_weight.push_back(1.0f / adapt_p_ctx->decay); // set to default value
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1102,8 +1130,11 @@ llama_token llama_sample_token_adaptive_p_impl(
|
||||
? candidates->data[idx].p / ctx->cum_cur_p
|
||||
: ctx->orig_prob[id] / ctx->cum_orig_prob;
|
||||
if (update_prob > 0) {
|
||||
ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob;
|
||||
ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f;
|
||||
ctx->history.push_back({
|
||||
ctx->decay * ctx->history.back().first + update_prob, // weighted_sum
|
||||
ctx->decay * ctx->history.back().second + 1.0f }); // total_weight
|
||||
// ctx->weighted_sum.push_back(ctx->decay * ctx->weighted_sum.back() + update_prob);
|
||||
// ctx->total_weight.push_back(ctx->decay * ctx->total_weight.back() + 1.0f);
|
||||
}
|
||||
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
@@ -1138,10 +1169,12 @@ void llama_sample_adaptive_p_impl(struct llama_sampling * ctx, llama_token_data_
|
||||
adapt_p_ctx->cum_cur_p = cum_sum;
|
||||
|
||||
// compute adapted target probability
|
||||
const float weighted_sum = adapt_p_ctx->history.back().first;
|
||||
const float total_weight = adapt_p_ctx->history.back().second;
|
||||
const float target = std::clamp(adapt_p_ctx->target, 0.0f, 1.0f);
|
||||
const float adapted_target = std::clamp(adapt_p_ctx->total_weight == 0.0f
|
||||
const float adapted_target = std::clamp(total_weight == 0.0f
|
||||
? target
|
||||
: 2.0f * target - (adapt_p_ctx->weighted_sum / adapt_p_ctx->total_weight),
|
||||
: 2.0f * target - (weighted_sum / total_weight),
|
||||
0.0f, 1.0f);
|
||||
|
||||
// transformation constants
|
||||
@@ -1202,16 +1235,20 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab,
|
||||
/* .decay = */ clamped_decay,
|
||||
/* .updt_w_cur = */ updt_w_cur,
|
||||
/* .rng = */ std::mt19937(seed),
|
||||
/* .weighted_sum = */ target / (1.0f - clamped_decay),
|
||||
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
|
||||
// /* .weighted_sum = */ {},
|
||||
// /* .total_weight = */ {},
|
||||
/* .history = */ {},
|
||||
/* .orig_prob = */ {},
|
||||
/* .cum_orig_prob = */ 1.0f,
|
||||
/* .cum_cur_p = */ 1.0f,
|
||||
/* .max_xform_logit = */ -INFINITY,
|
||||
/* .cum_probs = */ {},
|
||||
/* .recd_weighted_sum = */ target / (1.0f - clamped_decay),
|
||||
/* .recd_total_weight = */ 1.0f / (1.0f - clamped_decay),
|
||||
};
|
||||
// result->weighted_sum.push_back(target / (1.0f - clamped_decay));
|
||||
// result->total_weight.push_back(1.0f / (1.0f - clamped_decay));
|
||||
result->history.push_back({
|
||||
target / (1.0f - clamped_decay), // weighted_sum
|
||||
1.0f / (1.0f - clamped_decay) }); // total_weight
|
||||
result->orig_prob.resize(n_vocab);
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -70,8 +70,9 @@ struct llama_sampler_adaptive_p {
|
||||
const float decay; // EMA decay; history ≈ 1/(1-decay) tokens (0.0 - 0.99)
|
||||
const bool updt_w_cur; // false=original, true=current
|
||||
std::mt19937 rng; // RNG
|
||||
float weighted_sum; // sum(p_n * decay^N)
|
||||
float total_weight; // sum(decay^i), converges to 1/(1-decay)
|
||||
// std::vector<float> weighted_sum; // [0] = sum(p_n * decay^N)
|
||||
// std::vector<float> total_weight; // [0] = sum(decay^i), converges to 1/(1-decay)
|
||||
std::vector<std::pair<float, float>> history; // <weighted_sum, total_weight>
|
||||
|
||||
// first referenced in prep
|
||||
std::vector<float> orig_prob; // for storing the original proibabilities
|
||||
@@ -83,10 +84,6 @@ struct llama_sampler_adaptive_p {
|
||||
|
||||
// first referenced in sample_token
|
||||
std::vector<float> cum_probs; // cumulative probability distribution
|
||||
|
||||
// recorded states for rewinding
|
||||
float recd_weighted_sum;
|
||||
float recd_total_weight;
|
||||
};
|
||||
|
||||
struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab,
|
||||
@@ -105,7 +102,7 @@ void llama_sample_adaptive_p_impl(
|
||||
llama_token_data_array * candidates,
|
||||
struct llama_sampler_adaptive_p * adapt_p_ctx);
|
||||
|
||||
void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind);
|
||||
void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const int32_t n_rewind);
|
||||
|
||||
|
||||
void llama_sample_repetition_penalties_impl(
|
||||
|
||||
@@ -8304,8 +8304,8 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p(int n_vocab, const float
|
||||
return llama_init_adaptive_p_impl(n_vocab, target, decay, updt_w_cur, seed);
|
||||
}
|
||||
|
||||
void llama_review_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind) {
|
||||
llama_review_adaptive_p_impl(adapt_p_ctx, record, rewind);
|
||||
void llama_review_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx, const int32_t n_rewind) {
|
||||
llama_review_adaptive_p_impl(adapt_p_ctx, n_rewind);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user