diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 9d3134e4..6ab13132 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1038,8 +1038,7 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab& llama_token llama_sample_token_adaptive_p_impl( struct llama_sampling * smpl, llama_token_data_array * candidates, - struct llama_sampler_adaptive_p * adapt_p_ctx) -{ + struct llama_sampler_adaptive_p * adapt_p_ctx) { GGML_ASSERT(candidates->size > 0); const int64_t t_start_sample_us = ggml_time_us(); @@ -1065,15 +1064,21 @@ llama_token llama_sample_token_adaptive_p_impl( smpl->t_sample_us += ggml_time_us() - t_start_sample_us; smpl->n_sample++; - float update_prob = candidates->data[idx].p; // not ideal - if (ctx->orig_prob_map.contains(id)) { - // selected token id is among tracked ids - update_prob = ctx->orig_prob_map[id] / ctx->cum_orig_prob; + if (auto it = ctx->orig_prob_map.find(id); it != ctx->orig_prob_map.end()) { + float update_prob = it->second / ctx->cum_orig_prob; + ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob; + ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f; } - // update history with original probability of selected token - ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob; - ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f; + //float update_prob = candidates->data[idx].p; // not ideal + //if (ctx->orig_prob_map.contains(id)) { + // // selected token id is among tracked ids + // update_prob = ctx->orig_prob_map[id] / ctx->cum_orig_prob; + //} + + //// update history with original probability of selected token + //ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob; + //ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f; return id; } @@ -1130,37 +1135,71 @@ void llama_sample_adaptive_p_impl(llama_token_data_array * candidates, struct ll void llama_prep_adaptive_p_impl( llama_token_data_array * candidates, - struct llama_sampler_adaptive_p * adapt_p_ctx) -{ + struct llama_sampler_adaptive_p * adapt_p_ctx) { + constexpr float kDelta = 16.6f; if (!candidates->sorted) { - std::sort(candidates->data, candidates->data + candidates->size, - [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); - candidates->sorted = true; + float max_logit = candidates->data[0].logit; + for (int j = 1; j < int(candidates->size); ++j) { + max_logit = std::max(max_logit, candidates->data[j].logit); + } + float min_logit = max_logit - kDelta; + float cum_prob = 0.0f; + adapt_p_ctx->orig_prob_map.clear(); + for (int j = 0; j < int(candidates->size); ++j) { + if (candidates->data[j].logit > min_logit) { + float prob = expf(candidates->data[j].logit - max_logit); + cum_prob += prob; + adapt_p_ctx->orig_prob_map[candidates->data[j].id] = prob; + } + } + adapt_p_ctx->cum_orig_prob = cum_prob; + return; } - const float max_logit = candidates->data[0].logit; - // decide how many tokens to track based on logit delta - // i.e. do not track unlikely tokens - auto iter = std::lower_bound( - candidates->data, - candidates->data + candidates->size, - max_logit - 16.6f, // delta - [](const llama_token_data & data, const float delta) { - return data.logit > delta; - }); - const size_t n_track = std::distance(candidates->data, iter); - - // store orig_prob_map and cum_orig_prob to estimate original probability later + float max_logit = candidates->data[0].logit; + float min_logit = max_logit - kDelta; float cum_prob = 0.0f; adapt_p_ctx->orig_prob_map.clear(); - for (size_t i = 0; i < n_track; ++i) { - const float prob = expf(candidates->data[i].logit - max_logit); + for (int j = 0; j < int(candidates->size); ++j) { + auto logit = candidates->data[j].logit; + if (logit <= min_logit) { + break; + } + float prob = expf(logit - max_logit); cum_prob += prob; - adapt_p_ctx->orig_prob_map[candidates->data[i].id] = prob; + adapt_p_ctx->orig_prob_map[candidates->data[j].id] = prob; } adapt_p_ctx->cum_orig_prob = cum_prob; + + //if (!candidates->sorted) { + // std::sort(candidates->data, candidates->data + candidates->size, + // [](const llama_token_data & a, const llama_token_data & b) { + // return a.logit > b.logit; + // }); + // candidates->sorted = true; + //} + //const float max_logit = candidates->data[0].logit; + + //// decide how many tokens to track based on logit delta + //// i.e. do not track unlikely tokens + //auto iter = std::lower_bound( + // candidates->data, + // candidates->data + candidates->size, + // max_logit - kDelta, // delta + // [](const llama_token_data & data, const float delta) { + // return data.logit > delta; + // }); + //const size_t n_track = std::distance(candidates->data, iter); + + //// store orig_prob_map and cum_orig_prob to estimate original probability later + //float cum_prob = 0.0f; + //adapt_p_ctx->orig_prob_map.clear(); + //for (size_t i = 0; i < n_track; ++i) { + // const float prob = expf(candidates->data[i].logit - max_logit); + // cum_prob += prob; + // adapt_p_ctx->orig_prob_map[candidates->data[i].id] = prob; + //} + //adapt_p_ctx->cum_orig_prob = cum_prob; } struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(