diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index b6e929c9..5e26eb20 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -702,23 +702,39 @@ llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama GGML_ASSERT(smpl); const int64_t t_start_sample_us = ggml_time_us(); - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); - std::vector probs; - probs.reserve(candidates->size); - for (size_t i = 0; i < candidates->size; ++i) { - probs.push_back(candidates->data[i].p); + if (candidates->size < 2) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + smpl->n_sample++; + return candidates->data[0].id; } - std::discrete_distribution<> dist(probs.begin(), probs.end()); - int idx = dist(rng); + std::vector probs(candidates->size); + probs[0] = candidates->data[0].logit; + float max = probs[0]; + for (int j = 1; j < candidates->size; ++j) { + probs[j] = candidates->data[j].logit; + max = std::max(max, probs[j]); + } - llama_token result = candidates->data[idx].id; + float sump = 0; + for (int j = 0; j < candidates->size; ++j) { + float p = expf(probs[j] - max); + sump += p; + probs[j] = sump; + } + probs.back() += sump; + + auto p = sump * rng() / rng.max(); + auto iter = std::upper_bound(probs.begin(), probs.end(), p); + GGML_ASSERT(iter != probs.end()); + auto idx = std::distance(probs.begin(), iter); + auto id = candidates->data[idx].id; smpl->t_sample_us += ggml_time_us() - t_start_sample_us; smpl->n_sample++; - return result; + return id; } llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {