This commit is contained in:
Kawrakow
2026-01-19 13:03:12 +00:00
parent bd2434945d
commit c9cd616f84
3 changed files with 70 additions and 11 deletions

View File

@@ -16,6 +16,7 @@
#include <algorithm>
#include <cmath>
#include <cstring>
//#include <thread>
#ifdef __ARM_NEON
#include <arm_neon.h>
@@ -503,3 +504,49 @@ void iqk_hadamard(struct ggml_tensor * dst, int ith, int nth) {
fast_ht(nh, y);
}
}
namespace {
float iqk_exp_with_thresh_impl(int n, float * logits, float max, float min) {
float sum = 0;
#ifdef __AVX2__
auto vmax = _mm256_set1_ps(max);
auto vmin = _mm256_set1_ps(min);
auto vsum = _mm256_setzero_ps();
for (int j = 0; j < n/8; ++j) {
auto x = _mm256_loadu_ps(logits);
auto mask = _mm256_cmp_ps(x, vmin, _CMP_GE_OQ);
auto exp_x = v_expf(_mm256_sub_ps(x, vmax));
exp_x = _mm256_and_ps(exp_x, mask);
vsum = _mm256_add_ps(vsum, exp_x);
_mm256_storeu_ps(logits, exp_x);
logits += 8;
}
sum = hsum_float_8(vsum);
for (int j = 0; j < n - 8*(n/8); ++j) {
float p = logits[j] > min ? expf(logits[j] - max) : 0;
sum += p;
logits[j] = p;
}
#else
for (int j = 0; j < n; ++j) {
float p = logits[j] > min ? expf(logits[j] - max) : 0;
sum += p;
logits[j] = p;
}
#endif
return sum;
}
}
float iqk_exp_with_thresh(int n, float * logits, float max, float min) {
return iqk_exp_with_thresh_impl(n, logits, max, min);
//if (n < (1 << 16)) return iqk_exp_with_thresh_impl(n, logits, max, min);
//std::array<float, 2> result;
//auto compute = [logits, max, min, &result] (int first, int last, int ith) {
// result[ith] = iqk_exp_with_thresh_impl(last - first, logits + first, max, min);
//};
//auto t = std::thread(compute, 0, n/2, 0);
//compute(n/2, n, 1);
//t.join();
//return result[0] + result[1];
}

View File

@@ -30,6 +30,8 @@ void iqk_mul_multi_add(struct ggml_tensor * dst, int ith, int nth);
void iqk_hadamard(struct ggml_tensor * dst, int ith, int nth);
float iqk_exp_with_thresh(int n, float * logits, float max, float min);
#ifdef __cplusplus
}
#endif

View File

@@ -2,6 +2,8 @@
#include "llama-vocab.h"
#include "llama-grammar.h"
#include "iqk/iqk_cpu_ops.h"
#include <algorithm>
#include <cstring>
#include <ctime>
@@ -1131,7 +1133,7 @@ void llama_prep_adaptive_p_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
struct llama_sampler_adaptive_p * adapt_p_ctx) {
constexpr float kDelta = 16.6f;
constexpr float kDelta = 30.0f; //16.6f;
auto t_start = ggml_time_us();
auto & orig_prob = adapt_p_ctx->orig_prob;
if (candidates->size != orig_prob.size() || candidates->sorted) {
@@ -1141,18 +1143,26 @@ void llama_prep_adaptive_p_impl(
GGML_ABORT("Bad candidates in adaptive_p sampler");
}
float max_logit = candidates->data[0].logit;
for (int j = 1; j < int(candidates->size); ++j) {
max_logit = std::max(max_logit, candidates->data[j].logit);
}
float min_logit = max_logit - kDelta;
float cum_prob = 0.0f;
float max_logit = -INFINITY;
for (int j = 0; j < int(candidates->size); ++j) {
float prob = candidates->data[j].logit > min_logit ? expf(candidates->data[j].logit - max_logit) : 0.0f;
cum_prob += prob;
orig_prob[j] = prob;
orig_prob[j] = candidates->data[j].logit;
max_logit = std::max(max_logit, orig_prob[j]);
}
adapt_p_ctx->cum_orig_prob = cum_prob;
adapt_p_ctx->cum_orig_prob = iqk_exp_with_thresh(orig_prob.size(), orig_prob.data(), max_logit, max_logit - kDelta);
//float max_logit = candidates->data[0].logit;
//for (int j = 1; j < int(candidates->size); ++j) {
// max_logit = std::max(max_logit, candidates->data[j].logit);
//}
//float min_logit = max_logit - kDelta;
//float cum_prob = 0.0f;
//for (int j = 0; j < int(candidates->size); ++j) {
// float prob = candidates->data[j].logit > min_logit ? expf(candidates->data[j].logit - max_logit) : 0.0f;
// cum_prob += prob;
// orig_prob[j] = prob;
//}
//adapt_p_ctx->cum_orig_prob = cum_prob;
if (smpl) smpl->t_sample_us += ggml_time_us() - t_start;
}