diff --git a/common/common.cpp b/common/common.cpp index 8dcb9bfa..f3fe305f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -925,6 +925,16 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } return true; } + if (arg == "--adaptive-target") { + CHECK_ARG + sparams.adaptive_target = std::stof(argv[i]); + return true; + } + if (arg == "--adaptive-decay") { + CHECK_ARG + sparams.adaptive_decay = std::stof(argv[i]); + return true; + } if (arg == "--spec-replace") { CHECK_ARG std::string target = argv[i]; @@ -2201,6 +2211,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --xtc-probability p", "xtc probability (default: %.1f, 0.0 = disabled)", (double)sparams.xtc_probability }); options.push_back({ "*", " --xtc-threshold t", "xtc threshold (default: %.1f, >0.5 = disabled)", (double)sparams.xtc_threshold}); options.push_back({ "*", " --top-n-sigma t", "top-n-sigma parmeter (default: %.1f, 0.0 = disabled)", (double)sparams.top_n_sigma}); + options.push_back({ "*", " --adaptive-target", "adaptive-p sampling: (default: %.2f, <0.0 = disabled)", (double)sparams.adaptive_target}); + options.push_back({ "*", " --adaptive-decay", "adaptive-p sampling: (default: %.2f)", (double)sparams.adaptive_decay}); options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n" "i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n" "or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" }); @@ -4174,6 +4186,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); + fprintf(stream, "adaptive_target: %f # default: -1.0\n", sparams.adaptive_target); + fprintf(stream, "adaptive_decay: %f # default: 0.9\n", sparams.adaptive_decay); fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false"); } diff --git a/common/sampling.cpp b/common/sampling.cpp index 769fc06c..6709ff41 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -99,7 +99,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vo result->n_valid = 0; } result->grammar = grmr; - // init DRY + llama_sampling_set_rng_seed(result, params.seed); for (const auto& cnstr : params.samplers_sequence) { switch (cnstr) @@ -116,11 +116,16 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vo break; } + case llama_sampler_type::ADAPTIVE_P: + { + result->adapt_p_ctx=llama_sampler_init_adaptive_p(params.adaptive_target, params.adaptive_decay, result->rng()); + break; + } default: break; } } - llama_sampling_set_rng_seed(result, params.seed); + return result; } @@ -247,11 +252,13 @@ std::string llama_sampling_print(const llama_sampling_params & params) { "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n" "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f\n" - "\txtc_probability = %.3f, xtc_threshold = %.3f, top_n_sigma = %.3f", + "\txtc_probability = %.3f, xtc_threshold = %.3f, top_n_sigma = %.3f\n" + "\tadaptive_target = %.2f, adaptive_decay = %.2f", params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau, - params.xtc_probability, params.xtc_threshold, params.top_n_sigma); + params.xtc_probability, params.xtc_threshold, params.top_n_sigma, + params.adaptive_target, params.adaptive_decay); return std::string(result); } @@ -283,6 +290,7 @@ std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) { case llama_sampler_type::TEMPERATURE: return "temperature"; case llama_sampler_type::XTC : return "xtc"; case llama_sampler_type::TOP_N_SIGMA: return "top_n_sigma"; + case llama_sampler_type::ADAPTIVE_P : return "adaptive_p"; default : return ""; } } @@ -297,7 +305,8 @@ std::vector llama_sampling_types_from_names(const std::vecto {"tfs_z", llama_sampler_type::TFS_Z}, {"xtc", llama_sampler_type::XTC}, {"top_n_sigma", llama_sampler_type::TOP_N_SIGMA}, - {"temperature", llama_sampler_type::TEMPERATURE} + {"temperature", llama_sampler_type::TEMPERATURE}, + {"adaptive_p", llama_sampler_type::ADAPTIVE_P}, }; // since samplers names are written multiple ways @@ -314,7 +323,8 @@ std::vector llama_sampling_types_from_names(const std::vecto {"tfs", llama_sampler_type::TFS_Z}, {"xtc", llama_sampler_type::XTC}, {"top-n-sigma", llama_sampler_type::TOP_N_SIGMA}, - {"temp", llama_sampler_type::TEMPERATURE} + {"temp", llama_sampler_type::TEMPERATURE}, + {"adaptive-p", llama_sampler_type::ADAPTIVE_P}, }; std::vector sampler_types; @@ -351,7 +361,8 @@ std::vector llama_sampling_types_from_chars(const std::strin {'f', llama_sampler_type::TFS_Z}, {'x', llama_sampler_type::XTC}, {'n', llama_sampler_type::TOP_N_SIGMA}, - {'t', llama_sampler_type::TEMPERATURE} + {'t', llama_sampler_type::TEMPERATURE}, + {'w', llama_sampler_type::ADAPTIVE_P}, }; std::vector sampler_types; @@ -405,6 +416,7 @@ static void sampler_queue( llama_sample_temp(ctx_main, &cur_p, temp); } break; + case llama_sampler_type::ADAPTIVE_P: llama_sample_adaptive_p(ctx_main, ctx_sampling->adapt_p_ctx, &cur_p); break; default : break; } } @@ -422,6 +434,7 @@ static llama_token llama_sampling_sample_impl( const int mirostat = params.mirostat; const float mirostat_tau = params.mirostat_tau; const float mirostat_eta = params.mirostat_eta; + const float adaptive_target = params.adaptive_target; std::vector original_logits; auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits); @@ -451,6 +464,17 @@ static llama_token llama_sampling_sample_impl( } else if (mirostat == 2) { llama_sample_temp(ctx_main, &cur_p, temp); id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); + } else if (adaptive_target >= 0.0f) { + // adaptive p sampling + static thread_local std::vector orig_probs; + orig_probs.resize(cur_p.size); + + // store original probabilities + for (size_t ii = 0; ii < cur_p.size; ++ii) { + orig_probs[ii] = cur_p.data[ii].p; + } + sampler_queue(ctx_main, params, ctx_sampling, cur_p, std::max(1, params.min_keep)); + id = llama_sample_token_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx, orig_probs.data()); } else { // temperature sampling size_t min_keep = std::max(1, params.min_keep); diff --git a/common/sampling.h b/common/sampling.h index 544d98eb..718dae34 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -18,7 +18,8 @@ enum class llama_sampler_type : char { XTC = 'x', TOP_N_SIGMA = 'n', TYPICAL_P = 'y', - TEMPERATURE = 't' + TEMPERATURE = 't', + ADAPTIVE_P = 'w', }; enum common_grammar_trigger_type { @@ -66,6 +67,8 @@ typedef struct llama_sampling_params { float xtc_probability = 0.0f; // xtc probability float xtc_threshold = 1.0f; // xtc threshold, disabled if > 0.5 float top_n_sigma = 0.0f; // top-n-sigma + float adaptive_target = -1.0f; // select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled) + float adaptive_decay = 0.90f; // decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation) bool penalize_nl = false; // consider newlines as a repeatable token uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context @@ -80,7 +83,8 @@ typedef struct llama_sampling_params { llama_sampler_type::MIN_P, llama_sampler_type::XTC, llama_sampler_type::TOP_N_SIGMA, - llama_sampler_type::TEMPERATURE + llama_sampler_type::TEMPERATURE, + llama_sampler_type::ADAPTIVE_P, }; @@ -118,6 +122,8 @@ struct llama_sampling_context { std::vector cur; llama_sampler_dry* smpl; + llama_sampler_adaptive_p * adapt_p_ctx; // adaptive p sampler + size_t n_valid; // Number of correct top tokens with correct probabilities. llama_token_data_array cur_p; // current candidates diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index a3f3155a..27d3c193 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -807,6 +807,8 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); + slot.sparams.adaptive_target = json_value(data, "adaptive_target", default_sparams.adaptive_target); + slot.sparams.adaptive_decay = json_value(data, "adaptive_decay", default_sparams.adaptive_decay); slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); @@ -1405,6 +1407,8 @@ json server_context::get_formated_generation(const server_slot& slot) const { {"mirostat", slot.sparams.mirostat}, {"mirostat_tau", slot.sparams.mirostat_tau}, {"mirostat_eta", slot.sparams.mirostat_eta}, + {"adaptive_target", slot.sparams.adaptive_target}, + {"adaptive_decay", slot.sparams.adaptive_decay}, {"penalize_nl", slot.sparams.penalize_nl}, {"stop", slot.params.antiprompt}, {"max_tokens", slot.params.n_predict}, // User configured n_predict diff --git a/include/llama.h b/include/llama.h index 16558c5b..8364a616 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1380,6 +1380,21 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns( /// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982 + /// @details Adaptive p sampler initializer + /// @param target Select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled) + /// @param decay Decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation) + LLAMA_API struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p( + const float target, + const float decay, + const uint32_t seed); + + /// @details Adaptive p sampler described in https://github.com/MrJackSpade/adaptive-p-docs/blob/main/README.md + void llama_sample_adaptive_p( + struct llama_context * ctx, + struct llama_sampler_adaptive_p * adapt_p_ctx, + llama_token_data_array * candidates); + + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. @@ -1417,6 +1432,13 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns( struct llama_context * ctx, llama_token_data_array * candidates); + /// @details Randonly selects a token from the candidates following adaptive p sampler. + llama_token llama_sample_token_adaptive_p( + struct llama_context * ctx, + llama_token_data_array * candidates, + struct llama_sampler_adaptive_p * adapt_p_ctx, + float * orig_probs); + // // Model split // diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index e7daa175..244c18d8 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1033,6 +1033,111 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab& } +// adaptive p + +llama_token llama_sample_token_adaptive_p_impl( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + struct llama_sampler_adaptive_p * adapt_p_ctx, + float * orig_probs) +{ + GGML_ASSERT(candidates->size > 0); + const int64_t t_start_sample_us = ggml_time_us(); + + const size_t count = candidates->size; + adapt_p_ctx->probs.resize(count); + + // cumulative distribution + const float max_logit = adapt_p_ctx->max_logit; + float cum_prob = 0.0f; + for (size_t i = 0; i < count; ++i) { + cum_prob += expf(candidates->data[i].logit - max_logit); + adapt_p_ctx->probs[i] = cum_prob; + } + adapt_p_ctx->probs.back() += 1.0f; // safety margin in case rng() ~= rng.max() + + // find token with cum_prob > target_cum_prob + const float target_cum_prob = cum_prob * (float)adapt_p_ctx->rng() / (float)adapt_p_ctx->rng.max(); + auto iter = std::upper_bound(adapt_p_ctx->probs.begin(), adapt_p_ctx->probs.end(), target_cum_prob); + GGML_ASSERT(iter != adapt_p_ctx->probs.end()); + llama_token id = candidates->data[std::distance(adapt_p_ctx->probs.begin(), iter)].id; + + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + smpl->n_sample++; + + // update history with original probability of selected token + adapt_p_ctx->weighted_sum = adapt_p_ctx->decay * adapt_p_ctx->weighted_sum + orig_probs[id]; + adapt_p_ctx->total_weight = adapt_p_ctx->decay * adapt_p_ctx->total_weight + 1.0f; + + return id; +} + +void llama_sampler_adaptive_p_apply(struct llama_sampler_adaptive_p * adapt_p_ctx, llama_token_data_array * candidates) +{ + if (adapt_p_ctx->target < 0.0f) { + // sampler is disabled + llama_sample_softmax_impl(nullptr, candidates); + return; + } + + // incomplete softmax because final division can be fused + float max_l = candidates->data[0].logit; + for (size_t i = 1; i < candidates->size; ++i) { + max_l = std::max(max_l, candidates->data[i].logit); + } + float cum_sum = 0.0f; + for (size_t i = 0; i < candidates->size; ++i) { + const float p = expf(candidates->data[i].logit - max_l); + candidates->data[i].p = p; + cum_sum += p; + } + + // compute adapted target probability + const float target = std::clamp(adapt_p_ctx->target, 0.0f, 1.0f); + const float adapted_target = std::clamp(adapt_p_ctx->total_weight == 0.0f + ? target + : 2.0f * target - (adapt_p_ctx->weighted_sum / adapt_p_ctx->total_weight), + 0.0f, 1.0f); + + // transformation constants + static constexpr float peak_logit_value = 5.0f; + static constexpr float inv_width = 1.0f / 0.3f; + static constexpr float sharpness = 10.0f; + + const float fused_target = adapted_target * inv_width; + const float fused_width = inv_width / cum_sum; + + // quadratic near target for finite differentiation, transitioning to linear decay in tails + // unbounded negative logits suppress far-from-target tokens after softmax + float max_logit = -INFINITY; + for (size_t i = 0; i < candidates->size; ++i) { + const float dist = std::abs(candidates->data[i].p * fused_width - fused_target); + const float logit = peak_logit_value - sharpness * dist * dist / (1.0f + dist); + candidates->data[i].logit = logit; + max_logit = std::max(max_logit, logit); + } + candidates->sorted = false; + adapt_p_ctx->max_logit = max_logit; +} + +struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p_impl( + const float target, + const float decay, + const uint32_t seed) +{ + const float clamped_decay = std::clamp(decay, 0.0f, 0.99f); + return new llama_sampler_adaptive_p { + /* .target = */ target, + /* .decay = */ clamped_decay, + /* .rng = */ std::mt19937(seed), + /* .weighted_sum = */ target / (1.0f - clamped_decay), + /* .total_weight = */ 1.0f / (1.0f - clamped_decay), + /* .max_logit = */ 0.0f, + /* .probs = */ {}, + }; +} + + // grammar struct llama_sampler_grammar { diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 855278e2..61228548 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -61,6 +61,24 @@ struct llama_sampler_dry * llama_sampler_init_dry_impl( void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_array* cur_p); +// maintains an exponential moving average of the *ORIGINAL* probabilities of selected tokens +// used to compute an adapted target at each sampling step. +// see llama.h for a full description of the sampler +struct llama_sampler_adaptive_p { + const float target; // target probability (0.0 - 1.0; negative = disabled) + const float decay; // EMA decay; history ≈ 1/(1-decay) tokens (0.0 - 0.99) + std::mt19937 rng; // RNG + float weighted_sum; // sum(p_n * decay^N) + float total_weight; // sum(decay^i), converges to 1/(1-decay) + float max_logit; // maximum logit found during transform + std::vector probs; // cumulative probabilities +}; + +void llama_sampler_adaptive_p_apply( + struct llama_sampler_adaptive_p * adapt_p_ctx, + llama_token_data_array * candidates); + +struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p_impl(const float target, const float decay, const uint32_t seed); void llama_sample_repetition_penalties_impl( @@ -83,6 +101,6 @@ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, ll llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng); llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); - +llama_token llama_sample_token_adaptive_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx, float * orig_probs); diff --git a/src/llama.cpp b/src/llama.cpp index 5c28e224..10db6d58 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7581,6 +7581,13 @@ void llama_sample_dry([[maybe_unused]] struct llama_context* ctx, struct llama_s llama_sampler_dry_apply(smpl, candidates_p); } +void llama_sample_adaptive_p( + [[maybe_unused]] struct llama_context * ctx, + struct llama_sampler_adaptive_p * adapt_p_ctx, + llama_token_data_array * candidates) { + llama_sampler_adaptive_p_apply(adapt_p_ctx, candidates); +} + void llama_sample_repetition_penalties( struct llama_context * ctx, llama_token_data_array * candidates, @@ -7620,6 +7627,15 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, ctx->sampling.rng); } +llama_token llama_sample_token_adaptive_p( + struct llama_context * ctx, + llama_token_data_array * candidates, + struct llama_sampler_adaptive_p * adapt_p_ctx, + float * orig_probs) +{ + return llama_sample_token_adaptive_p_impl(&ctx->sampling, candidates, adapt_p_ctx, orig_probs); +} + int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) { static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf"; if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) { @@ -7671,6 +7687,13 @@ void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token) smpl->last_tokens.push_back(token); } + +struct llama_sampler_adaptive_p * llama_sampler_init_adaptive_p(const float target, const float decay, const uint32_t seed) +{ + return llama_sampler_init_adaptive_p_impl(target, decay, seed); +} + + int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int split_no, int split_count) { std::string str_split_path(split_path); char postfix[32];