Fix logprobs (#787)

This commit is mostly a cherry-pick of ggml-org/llama.cpp#10783, plus
optimization to do partial sort when sorting the logits.

That mainline PR and friends were partially cherry-picked by #723, but
wasn't really in a working state yet.

A couple of additional changes:
* Include timing information in response, which was (unintentionally?)
  done in mainline since ggml-org/llama.cpp#10643.
* Also return the actual logprobs for accepted draft tokens. This is
  still a TODO in mainline [1].

Note that there is a TG performance penalty to return the logprobs, as
we need to sort the logits. By doing partial sort, the penalty is quite
small. Here are some numbers I got using the same prompt:

This PR with partial sort:
* no draft, no logprobs: 12.87 tok/s
* no draft, with logprobs: 12.61 tok/s (2.0% drop)
* with draft, no logprobs: 36.74 tok/s
* with draft, with logprobs: 36.12 tok/s (1.7% drop)

If cherry-pick the full sort from mainline PR:
* no draft, no logprobs: 12.81 tok/s
* no draft, with logprobs: 12.02 tok/s (6.2% drop)
* with draft, no logprobs: 36.59 tok/s
* with draft, with logprobs: 29.08 tok/s (20.5% drop)

[1] https://github.com/ggml-org/llama.cpp/blob/b6548/tools/server/server.cpp#L4019

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Yap Sok Ann
2025-09-25 20:43:30 +07:00
committed by GitHub
parent 8e497e704e
commit 6bb76b142d
2 changed files with 111 additions and 72 deletions

View File

@@ -406,7 +406,6 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
return out;
}
//
// OAI utils
//
@@ -616,13 +615,12 @@ static json oaicompat_chat_params_parse(
// Handle "logprobs" field
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
if (body.contains("logprobs")) {
if (json_value(body, "logprobs", false)) {
if (has_tools && stream) {
throw std::runtime_error("logprobs is not supported with tools + stream");
}
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
}
else if (body.contains("top_logprobs")) {
} else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
}
@@ -715,3 +713,43 @@ static json format_error_response(const std::string & message, const enum error_
{"type", type_str},
};
}
struct token_probabilities {
float sampled_token_p;
std::vector<llama_token_data> cur;
};
static token_probabilities get_token_probabilities(llama_context * ctx, int idx, llama_token sampled_token_id, int n_sorted) {
const auto * logits = llama_get_logits_ith(ctx, idx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
n_sorted = std::min(n_sorted, n_vocab);
std::vector<std::pair<float, llama_token>> sorted(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) sorted[token_id] = {logits[token_id], token_id};
std::partial_sort(sorted.begin(), sorted.begin() + n_sorted, sorted.end(), std::greater<std::pair<float,llama_token>>{});
float max_l = sorted.front().first;
float cum_sum = 0.0f;
float sampled_token_p = 0.0f;
bool sampled_token_found = false;
std::vector<llama_token_data> cur(n_sorted);
for (int i = 0; i < n_vocab; ++i) {
float p = expf(sorted[i].first - max_l);
cum_sum += p;
if (i < n_sorted) {
cur[i] = {sorted[i].second, sorted[i].first, p};
}
if (!sampled_token_found && sorted[i].second == sampled_token_id) {
sampled_token_p = p;
sampled_token_found = true;
}
}
for (int i = n_sorted; i < n_vocab; ++i) cum_sum += expf(sorted[i].first - max_l);
float inv_cum_sum = 1/cum_sum;
for (int i = 0; i < n_sorted; ++i) cur[i].p *= inv_cum_sum;
sampled_token_p *= inv_cum_sum;
return {sampled_token_p, cur};
}