diff --git a/common/sampling.cpp b/common/sampling.cpp index f68fdc0e..ba8d3f67 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -118,7 +118,9 @@ struct llama_sampling_context * common_sampler_init(const struct llama_vocab* vo } case llama_sampler_type::ADAPTIVE_P: { - result->adapt_p_ctx = llama_init_adaptive_p(params.adaptive_target, params.adaptive_decay, result->rng()); + GGML_ASSERT(vocab); + auto n_vocab = llama_vocab_n_tokens(vocab); + result->adapt_p_ctx = llama_init_adaptive_p(n_vocab, params.adaptive_target, params.adaptive_decay, result->rng()); break; } default: diff --git a/ggml/src/iqk/iqk_cpu_ops.cpp b/ggml/src/iqk/iqk_cpu_ops.cpp index 9a20e0c8..4581a1dd 100644 --- a/ggml/src/iqk/iqk_cpu_ops.cpp +++ b/ggml/src/iqk/iqk_cpu_ops.cpp @@ -16,6 +16,7 @@ #include #include #include +//#include #ifdef __ARM_NEON #include @@ -503,3 +504,49 @@ void iqk_hadamard(struct ggml_tensor * dst, int ith, int nth) { fast_ht(nh, y); } } + +namespace { +float iqk_exp_with_thresh_impl(int n, float * logits, float max, float min) { + float sum = 0; +#ifdef __AVX2__ + auto vmax = _mm256_set1_ps(max); + auto vmin = _mm256_set1_ps(min); + auto vsum = _mm256_setzero_ps(); + for (int j = 0; j < n/8; ++j) { + auto x = _mm256_loadu_ps(logits); + auto mask = _mm256_cmp_ps(x, vmin, _CMP_GE_OQ); + auto exp_x = v_expf(_mm256_sub_ps(x, vmax)); + exp_x = _mm256_and_ps(exp_x, mask); + vsum = _mm256_add_ps(vsum, exp_x); + _mm256_storeu_ps(logits, exp_x); + logits += 8; + } + sum = hsum_float_8(vsum); + for (int j = 0; j < n - 8*(n/8); ++j) { + float p = logits[j] > min ? expf(logits[j] - max) : 0; + sum += p; + logits[j] = p; + } +#else + for (int j = 0; j < n; ++j) { + float p = logits[j] > min ? expf(logits[j] - max) : 0; + sum += p; + logits[j] = p; + } +#endif + return sum; +} +} + +float iqk_exp_with_thresh(int n, float * logits, float max, float min) { + return iqk_exp_with_thresh_impl(n, logits, max, min); + //if (n < (1 << 16)) return iqk_exp_with_thresh_impl(n, logits, max, min); + //std::array result; + //auto compute = [logits, max, min, &result] (int first, int last, int ith) { + // result[ith] = iqk_exp_with_thresh_impl(last - first, logits + first, max, min); + //}; + //auto t = std::thread(compute, 0, n/2, 0); + //compute(n/2, n, 1); + //t.join(); + //return result[0] + result[1]; +} diff --git a/ggml/src/iqk/iqk_cpu_ops.h b/ggml/src/iqk/iqk_cpu_ops.h index 8656d7f1..267c3e85 100644 --- a/ggml/src/iqk/iqk_cpu_ops.h +++ b/ggml/src/iqk/iqk_cpu_ops.h @@ -30,6 +30,8 @@ void iqk_mul_multi_add(struct ggml_tensor * dst, int ith, int nth); void iqk_hadamard(struct ggml_tensor * dst, int ith, int nth); +float iqk_exp_with_thresh(int n, float * logits, float max, float min); + #ifdef __cplusplus } #endif diff --git a/include/llama.h b/include/llama.h index dd0bb409..72cb9edd 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1384,7 +1384,7 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns( /// @details Adaptive p sampler initializer /// @param target Select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled) /// @param decay Decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation) - LLAMA_API struct llama_sampler_adaptive_p * llama_init_adaptive_p( + LLAMA_API struct llama_sampler_adaptive_p * llama_init_adaptive_p(int n_vocab, const float target, const float decay, const uint32_t seed); @@ -1394,10 +1394,9 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns( struct llama_sampler_adaptive_p * adapt_p_ctx); /// @details Adaptive p sampler described in https://github.com/MrJackSpade/adaptive-p-docs/blob/main/README.md - void llama_sample_adaptive_p( - struct llama_context * ctx, - llama_token_data_array * candidates, - struct llama_sampler_adaptive_p * adapt_p_ctx); + void llama_sample_adaptive_p(struct llama_context * ctx, + llama_token_data_array * candidates, + struct llama_sampler_adaptive_p * adapt_p_ctx); /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index acfcefd4..b3106937 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -2,6 +2,8 @@ #include "llama-vocab.h" #include "llama-grammar.h" +#include "iqk/iqk_cpu_ops.h" + #include #include #include @@ -1061,8 +1063,8 @@ llama_token llama_sample_token_adaptive_p_impl( const size_t idx = std::distance(ctx->cum_probs.begin(), iter); llama_token id = candidates->data[idx].id; - if (auto it = ctx->orig_prob_map.find(id); it != ctx->orig_prob_map.end()) { - float update_prob = it->second / ctx->cum_orig_prob; + GGML_ASSERT(id < int(ctx->orig_prob.size())); + if (auto update_prob = ctx->orig_prob[id]; update_prob > 0) { ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob; ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f; } @@ -1070,16 +1072,6 @@ llama_token llama_sample_token_adaptive_p_impl( smpl->t_sample_us += ggml_time_us() - t_start_sample_us; smpl->n_sample++; - //float update_prob = candidates->data[idx].p; // not ideal - //if (ctx->orig_prob_map.contains(id)) { - // // selected token id is among tracked ids - // update_prob = ctx->orig_prob_map[id] / ctx->cum_orig_prob; - //} - - //// update history with original probability of selected token - //ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob; - //ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f; - return id; } @@ -1137,94 +1129,49 @@ void llama_sample_adaptive_p_impl(struct llama_sampling * ctx, llama_token_data_ ctx->t_sample_us += ggml_time_us() - t_start; } -void llama_prep_adaptive_p_impl(struct llama_sampling * smpl, +void llama_prep_adaptive_p_impl( + struct llama_sampling * smpl, llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx) { - constexpr float kDelta = 16.6f; + constexpr float kDelta = 30.0f; //16.6f; auto t_start = ggml_time_us(); - if (!candidates->sorted) { - float max_logit = candidates->data[0].logit; - for (int j = 1; j < int(candidates->size); ++j) { - max_logit = std::max(max_logit, candidates->data[j].logit); - } - float min_logit = max_logit - kDelta; - float cum_prob = 0.0f; - adapt_p_ctx->orig_prob_map.clear(); - for (int j = 0; j < int(candidates->size); ++j) { - if (candidates->data[j].logit > min_logit) { - float prob = expf(candidates->data[j].logit - max_logit); - cum_prob += prob; - adapt_p_ctx->orig_prob_map[candidates->data[j].id] = prob; - } - } - adapt_p_ctx->cum_orig_prob = cum_prob; - if (smpl) smpl->t_sample_us += ggml_time_us() - t_start; - return; + auto & orig_prob = adapt_p_ctx->orig_prob; + if (candidates->size != orig_prob.size() || candidates->sorted) { + LLAMA_LOG_ERROR("%s: this function must be called before any other sampler has been applied\n", __func__); + LLAMA_LOG_ERROR("%s: the sampler has been initialized with a vocabulary of %zu, but is being called with %zu candidates\n", + __func__, orig_prob.size(), candidates->size); + GGML_ABORT("Bad candidates in adaptive_p sampler"); } - float max_logit = candidates->data[0].logit; - float min_logit = max_logit - kDelta; - float cum_prob = 0.0f; - adapt_p_ctx->orig_prob_map.clear(); + float max_logit = -INFINITY; for (int j = 0; j < int(candidates->size); ++j) { - auto logit = candidates->data[j].logit; - if (logit <= min_logit) { - break; - } - float prob = expf(logit - max_logit); - cum_prob += prob; - adapt_p_ctx->orig_prob_map[candidates->data[j].id] = prob; + orig_prob[j] = candidates->data[j].logit; + max_logit = std::max(max_logit, orig_prob[j]); } - adapt_p_ctx->cum_orig_prob = cum_prob; + adapt_p_ctx->cum_orig_prob = iqk_exp_with_thresh(orig_prob.size(), orig_prob.data(), max_logit, max_logit - kDelta); + if (smpl) smpl->t_sample_us += ggml_time_us() - t_start; - - //if (!candidates->sorted) { - // std::sort(candidates->data, candidates->data + candidates->size, - // [](const llama_token_data & a, const llama_token_data & b) { - // return a.logit > b.logit; - // }); - // candidates->sorted = true; - //} - //const float max_logit = candidates->data[0].logit; - - //// decide how many tokens to track based on logit delta - //// i.e. do not track unlikely tokens - //auto iter = std::lower_bound( - // candidates->data, - // candidates->data + candidates->size, - // max_logit - kDelta, // delta - // [](const llama_token_data & data, const float delta) { - // return data.logit > delta; - // }); - //const size_t n_track = std::distance(candidates->data, iter); - - //// store orig_prob_map and cum_orig_prob to estimate original probability later - //float cum_prob = 0.0f; - //adapt_p_ctx->orig_prob_map.clear(); - //for (size_t i = 0; i < n_track; ++i) { - // const float prob = expf(candidates->data[i].logit - max_logit); - // cum_prob += prob; - // adapt_p_ctx->orig_prob_map[candidates->data[i].id] = prob; - //} - //adapt_p_ctx->cum_orig_prob = cum_prob; } -struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl( +struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab, const float target, const float decay, const uint32_t seed) { + GGML_ASSERT(n_vocab > 0); const float clamped_decay = std::clamp(decay, 0.0f, 0.99f); - return new llama_sampler_adaptive_p { + auto result = new llama_sampler_adaptive_p { /* .target = */ target, /* .decay = */ clamped_decay, /* .rng = */ std::mt19937(seed), /* .weighted_sum = */ target / (1.0f - clamped_decay), /* .total_weight = */ 1.0f / (1.0f - clamped_decay), - /* .orig_logit_map = */ {}, + /* .orig_prob = */ {}, /* .cum_orig_prob = */ 0.0f, /* .max_xform_logit = */ -INFINITY, /* .cum_probs = */ {}, }; + result->orig_prob.resize(n_vocab); + return result; } // grammar diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 3249c843..8ebbfb49 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -73,7 +73,7 @@ struct llama_sampler_adaptive_p { float total_weight; // sum(decay^i), converges to 1/(1-decay) // first referenced in prep - std::unordered_map orig_prob_map; // probabilities before sampler_queue + std::vector orig_prob; // for storing the original proibabilities float cum_orig_prob; // for normalizing orig_prob in sample_token // first referenced in sample @@ -83,7 +83,7 @@ struct llama_sampler_adaptive_p { std::vector cum_probs; // cumulative probability distribution }; -struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl( +struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab, const float target, const float decay, const uint32_t seed); diff --git a/src/llama.cpp b/src/llama.cpp index d76ac5b7..488dcadc 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7687,10 +7687,9 @@ void llama_sample_dry([[maybe_unused]] struct llama_context* ctx, struct llama_s llama_sampler_dry_apply(smpl, candidates_p); } -void llama_sample_adaptive_p( - [[maybe_unused]] struct llama_context * ctx, - llama_token_data_array * candidates, - struct llama_sampler_adaptive_p * adapt_p_ctx) { +void llama_sample_adaptive_p(llama_context * ctx, + llama_token_data_array * candidates, + llama_sampler_adaptive_p * adapt_p_ctx) { llama_sample_adaptive_p_impl(&ctx->sampling, candidates, adapt_p_ctx); } @@ -7797,8 +7796,8 @@ void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token) } -struct llama_sampler_adaptive_p * llama_init_adaptive_p(const float target, const float decay, const uint32_t seed) { - return llama_init_adaptive_p_impl(target, decay, seed); +struct llama_sampler_adaptive_p * llama_init_adaptive_p(int n_vocab, const float target, const float decay, const uint32_t seed) { + return llama_init_adaptive_p_impl(n_vocab, target, decay, seed); }