mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Faster adaptive_p sampling (#1165)
* A hopefully more efficient adaptive_p sampling * Once at it, lets fix the formatting too * More formatting * Hopefully better * This should be better * Correctly accumulate adaptive_p sampling time * AVX2
This commit is contained in:
@@ -118,7 +118,9 @@ struct llama_sampling_context * common_sampler_init(const struct llama_vocab* vo
|
|||||||
}
|
}
|
||||||
case llama_sampler_type::ADAPTIVE_P:
|
case llama_sampler_type::ADAPTIVE_P:
|
||||||
{
|
{
|
||||||
result->adapt_p_ctx = llama_init_adaptive_p(params.adaptive_target, params.adaptive_decay, result->rng());
|
GGML_ASSERT(vocab);
|
||||||
|
auto n_vocab = llama_vocab_n_tokens(vocab);
|
||||||
|
result->adapt_p_ctx = llama_init_adaptive_p(n_vocab, params.adaptive_target, params.adaptive_decay, result->rng());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
//#include <thread>
|
||||||
|
|
||||||
#ifdef __ARM_NEON
|
#ifdef __ARM_NEON
|
||||||
#include <arm_neon.h>
|
#include <arm_neon.h>
|
||||||
@@ -503,3 +504,49 @@ void iqk_hadamard(struct ggml_tensor * dst, int ith, int nth) {
|
|||||||
fast_ht(nh, y);
|
fast_ht(nh, y);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
float iqk_exp_with_thresh_impl(int n, float * logits, float max, float min) {
|
||||||
|
float sum = 0;
|
||||||
|
#ifdef __AVX2__
|
||||||
|
auto vmax = _mm256_set1_ps(max);
|
||||||
|
auto vmin = _mm256_set1_ps(min);
|
||||||
|
auto vsum = _mm256_setzero_ps();
|
||||||
|
for (int j = 0; j < n/8; ++j) {
|
||||||
|
auto x = _mm256_loadu_ps(logits);
|
||||||
|
auto mask = _mm256_cmp_ps(x, vmin, _CMP_GE_OQ);
|
||||||
|
auto exp_x = v_expf(_mm256_sub_ps(x, vmax));
|
||||||
|
exp_x = _mm256_and_ps(exp_x, mask);
|
||||||
|
vsum = _mm256_add_ps(vsum, exp_x);
|
||||||
|
_mm256_storeu_ps(logits, exp_x);
|
||||||
|
logits += 8;
|
||||||
|
}
|
||||||
|
sum = hsum_float_8(vsum);
|
||||||
|
for (int j = 0; j < n - 8*(n/8); ++j) {
|
||||||
|
float p = logits[j] > min ? expf(logits[j] - max) : 0;
|
||||||
|
sum += p;
|
||||||
|
logits[j] = p;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
for (int j = 0; j < n; ++j) {
|
||||||
|
float p = logits[j] > min ? expf(logits[j] - max) : 0;
|
||||||
|
sum += p;
|
||||||
|
logits[j] = p;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float iqk_exp_with_thresh(int n, float * logits, float max, float min) {
|
||||||
|
return iqk_exp_with_thresh_impl(n, logits, max, min);
|
||||||
|
//if (n < (1 << 16)) return iqk_exp_with_thresh_impl(n, logits, max, min);
|
||||||
|
//std::array<float, 2> result;
|
||||||
|
//auto compute = [logits, max, min, &result] (int first, int last, int ith) {
|
||||||
|
// result[ith] = iqk_exp_with_thresh_impl(last - first, logits + first, max, min);
|
||||||
|
//};
|
||||||
|
//auto t = std::thread(compute, 0, n/2, 0);
|
||||||
|
//compute(n/2, n, 1);
|
||||||
|
//t.join();
|
||||||
|
//return result[0] + result[1];
|
||||||
|
}
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ void iqk_mul_multi_add(struct ggml_tensor * dst, int ith, int nth);
|
|||||||
|
|
||||||
void iqk_hadamard(struct ggml_tensor * dst, int ith, int nth);
|
void iqk_hadamard(struct ggml_tensor * dst, int ith, int nth);
|
||||||
|
|
||||||
|
float iqk_exp_with_thresh(int n, float * logits, float max, float min);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -1384,7 +1384,7 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
|
|||||||
/// @details Adaptive p sampler initializer
|
/// @details Adaptive p sampler initializer
|
||||||
/// @param target Select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
|
/// @param target Select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
|
||||||
/// @param decay Decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
|
/// @param decay Decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
|
||||||
LLAMA_API struct llama_sampler_adaptive_p * llama_init_adaptive_p(
|
LLAMA_API struct llama_sampler_adaptive_p * llama_init_adaptive_p(int n_vocab,
|
||||||
const float target,
|
const float target,
|
||||||
const float decay,
|
const float decay,
|
||||||
const uint32_t seed);
|
const uint32_t seed);
|
||||||
@@ -1394,8 +1394,7 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
|
|||||||
struct llama_sampler_adaptive_p * adapt_p_ctx);
|
struct llama_sampler_adaptive_p * adapt_p_ctx);
|
||||||
|
|
||||||
/// @details Adaptive p sampler described in https://github.com/MrJackSpade/adaptive-p-docs/blob/main/README.md
|
/// @details Adaptive p sampler described in https://github.com/MrJackSpade/adaptive-p-docs/blob/main/README.md
|
||||||
void llama_sample_adaptive_p(
|
void llama_sample_adaptive_p(struct llama_context * ctx,
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_token_data_array * candidates,
|
llama_token_data_array * candidates,
|
||||||
struct llama_sampler_adaptive_p * adapt_p_ctx);
|
struct llama_sampler_adaptive_p * adapt_p_ctx);
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,8 @@
|
|||||||
#include "llama-vocab.h"
|
#include "llama-vocab.h"
|
||||||
#include "llama-grammar.h"
|
#include "llama-grammar.h"
|
||||||
|
|
||||||
|
#include "iqk/iqk_cpu_ops.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <ctime>
|
#include <ctime>
|
||||||
@@ -1061,8 +1063,8 @@ llama_token llama_sample_token_adaptive_p_impl(
|
|||||||
const size_t idx = std::distance(ctx->cum_probs.begin(), iter);
|
const size_t idx = std::distance(ctx->cum_probs.begin(), iter);
|
||||||
llama_token id = candidates->data[idx].id;
|
llama_token id = candidates->data[idx].id;
|
||||||
|
|
||||||
if (auto it = ctx->orig_prob_map.find(id); it != ctx->orig_prob_map.end()) {
|
GGML_ASSERT(id < int(ctx->orig_prob.size()));
|
||||||
float update_prob = it->second / ctx->cum_orig_prob;
|
if (auto update_prob = ctx->orig_prob[id]; update_prob > 0) {
|
||||||
ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob;
|
ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob;
|
||||||
ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f;
|
ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f;
|
||||||
}
|
}
|
||||||
@@ -1070,16 +1072,6 @@ llama_token llama_sample_token_adaptive_p_impl(
|
|||||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
smpl->n_sample++;
|
smpl->n_sample++;
|
||||||
|
|
||||||
//float update_prob = candidates->data[idx].p; // not ideal
|
|
||||||
//if (ctx->orig_prob_map.contains(id)) {
|
|
||||||
// // selected token id is among tracked ids
|
|
||||||
// update_prob = ctx->orig_prob_map[id] / ctx->cum_orig_prob;
|
|
||||||
//}
|
|
||||||
|
|
||||||
//// update history with original probability of selected token
|
|
||||||
//ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob;
|
|
||||||
//ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f;
|
|
||||||
|
|
||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1137,94 +1129,49 @@ void llama_sample_adaptive_p_impl(struct llama_sampling * ctx, llama_token_data_
|
|||||||
ctx->t_sample_us += ggml_time_us() - t_start;
|
ctx->t_sample_us += ggml_time_us() - t_start;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_prep_adaptive_p_impl(struct llama_sampling * smpl,
|
void llama_prep_adaptive_p_impl(
|
||||||
|
struct llama_sampling * smpl,
|
||||||
llama_token_data_array * candidates,
|
llama_token_data_array * candidates,
|
||||||
struct llama_sampler_adaptive_p * adapt_p_ctx) {
|
struct llama_sampler_adaptive_p * adapt_p_ctx) {
|
||||||
constexpr float kDelta = 16.6f;
|
constexpr float kDelta = 30.0f; //16.6f;
|
||||||
auto t_start = ggml_time_us();
|
auto t_start = ggml_time_us();
|
||||||
if (!candidates->sorted) {
|
auto & orig_prob = adapt_p_ctx->orig_prob;
|
||||||
float max_logit = candidates->data[0].logit;
|
if (candidates->size != orig_prob.size() || candidates->sorted) {
|
||||||
for (int j = 1; j < int(candidates->size); ++j) {
|
LLAMA_LOG_ERROR("%s: this function must be called before any other sampler has been applied\n", __func__);
|
||||||
max_logit = std::max(max_logit, candidates->data[j].logit);
|
LLAMA_LOG_ERROR("%s: the sampler has been initialized with a vocabulary of %zu, but is being called with %zu candidates\n",
|
||||||
|
__func__, orig_prob.size(), candidates->size);
|
||||||
|
GGML_ABORT("Bad candidates in adaptive_p sampler");
|
||||||
}
|
}
|
||||||
float min_logit = max_logit - kDelta;
|
|
||||||
float cum_prob = 0.0f;
|
float max_logit = -INFINITY;
|
||||||
adapt_p_ctx->orig_prob_map.clear();
|
|
||||||
for (int j = 0; j < int(candidates->size); ++j) {
|
for (int j = 0; j < int(candidates->size); ++j) {
|
||||||
if (candidates->data[j].logit > min_logit) {
|
orig_prob[j] = candidates->data[j].logit;
|
||||||
float prob = expf(candidates->data[j].logit - max_logit);
|
max_logit = std::max(max_logit, orig_prob[j]);
|
||||||
cum_prob += prob;
|
|
||||||
adapt_p_ctx->orig_prob_map[candidates->data[j].id] = prob;
|
|
||||||
}
|
}
|
||||||
}
|
adapt_p_ctx->cum_orig_prob = iqk_exp_with_thresh(orig_prob.size(), orig_prob.data(), max_logit, max_logit - kDelta);
|
||||||
adapt_p_ctx->cum_orig_prob = cum_prob;
|
|
||||||
if (smpl) smpl->t_sample_us += ggml_time_us() - t_start;
|
if (smpl) smpl->t_sample_us += ggml_time_us() - t_start;
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
float max_logit = candidates->data[0].logit;
|
struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab,
|
||||||
float min_logit = max_logit - kDelta;
|
|
||||||
float cum_prob = 0.0f;
|
|
||||||
adapt_p_ctx->orig_prob_map.clear();
|
|
||||||
for (int j = 0; j < int(candidates->size); ++j) {
|
|
||||||
auto logit = candidates->data[j].logit;
|
|
||||||
if (logit <= min_logit) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
float prob = expf(logit - max_logit);
|
|
||||||
cum_prob += prob;
|
|
||||||
adapt_p_ctx->orig_prob_map[candidates->data[j].id] = prob;
|
|
||||||
}
|
|
||||||
adapt_p_ctx->cum_orig_prob = cum_prob;
|
|
||||||
if (smpl) smpl->t_sample_us += ggml_time_us() - t_start;
|
|
||||||
|
|
||||||
//if (!candidates->sorted) {
|
|
||||||
// std::sort(candidates->data, candidates->data + candidates->size,
|
|
||||||
// [](const llama_token_data & a, const llama_token_data & b) {
|
|
||||||
// return a.logit > b.logit;
|
|
||||||
// });
|
|
||||||
// candidates->sorted = true;
|
|
||||||
//}
|
|
||||||
//const float max_logit = candidates->data[0].logit;
|
|
||||||
|
|
||||||
//// decide how many tokens to track based on logit delta
|
|
||||||
//// i.e. do not track unlikely tokens
|
|
||||||
//auto iter = std::lower_bound(
|
|
||||||
// candidates->data,
|
|
||||||
// candidates->data + candidates->size,
|
|
||||||
// max_logit - kDelta, // delta
|
|
||||||
// [](const llama_token_data & data, const float delta) {
|
|
||||||
// return data.logit > delta;
|
|
||||||
// });
|
|
||||||
//const size_t n_track = std::distance(candidates->data, iter);
|
|
||||||
|
|
||||||
//// store orig_prob_map and cum_orig_prob to estimate original probability later
|
|
||||||
//float cum_prob = 0.0f;
|
|
||||||
//adapt_p_ctx->orig_prob_map.clear();
|
|
||||||
//for (size_t i = 0; i < n_track; ++i) {
|
|
||||||
// const float prob = expf(candidates->data[i].logit - max_logit);
|
|
||||||
// cum_prob += prob;
|
|
||||||
// adapt_p_ctx->orig_prob_map[candidates->data[i].id] = prob;
|
|
||||||
//}
|
|
||||||
//adapt_p_ctx->cum_orig_prob = cum_prob;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(
|
|
||||||
const float target,
|
const float target,
|
||||||
const float decay,
|
const float decay,
|
||||||
const uint32_t seed) {
|
const uint32_t seed) {
|
||||||
|
GGML_ASSERT(n_vocab > 0);
|
||||||
const float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
|
const float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
|
||||||
return new llama_sampler_adaptive_p {
|
auto result = new llama_sampler_adaptive_p {
|
||||||
/* .target = */ target,
|
/* .target = */ target,
|
||||||
/* .decay = */ clamped_decay,
|
/* .decay = */ clamped_decay,
|
||||||
/* .rng = */ std::mt19937(seed),
|
/* .rng = */ std::mt19937(seed),
|
||||||
/* .weighted_sum = */ target / (1.0f - clamped_decay),
|
/* .weighted_sum = */ target / (1.0f - clamped_decay),
|
||||||
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
|
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
|
||||||
/* .orig_logit_map = */ {},
|
/* .orig_prob = */ {},
|
||||||
/* .cum_orig_prob = */ 0.0f,
|
/* .cum_orig_prob = */ 0.0f,
|
||||||
/* .max_xform_logit = */ -INFINITY,
|
/* .max_xform_logit = */ -INFINITY,
|
||||||
/* .cum_probs = */ {},
|
/* .cum_probs = */ {},
|
||||||
};
|
};
|
||||||
|
result->orig_prob.resize(n_vocab);
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// grammar
|
// grammar
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ struct llama_sampler_adaptive_p {
|
|||||||
float total_weight; // sum(decay^i), converges to 1/(1-decay)
|
float total_weight; // sum(decay^i), converges to 1/(1-decay)
|
||||||
|
|
||||||
// first referenced in prep
|
// first referenced in prep
|
||||||
std::unordered_map<llama_token, float> orig_prob_map; // probabilities before sampler_queue
|
std::vector<float> orig_prob; // for storing the original proibabilities
|
||||||
float cum_orig_prob; // for normalizing orig_prob in sample_token
|
float cum_orig_prob; // for normalizing orig_prob in sample_token
|
||||||
|
|
||||||
// first referenced in sample
|
// first referenced in sample
|
||||||
@@ -83,7 +83,7 @@ struct llama_sampler_adaptive_p {
|
|||||||
std::vector<float> cum_probs; // cumulative probability distribution
|
std::vector<float> cum_probs; // cumulative probability distribution
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(
|
struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab,
|
||||||
const float target,
|
const float target,
|
||||||
const float decay,
|
const float decay,
|
||||||
const uint32_t seed);
|
const uint32_t seed);
|
||||||
|
|||||||
@@ -7687,10 +7687,9 @@ void llama_sample_dry([[maybe_unused]] struct llama_context* ctx, struct llama_s
|
|||||||
llama_sampler_dry_apply(smpl, candidates_p);
|
llama_sampler_dry_apply(smpl, candidates_p);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_adaptive_p(
|
void llama_sample_adaptive_p(llama_context * ctx,
|
||||||
[[maybe_unused]] struct llama_context * ctx,
|
|
||||||
llama_token_data_array * candidates,
|
llama_token_data_array * candidates,
|
||||||
struct llama_sampler_adaptive_p * adapt_p_ctx) {
|
llama_sampler_adaptive_p * adapt_p_ctx) {
|
||||||
llama_sample_adaptive_p_impl(&ctx->sampling, candidates, adapt_p_ctx);
|
llama_sample_adaptive_p_impl(&ctx->sampling, candidates, adapt_p_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -7797,8 +7796,8 @@ void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token)
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
struct llama_sampler_adaptive_p * llama_init_adaptive_p(const float target, const float decay, const uint32_t seed) {
|
struct llama_sampler_adaptive_p * llama_init_adaptive_p(int n_vocab, const float target, const float decay, const uint32_t seed) {
|
||||||
return llama_init_adaptive_p_impl(target, decay, seed);
|
return llama_init_adaptive_p_impl(n_vocab, target, decay, seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user