diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 17f95f0a..046d0d0e 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1064,22 +1064,12 @@ llama_token llama_sample_token_adaptive_p_impl( smpl->t_sample_us += ggml_time_us() - t_start_sample_us; smpl->n_sample++; - if (auto it = ctx->orig_prob_map.find(id); it != ctx->orig_prob_map.end()) { - float update_prob = it->second / ctx->cum_orig_prob; + GGML_ASSERT(id < int(ctx->orig_prob.size())); + if (auto update_prob = ctx->orig_prob[id]; update_prob > 0) { ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob; ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f; } - //float update_prob = candidates->data[idx].p; // not ideal - //if (ctx->orig_prob_map.contains(id)) { - // // selected token id is among tracked ids - // update_prob = ctx->orig_prob_map[id] / ctx->cum_orig_prob; - //} - - //// update history with original probability of selected token - //ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob; - //ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f; - return id; } @@ -1136,6 +1126,7 @@ void llama_prep_adaptive_p_impl( llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx) { constexpr float kDelta = 16.6f; + auto & orig_prob = adapt_p_ctx->orig_prob; if (!candidates->sorted) { float max_logit = candidates->data[0].logit; for (int j = 1; j < int(candidates->size); ++j) { @@ -1143,22 +1134,37 @@ void llama_prep_adaptive_p_impl( } float min_logit = max_logit - kDelta; float cum_prob = 0.0f; - adapt_p_ctx->orig_prob_map.clear(); + if (orig_prob.size() != candidates->size) { + orig_prob.resize(candidates->size); + } for (int j = 0; j < int(candidates->size); ++j) { if (candidates->data[j].logit > min_logit) { float prob = expf(candidates->data[j].logit - max_logit); cum_prob += prob; - adapt_p_ctx->orig_prob_map[candidates->data[j].id] = prob; + orig_prob[j] = prob; + } else { + orig_prob[j] = 0; } } adapt_p_ctx->cum_orig_prob = cum_prob; return; } + // Hopefully we never end here + // But if we do, let's issue some warnings + if (adapt_p_ctx->n_warn < 10) { + LLAMA_LOG_WARN("%s: this function should be called before any other sampler is applied\n", __func__); + ++adapt_p_ctx->n_warn; + } + + llama_token max_id = 0; + for (int j = 0; j < int(candidates->size); ++j) max_id = std::max(max_id, candidates->data[j].id); + if (max_id + 1 != int(orig_prob.size())) orig_prob.resize(max_id + 1); + std::memset(orig_prob.data(), 0, orig_prob.size()*sizeof(float)); + float max_logit = candidates->data[0].logit; float min_logit = max_logit - kDelta; float cum_prob = 0.0f; - adapt_p_ctx->orig_prob_map.clear(); for (int j = 0; j < int(candidates->size); ++j) { auto logit = candidates->data[j].logit; if (logit <= min_logit) { @@ -1166,39 +1172,10 @@ void llama_prep_adaptive_p_impl( } float prob = expf(logit - max_logit); cum_prob += prob; - adapt_p_ctx->orig_prob_map[candidates->data[j].id] = prob; + orig_prob[candidates->data[j].id] = prob; } adapt_p_ctx->cum_orig_prob = cum_prob; - //if (!candidates->sorted) { - // std::sort(candidates->data, candidates->data + candidates->size, - // [](const llama_token_data & a, const llama_token_data & b) { - // return a.logit > b.logit; - // }); - // candidates->sorted = true; - //} - //const float max_logit = candidates->data[0].logit; - - //// decide how many tokens to track based on logit delta - //// i.e. do not track unlikely tokens - //auto iter = std::lower_bound( - // candidates->data, - // candidates->data + candidates->size, - // max_logit - kDelta, // delta - // [](const llama_token_data & data, const float delta) { - // return data.logit > delta; - // }); - //const size_t n_track = std::distance(candidates->data, iter); - - //// store orig_prob_map and cum_orig_prob to estimate original probability later - //float cum_prob = 0.0f; - //adapt_p_ctx->orig_prob_map.clear(); - //for (size_t i = 0; i < n_track; ++i) { - // const float prob = expf(candidates->data[i].logit - max_logit); - // cum_prob += prob; - // adapt_p_ctx->orig_prob_map[candidates->data[i].id] = prob; - //} - //adapt_p_ctx->cum_orig_prob = cum_prob; } struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl( @@ -1212,8 +1189,9 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl( /* .rng = */ std::mt19937(seed), /* .weighted_sum = */ target / (1.0f - clamped_decay), /* .total_weight = */ 1.0f / (1.0f - clamped_decay), - /* .orig_logit_map = */ {}, + /* .orig_prob = */ {}, /* .cum_orig_prob = */ 0.0f, + /* .n_warn = */ 0, /* .max_xform_logit = */ -INFINITY, /* .cum_probs = */ {}, }; diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 55b4371c..de1cce27 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -73,8 +73,9 @@ struct llama_sampler_adaptive_p { float total_weight; // sum(decay^i), converges to 1/(1-decay) // first referenced in prep - std::unordered_map orig_prob_map; // probabilities before sampler_queue + std::vector orig_prob; // for storing the original proibabilities float cum_orig_prob; // for normalizing orig_prob in sample_token + int n_warn = 0; // for warnings // first referenced in sample float max_xform_logit; // maximum logit found during transform