diff --git a/ggml/src/iqk/iqk_cpu_ops.cpp b/ggml/src/iqk/iqk_cpu_ops.cpp index 9a20e0c8..4581a1dd 100644 --- a/ggml/src/iqk/iqk_cpu_ops.cpp +++ b/ggml/src/iqk/iqk_cpu_ops.cpp @@ -16,6 +16,7 @@ #include #include #include +//#include #ifdef __ARM_NEON #include @@ -503,3 +504,49 @@ void iqk_hadamard(struct ggml_tensor * dst, int ith, int nth) { fast_ht(nh, y); } } + +namespace { +float iqk_exp_with_thresh_impl(int n, float * logits, float max, float min) { + float sum = 0; +#ifdef __AVX2__ + auto vmax = _mm256_set1_ps(max); + auto vmin = _mm256_set1_ps(min); + auto vsum = _mm256_setzero_ps(); + for (int j = 0; j < n/8; ++j) { + auto x = _mm256_loadu_ps(logits); + auto mask = _mm256_cmp_ps(x, vmin, _CMP_GE_OQ); + auto exp_x = v_expf(_mm256_sub_ps(x, vmax)); + exp_x = _mm256_and_ps(exp_x, mask); + vsum = _mm256_add_ps(vsum, exp_x); + _mm256_storeu_ps(logits, exp_x); + logits += 8; + } + sum = hsum_float_8(vsum); + for (int j = 0; j < n - 8*(n/8); ++j) { + float p = logits[j] > min ? expf(logits[j] - max) : 0; + sum += p; + logits[j] = p; + } +#else + for (int j = 0; j < n; ++j) { + float p = logits[j] > min ? expf(logits[j] - max) : 0; + sum += p; + logits[j] = p; + } +#endif + return sum; +} +} + +float iqk_exp_with_thresh(int n, float * logits, float max, float min) { + return iqk_exp_with_thresh_impl(n, logits, max, min); + //if (n < (1 << 16)) return iqk_exp_with_thresh_impl(n, logits, max, min); + //std::array result; + //auto compute = [logits, max, min, &result] (int first, int last, int ith) { + // result[ith] = iqk_exp_with_thresh_impl(last - first, logits + first, max, min); + //}; + //auto t = std::thread(compute, 0, n/2, 0); + //compute(n/2, n, 1); + //t.join(); + //return result[0] + result[1]; +} diff --git a/ggml/src/iqk/iqk_cpu_ops.h b/ggml/src/iqk/iqk_cpu_ops.h index 8656d7f1..267c3e85 100644 --- a/ggml/src/iqk/iqk_cpu_ops.h +++ b/ggml/src/iqk/iqk_cpu_ops.h @@ -30,6 +30,8 @@ void iqk_mul_multi_add(struct ggml_tensor * dst, int ith, int nth); void iqk_hadamard(struct ggml_tensor * dst, int ith, int nth); +float iqk_exp_with_thresh(int n, float * logits, float max, float min); + #ifdef __cplusplus } #endif diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index b598a2f1..deb6f273 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -2,6 +2,8 @@ #include "llama-vocab.h" #include "llama-grammar.h" +#include "iqk/iqk_cpu_ops.h" + #include #include #include @@ -1131,7 +1133,7 @@ void llama_prep_adaptive_p_impl( struct llama_sampling * smpl, llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx) { - constexpr float kDelta = 16.6f; + constexpr float kDelta = 30.0f; //16.6f; auto t_start = ggml_time_us(); auto & orig_prob = adapt_p_ctx->orig_prob; if (candidates->size != orig_prob.size() || candidates->sorted) { @@ -1141,18 +1143,26 @@ void llama_prep_adaptive_p_impl( GGML_ABORT("Bad candidates in adaptive_p sampler"); } - float max_logit = candidates->data[0].logit; - for (int j = 1; j < int(candidates->size); ++j) { - max_logit = std::max(max_logit, candidates->data[j].logit); - } - float min_logit = max_logit - kDelta; - float cum_prob = 0.0f; + float max_logit = -INFINITY; for (int j = 0; j < int(candidates->size); ++j) { - float prob = candidates->data[j].logit > min_logit ? expf(candidates->data[j].logit - max_logit) : 0.0f; - cum_prob += prob; - orig_prob[j] = prob; + orig_prob[j] = candidates->data[j].logit; + max_logit = std::max(max_logit, orig_prob[j]); } - adapt_p_ctx->cum_orig_prob = cum_prob; + adapt_p_ctx->cum_orig_prob = iqk_exp_with_thresh(orig_prob.size(), orig_prob.data(), max_logit, max_logit - kDelta); + + //float max_logit = candidates->data[0].logit; + //for (int j = 1; j < int(candidates->size); ++j) { + // max_logit = std::max(max_logit, candidates->data[j].logit); + //} + //float min_logit = max_logit - kDelta; + //float cum_prob = 0.0f; + //for (int j = 0; j < int(candidates->size); ++j) { + // float prob = candidates->data[j].logit > min_logit ? expf(candidates->data[j].logit - max_logit) : 0.0f; + // cum_prob += prob; + // orig_prob[j] = prob; + //} + //adapt_p_ctx->cum_orig_prob = cum_prob; + if (smpl) smpl->t_sample_us += ggml_time_us() - t_start; }