mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-03 02:20:01 +00:00
Fix adaptive p sampler bug with string ban (#1287)
* adaptive p: upadte internal state only if not rewinding * adaptive p: conditional update for speculative decoding * adaptive p: refactor to rewind instead of update * adaptive p fix: better comments * fix rewind check * add record to handle multi-token rewind * better comment
This commit is contained in:
@@ -142,6 +142,16 @@ void common_sampler_reset(common_sampler * ctx) {
|
||||
llama_sampler_dry_reset(ctx->smpl);
|
||||
}
|
||||
|
||||
void common_sampler_review(common_sampler * ctx) {
|
||||
const bool record = ctx->record_samplers;
|
||||
const bool rewind = ctx->rewind_samplers;
|
||||
|
||||
llama_review_adaptive_p(ctx->adapt_p_ctx, record, rewind);
|
||||
|
||||
ctx->record_samplers = false;
|
||||
ctx->rewind_samplers = false;
|
||||
}
|
||||
|
||||
void llama_sampling_set_rng_seed(struct common_sampler * ctx, uint32_t seed) {
|
||||
if (seed == LLAMA_DEFAULT_SEED) {
|
||||
seed = std::random_device{}();
|
||||
|
||||
@@ -134,6 +134,9 @@ struct common_sampler {
|
||||
llama_token_data_array cur_p; // current candidates
|
||||
|
||||
std::mt19937 rng;
|
||||
|
||||
bool record_samplers = false; // record current state for stateful samplers
|
||||
bool rewind_samplers = false; // rewind stateful samplers to last recorded
|
||||
};
|
||||
|
||||
|
||||
@@ -148,6 +151,11 @@ void common_sampler_free(struct common_sampler * ctx);
|
||||
// - reset grammar
|
||||
void common_sampler_reset(common_sampler * ctx);
|
||||
|
||||
// Review stateful samplers
|
||||
// | record current state for rewinding
|
||||
// | rewind to last recorded state
|
||||
void common_sampler_review(common_sampler * ctx);
|
||||
|
||||
// Set the sampler seed
|
||||
void llama_sampling_set_rng_seed(struct common_sampler * ctx, uint32_t seed);
|
||||
|
||||
|
||||
@@ -3020,6 +3020,8 @@ void server_context::speculative_decoding_accept() {
|
||||
} else {
|
||||
buffer_and_check_string_ban(slot, result);
|
||||
}
|
||||
|
||||
common_sampler_review(slot.ctx_sampling);
|
||||
}
|
||||
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int)ids.size() - 1, (int)slot.drafted.size(), slot.n_past);
|
||||
LOG_VERBOSE("speculative decoding result", {
|
||||
@@ -3135,6 +3137,7 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
|
||||
if (n_rewind > 0 && (slot.rewind_count <20 || slot.rewind_count <= 2 * slot.ban_phrases.size())) {
|
||||
rewind_context(slot, n_rewind);
|
||||
slot.rewind_status = true;
|
||||
slot.ctx_sampling->rewind_samplers = true;
|
||||
}
|
||||
else if (send_result) {
|
||||
slot.rewind_status = false;
|
||||
@@ -3147,6 +3150,7 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
|
||||
// send 1 token
|
||||
send_token_results(slot.token_buffer, slot, 1);
|
||||
}
|
||||
slot.ctx_sampling->record_samplers = true;
|
||||
}
|
||||
else {
|
||||
// buffer the result
|
||||
@@ -3264,6 +3268,8 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
buffer_and_check_string_ban(slot, result);
|
||||
}
|
||||
|
||||
common_sampler_review(slot.ctx_sampling);
|
||||
|
||||
slot.i_batch = -1;
|
||||
}
|
||||
|
||||
|
||||
@@ -1382,6 +1382,8 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
|
||||
llama_token_data_array * candidates,
|
||||
struct llama_sampler_adaptive_p * adapt_p_ctx);
|
||||
|
||||
void llama_review_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind);
|
||||
|
||||
|
||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
|
||||
@@ -1053,10 +1053,27 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab&
|
||||
|
||||
// adaptive p
|
||||
|
||||
void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind) {
|
||||
if (record && rewind) {
|
||||
LLAMA_LOG_WARN("%s: record AND rewind is invalid\n", __func__);
|
||||
return;
|
||||
}
|
||||
if (record) {
|
||||
adapt_p_ctx->recd_weighted_sum = adapt_p_ctx->weighted_sum;
|
||||
adapt_p_ctx->recd_total_weight = adapt_p_ctx->total_weight;
|
||||
return;
|
||||
}
|
||||
if (rewind) {
|
||||
adapt_p_ctx->weighted_sum = adapt_p_ctx->recd_weighted_sum;
|
||||
adapt_p_ctx->total_weight = adapt_p_ctx->recd_total_weight;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_adaptive_p_impl(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates,
|
||||
struct llama_sampler_adaptive_p * adapt_p_ctx) {
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates,
|
||||
struct llama_sampler_adaptive_p * adapt_p_ctx) {
|
||||
GGML_ASSERT(candidates->size > 0);
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
@@ -1082,8 +1099,8 @@ llama_token llama_sample_token_adaptive_p_impl(
|
||||
|
||||
// update history
|
||||
const float update_prob = ctx->updt_w_cur
|
||||
? candidates->data[idx].p / ctx->cum_cur_p
|
||||
: ctx->orig_prob[id] / ctx->cum_orig_prob;
|
||||
? candidates->data[idx].p / ctx->cum_cur_p
|
||||
: ctx->orig_prob[id] / ctx->cum_orig_prob;
|
||||
if (update_prob > 0) {
|
||||
ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob;
|
||||
ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f;
|
||||
@@ -1186,17 +1203,19 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab,
|
||||
GGML_ASSERT(n_vocab > 0);
|
||||
const float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
|
||||
auto result = new llama_sampler_adaptive_p {
|
||||
/* .target = */ target,
|
||||
/* .decay = */ clamped_decay,
|
||||
/* .updt_w_cur = */ updt_w_cur,
|
||||
/* .rng = */ std::mt19937(seed),
|
||||
/* .weighted_sum = */ target / (1.0f - clamped_decay),
|
||||
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
|
||||
/* .orig_prob = */ {},
|
||||
/* .cum_orig_prob = */ 1.0f,
|
||||
/* .cum_cur_p = */ 1.0f,
|
||||
/* .max_xform_logit = */ -INFINITY,
|
||||
/* .cum_probs = */ {},
|
||||
/* .target = */ target,
|
||||
/* .decay = */ clamped_decay,
|
||||
/* .updt_w_cur = */ updt_w_cur,
|
||||
/* .rng = */ std::mt19937(seed),
|
||||
/* .weighted_sum = */ target / (1.0f - clamped_decay),
|
||||
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
|
||||
/* .orig_prob = */ {},
|
||||
/* .cum_orig_prob = */ 1.0f,
|
||||
/* .cum_cur_p = */ 1.0f,
|
||||
/* .max_xform_logit = */ -INFINITY,
|
||||
/* .cum_probs = */ {},
|
||||
/* .recd_weighted_sum = */ target / (1.0f - clamped_decay),
|
||||
/* .recd_total_weight = */ 1.0f / (1.0f - clamped_decay),
|
||||
};
|
||||
result->orig_prob.resize(n_vocab);
|
||||
return result;
|
||||
|
||||
@@ -83,6 +83,10 @@ struct llama_sampler_adaptive_p {
|
||||
|
||||
// first referenced in sample_token
|
||||
std::vector<float> cum_probs; // cumulative probability distribution
|
||||
|
||||
// recorded states for rewinding
|
||||
float recd_weighted_sum;
|
||||
float recd_total_weight;
|
||||
};
|
||||
|
||||
struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab,
|
||||
@@ -101,6 +105,8 @@ void llama_sample_adaptive_p_impl(
|
||||
llama_token_data_array * candidates,
|
||||
struct llama_sampler_adaptive_p * adapt_p_ctx);
|
||||
|
||||
void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind);
|
||||
|
||||
|
||||
void llama_sample_repetition_penalties_impl(
|
||||
struct llama_sampling * smpl,
|
||||
|
||||
@@ -8160,6 +8160,10 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p(int n_vocab, const float
|
||||
return llama_init_adaptive_p_impl(n_vocab, target, decay, updt_w_cur, seed);
|
||||
}
|
||||
|
||||
void llama_review_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind) {
|
||||
llama_review_adaptive_p_impl(adapt_p_ctx, record, rewind);
|
||||
}
|
||||
|
||||
|
||||
int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int split_no, int split_count) {
|
||||
std::string str_split_path(split_path);
|
||||
|
||||
Reference in New Issue
Block a user