diff --git a/common/sampling.cpp b/common/sampling.cpp index 60318fa7..f68fdc0e 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -125,7 +125,7 @@ struct llama_sampling_context * common_sampler_init(const struct llama_vocab* vo break; } } - + return result; } @@ -419,7 +419,7 @@ static void sampler_queue( case llama_sampler_type::ADAPTIVE_P: use_adaptive_p = true; break; default : break; } - + } if (use_adaptive_p) { // adaptive p should be put to the last, so we ignore the order in the sampler @@ -451,7 +451,7 @@ static llama_token llama_sampling_sample_impl( if (ctx_sampling->grammar != NULL && is_resampling) { float* logits = llama_get_logits_ith(ctx_main, idx); // Apply grammar constraints to all candidates - llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p); + llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p); } if (temp < 0.0) { @@ -471,7 +471,7 @@ static llama_token llama_sampling_sample_impl( id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); } else if (adaptive_target >= 0.0f && ctx_sampling->adapt_p_ctx!=nullptr) { // adaptive p sampling - llama_prep_adaptive_p(&cur_p, ctx_sampling->adapt_p_ctx); + llama_prep_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx); sampler_queue(ctx_main, params, ctx_sampling, cur_p, std::max(1, params.min_keep)); id = llama_sample_token_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx); } else { diff --git a/include/llama.h b/include/llama.h index 3d17f9b2..dd0bb409 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1389,7 +1389,7 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns( const float decay, const uint32_t seed); - void llama_prep_adaptive_p( + void llama_prep_adaptive_p(struct llama_context * ctx, llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 9d3134e4..acfcefd4 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1038,8 +1038,7 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab& llama_token llama_sample_token_adaptive_p_impl( struct llama_sampling * smpl, llama_token_data_array * candidates, - struct llama_sampler_adaptive_p * adapt_p_ctx) -{ + struct llama_sampler_adaptive_p * adapt_p_ctx) { GGML_ASSERT(candidates->size > 0); const int64_t t_start_sample_us = ggml_time_us(); @@ -1062,30 +1061,38 @@ llama_token llama_sample_token_adaptive_p_impl( const size_t idx = std::distance(ctx->cum_probs.begin(), iter); llama_token id = candidates->data[idx].id; + if (auto it = ctx->orig_prob_map.find(id); it != ctx->orig_prob_map.end()) { + float update_prob = it->second / ctx->cum_orig_prob; + ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob; + ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f; + } + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; smpl->n_sample++; - float update_prob = candidates->data[idx].p; // not ideal - if (ctx->orig_prob_map.contains(id)) { - // selected token id is among tracked ids - update_prob = ctx->orig_prob_map[id] / ctx->cum_orig_prob; - } + //float update_prob = candidates->data[idx].p; // not ideal + //if (ctx->orig_prob_map.contains(id)) { + // // selected token id is among tracked ids + // update_prob = ctx->orig_prob_map[id] / ctx->cum_orig_prob; + //} - // update history with original probability of selected token - ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob; - ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f; + //// update history with original probability of selected token + //ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob; + //ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f; return id; } -void llama_sample_adaptive_p_impl(llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx) -{ +void llama_sample_adaptive_p_impl(struct llama_sampling * ctx, llama_token_data_array * candidates, + struct llama_sampler_adaptive_p * adapt_p_ctx) { if (adapt_p_ctx->target < 0.0f) { // sampler is disabled llama_sample_softmax_impl(nullptr, candidates); return; } + auto t_start = ggml_time_us(); + // incomplete softmax because final division can be fused float max_l = candidates->data[0].logit; if (!candidates->sorted) { @@ -1126,48 +1133,86 @@ void llama_sample_adaptive_p_impl(llama_token_data_array * candidates, struct ll } candidates->sorted = false; adapt_p_ctx->max_xform_logit = max_logit; + + ctx->t_sample_us += ggml_time_us() - t_start; } -void llama_prep_adaptive_p_impl( +void llama_prep_adaptive_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, - struct llama_sampler_adaptive_p * adapt_p_ctx) -{ + struct llama_sampler_adaptive_p * adapt_p_ctx) { + constexpr float kDelta = 16.6f; + auto t_start = ggml_time_us(); if (!candidates->sorted) { - std::sort(candidates->data, candidates->data + candidates->size, - [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); - candidates->sorted = true; + float max_logit = candidates->data[0].logit; + for (int j = 1; j < int(candidates->size); ++j) { + max_logit = std::max(max_logit, candidates->data[j].logit); + } + float min_logit = max_logit - kDelta; + float cum_prob = 0.0f; + adapt_p_ctx->orig_prob_map.clear(); + for (int j = 0; j < int(candidates->size); ++j) { + if (candidates->data[j].logit > min_logit) { + float prob = expf(candidates->data[j].logit - max_logit); + cum_prob += prob; + adapt_p_ctx->orig_prob_map[candidates->data[j].id] = prob; + } + } + adapt_p_ctx->cum_orig_prob = cum_prob; + if (smpl) smpl->t_sample_us += ggml_time_us() - t_start; + return; } - const float max_logit = candidates->data[0].logit; - // decide how many tokens to track based on logit delta - // i.e. do not track unlikely tokens - auto iter = std::lower_bound( - candidates->data, - candidates->data + candidates->size, - max_logit - 16.6f, // delta - [](const llama_token_data & data, const float delta) { - return data.logit > delta; - }); - const size_t n_track = std::distance(candidates->data, iter); - - // store orig_prob_map and cum_orig_prob to estimate original probability later + float max_logit = candidates->data[0].logit; + float min_logit = max_logit - kDelta; float cum_prob = 0.0f; adapt_p_ctx->orig_prob_map.clear(); - for (size_t i = 0; i < n_track; ++i) { - const float prob = expf(candidates->data[i].logit - max_logit); + for (int j = 0; j < int(candidates->size); ++j) { + auto logit = candidates->data[j].logit; + if (logit <= min_logit) { + break; + } + float prob = expf(logit - max_logit); cum_prob += prob; - adapt_p_ctx->orig_prob_map[candidates->data[i].id] = prob; + adapt_p_ctx->orig_prob_map[candidates->data[j].id] = prob; } adapt_p_ctx->cum_orig_prob = cum_prob; + if (smpl) smpl->t_sample_us += ggml_time_us() - t_start; + + //if (!candidates->sorted) { + // std::sort(candidates->data, candidates->data + candidates->size, + // [](const llama_token_data & a, const llama_token_data & b) { + // return a.logit > b.logit; + // }); + // candidates->sorted = true; + //} + //const float max_logit = candidates->data[0].logit; + + //// decide how many tokens to track based on logit delta + //// i.e. do not track unlikely tokens + //auto iter = std::lower_bound( + // candidates->data, + // candidates->data + candidates->size, + // max_logit - kDelta, // delta + // [](const llama_token_data & data, const float delta) { + // return data.logit > delta; + // }); + //const size_t n_track = std::distance(candidates->data, iter); + + //// store orig_prob_map and cum_orig_prob to estimate original probability later + //float cum_prob = 0.0f; + //adapt_p_ctx->orig_prob_map.clear(); + //for (size_t i = 0; i < n_track; ++i) { + // const float prob = expf(candidates->data[i].logit - max_logit); + // cum_prob += prob; + // adapt_p_ctx->orig_prob_map[candidates->data[i].id] = prob; + //} + //adapt_p_ctx->cum_orig_prob = cum_prob; } struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl( const float target, const float decay, - const uint32_t seed) -{ + const uint32_t seed) { const float clamped_decay = std::clamp(decay, 0.0f, 0.99f); return new llama_sampler_adaptive_p { /* .target = */ target, diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 55b4371c..3249c843 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -89,10 +89,12 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl( const uint32_t seed); void llama_prep_adaptive_p_impl( + struct llama_sampling * smpl, llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx); void llama_sample_adaptive_p_impl( + struct llama_sampling * smpl, llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx); diff --git a/src/llama.cpp b/src/llama.cpp index 7175eca9..d76ac5b7 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7690,14 +7690,12 @@ void llama_sample_dry([[maybe_unused]] struct llama_context* ctx, struct llama_s void llama_sample_adaptive_p( [[maybe_unused]] struct llama_context * ctx, llama_token_data_array * candidates, - struct llama_sampler_adaptive_p * adapt_p_ctx) -{ - llama_sample_adaptive_p_impl(candidates, adapt_p_ctx); + struct llama_sampler_adaptive_p * adapt_p_ctx) { + llama_sample_adaptive_p_impl(&ctx->sampling, candidates, adapt_p_ctx); } -void llama_prep_adaptive_p(llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx) -{ - llama_prep_adaptive_p_impl(candidates, adapt_p_ctx); +void llama_prep_adaptive_p(struct llama_context * ctx, llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx) { + llama_prep_adaptive_p_impl(&ctx->sampling, candidates, adapt_p_ctx); } @@ -7743,8 +7741,7 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra llama_token llama_sample_token_adaptive_p( struct llama_context * ctx, llama_token_data_array * candidates, - struct llama_sampler_adaptive_p * adapt_p_ctx) -{ + struct llama_sampler_adaptive_p * adapt_p_ctx) { return llama_sample_token_adaptive_p_impl(&ctx->sampling, candidates, adapt_p_ctx); } @@ -7800,8 +7797,7 @@ void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token) } -struct llama_sampler_adaptive_p * llama_init_adaptive_p(const float target, const float decay, const uint32_t seed) -{ +struct llama_sampler_adaptive_p * llama_init_adaptive_p(const float target, const float decay, const uint32_t seed) { return llama_init_adaptive_p_impl(target, decay, seed); }