diff --git a/common/sampling.cpp b/common/sampling.cpp index 1e9131c4..bbe8b37a 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -142,6 +142,16 @@ void common_sampler_reset(common_sampler * ctx) { llama_sampler_dry_reset(ctx->smpl); } +void common_sampler_review(common_sampler * ctx) { + const bool record = ctx->record_samplers; + const bool rewind = ctx->rewind_samplers; + + llama_review_adaptive_p(ctx->adapt_p_ctx, record, rewind); + + ctx->record_samplers = false; + ctx->rewind_samplers = false; +} + void llama_sampling_set_rng_seed(struct common_sampler * ctx, uint32_t seed) { if (seed == LLAMA_DEFAULT_SEED) { seed = std::random_device{}(); diff --git a/common/sampling.h b/common/sampling.h index cc5b7155..27bf61db 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -134,6 +134,9 @@ struct common_sampler { llama_token_data_array cur_p; // current candidates std::mt19937 rng; + + bool record_samplers = false; // record current state for stateful samplers + bool rewind_samplers = false; // rewind stateful samplers to last recorded }; @@ -148,6 +151,11 @@ void common_sampler_free(struct common_sampler * ctx); // - reset grammar void common_sampler_reset(common_sampler * ctx); +// Review stateful samplers +// | record current state for rewinding +// | rewind to last recorded state +void common_sampler_review(common_sampler * ctx); + // Set the sampler seed void llama_sampling_set_rng_seed(struct common_sampler * ctx, uint32_t seed); diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index fe9286e9..73eba614 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -3020,6 +3020,8 @@ void server_context::speculative_decoding_accept() { } else { buffer_and_check_string_ban(slot, result); } + + common_sampler_review(slot.ctx_sampling); } SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int)ids.size() - 1, (int)slot.drafted.size(), slot.n_past); LOG_VERBOSE("speculative decoding result", { @@ -3135,6 +3137,7 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_ if (n_rewind > 0 && (slot.rewind_count <20 || slot.rewind_count <= 2 * slot.ban_phrases.size())) { rewind_context(slot, n_rewind); slot.rewind_status = true; + slot.ctx_sampling->rewind_samplers = true; } else if (send_result) { slot.rewind_status = false; @@ -3147,6 +3150,7 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_ // send 1 token send_token_results(slot.token_buffer, slot, 1); } + slot.ctx_sampling->record_samplers = true; } else { // buffer the result @@ -3264,6 +3268,8 @@ void server_context::process_batch_tokens(int32_t & n_batch) { buffer_and_check_string_ban(slot, result); } + common_sampler_review(slot.ctx_sampling); + slot.i_batch = -1; } diff --git a/include/llama.h b/include/llama.h index b7998a15..b8cbde9b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1382,6 +1382,8 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns( llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx); + void llama_review_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind); + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 629a4f83..c442f356 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1053,10 +1053,27 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab& // adaptive p +void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind) { + if (record && rewind) { + LLAMA_LOG_WARN("%s: record AND rewind is invalid\n", __func__); + return; + } + if (record) { + adapt_p_ctx->recd_weighted_sum = adapt_p_ctx->weighted_sum; + adapt_p_ctx->recd_total_weight = adapt_p_ctx->total_weight; + return; + } + if (rewind) { + adapt_p_ctx->weighted_sum = adapt_p_ctx->recd_weighted_sum; + adapt_p_ctx->total_weight = adapt_p_ctx->recd_total_weight; + return; + } +} + llama_token llama_sample_token_adaptive_p_impl( - struct llama_sampling * smpl, - llama_token_data_array * candidates, - struct llama_sampler_adaptive_p * adapt_p_ctx) { + struct llama_sampling * smpl, + llama_token_data_array * candidates, + struct llama_sampler_adaptive_p * adapt_p_ctx) { GGML_ASSERT(candidates->size > 0); const int64_t t_start_sample_us = ggml_time_us(); @@ -1082,8 +1099,8 @@ llama_token llama_sample_token_adaptive_p_impl( // update history const float update_prob = ctx->updt_w_cur - ? candidates->data[idx].p / ctx->cum_cur_p - : ctx->orig_prob[id] / ctx->cum_orig_prob; + ? candidates->data[idx].p / ctx->cum_cur_p + : ctx->orig_prob[id] / ctx->cum_orig_prob; if (update_prob > 0) { ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob; ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f; @@ -1186,17 +1203,19 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab, GGML_ASSERT(n_vocab > 0); const float clamped_decay = std::clamp(decay, 0.0f, 0.99f); auto result = new llama_sampler_adaptive_p { - /* .target = */ target, - /* .decay = */ clamped_decay, - /* .updt_w_cur = */ updt_w_cur, - /* .rng = */ std::mt19937(seed), - /* .weighted_sum = */ target / (1.0f - clamped_decay), - /* .total_weight = */ 1.0f / (1.0f - clamped_decay), - /* .orig_prob = */ {}, - /* .cum_orig_prob = */ 1.0f, - /* .cum_cur_p = */ 1.0f, - /* .max_xform_logit = */ -INFINITY, - /* .cum_probs = */ {}, + /* .target = */ target, + /* .decay = */ clamped_decay, + /* .updt_w_cur = */ updt_w_cur, + /* .rng = */ std::mt19937(seed), + /* .weighted_sum = */ target / (1.0f - clamped_decay), + /* .total_weight = */ 1.0f / (1.0f - clamped_decay), + /* .orig_prob = */ {}, + /* .cum_orig_prob = */ 1.0f, + /* .cum_cur_p = */ 1.0f, + /* .max_xform_logit = */ -INFINITY, + /* .cum_probs = */ {}, + /* .recd_weighted_sum = */ target / (1.0f - clamped_decay), + /* .recd_total_weight = */ 1.0f / (1.0f - clamped_decay), }; result->orig_prob.resize(n_vocab); return result; diff --git a/src/llama-sampling.h b/src/llama-sampling.h index b5b33869..0a52cca5 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -83,6 +83,10 @@ struct llama_sampler_adaptive_p { // first referenced in sample_token std::vector cum_probs; // cumulative probability distribution + + // recorded states for rewinding + float recd_weighted_sum; + float recd_total_weight; }; struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab, @@ -101,6 +105,8 @@ void llama_sample_adaptive_p_impl( llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx); +void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind); + void llama_sample_repetition_penalties_impl( struct llama_sampling * smpl, diff --git a/src/llama.cpp b/src/llama.cpp index ba042199..0be9df8a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8160,6 +8160,10 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p(int n_vocab, const float return llama_init_adaptive_p_impl(n_vocab, target, decay, updt_w_cur, seed); } +void llama_review_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind) { + llama_review_adaptive_p_impl(adapt_p_ctx, record, rewind); +} + int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int split_no, int split_count) { std::string str_split_path(split_path);