mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 10:51:51 +00:00
Fix logprobs (#787)
This commit is mostly a cherry-pick of ggml-org/llama.cpp#10783, plus optimization to do partial sort when sorting the logits. That mainline PR and friends were partially cherry-picked by #723, but wasn't really in a working state yet. A couple of additional changes: * Include timing information in response, which was (unintentionally?) done in mainline since ggml-org/llama.cpp#10643. * Also return the actual logprobs for accepted draft tokens. This is still a TODO in mainline [1]. Note that there is a TG performance penalty to return the logprobs, as we need to sort the logits. By doing partial sort, the penalty is quite small. Here are some numbers I got using the same prompt: This PR with partial sort: * no draft, no logprobs: 12.87 tok/s * no draft, with logprobs: 12.61 tok/s (2.0% drop) * with draft, no logprobs: 36.74 tok/s * with draft, with logprobs: 36.12 tok/s (1.7% drop) If cherry-pick the full sort from mainline PR: * no draft, no logprobs: 12.81 tok/s * no draft, with logprobs: 12.02 tok/s (6.2% drop) * with draft, no logprobs: 36.59 tok/s * with draft, with logprobs: 29.08 tok/s (20.5% drop) [1] https://github.com/ggml-org/llama.cpp/blob/b6548/tools/server/server.cpp#L4019 Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -406,7 +406,6 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// OAI utils
|
||||
//
|
||||
@@ -616,13 +615,12 @@ static json oaicompat_chat_params_parse(
|
||||
|
||||
// Handle "logprobs" field
|
||||
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
|
||||
if (body.contains("logprobs")) {
|
||||
if (json_value(body, "logprobs", false)) {
|
||||
if (has_tools && stream) {
|
||||
throw std::runtime_error("logprobs is not supported with tools + stream");
|
||||
}
|
||||
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
|
||||
}
|
||||
else if (body.contains("top_logprobs")) {
|
||||
} else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
|
||||
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
|
||||
}
|
||||
|
||||
@@ -715,3 +713,43 @@ static json format_error_response(const std::string & message, const enum error_
|
||||
{"type", type_str},
|
||||
};
|
||||
}
|
||||
|
||||
struct token_probabilities {
|
||||
float sampled_token_p;
|
||||
std::vector<llama_token_data> cur;
|
||||
};
|
||||
|
||||
static token_probabilities get_token_probabilities(llama_context * ctx, int idx, llama_token sampled_token_id, int n_sorted) {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
n_sorted = std::min(n_sorted, n_vocab);
|
||||
|
||||
std::vector<std::pair<float, llama_token>> sorted(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) sorted[token_id] = {logits[token_id], token_id};
|
||||
|
||||
std::partial_sort(sorted.begin(), sorted.begin() + n_sorted, sorted.end(), std::greater<std::pair<float,llama_token>>{});
|
||||
|
||||
float max_l = sorted.front().first;
|
||||
float cum_sum = 0.0f;
|
||||
float sampled_token_p = 0.0f;
|
||||
bool sampled_token_found = false;
|
||||
std::vector<llama_token_data> cur(n_sorted);
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
float p = expf(sorted[i].first - max_l);
|
||||
cum_sum += p;
|
||||
if (i < n_sorted) {
|
||||
cur[i] = {sorted[i].second, sorted[i].first, p};
|
||||
}
|
||||
if (!sampled_token_found && sorted[i].second == sampled_token_id) {
|
||||
sampled_token_p = p;
|
||||
sampled_token_found = true;
|
||||
}
|
||||
}
|
||||
for (int i = n_sorted; i < n_vocab; ++i) cum_sum += expf(sorted[i].first - max_l);
|
||||
|
||||
float inv_cum_sum = 1/cum_sum;
|
||||
for (int i = 0; i < n_sorted; ++i) cur[i].p *= inv_cum_sum;
|
||||
sampled_token_p *= inv_cum_sum;
|
||||
|
||||
return {sampled_token_p, cur};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user