mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 11:21:56 +00:00
Fix logprobs (#787)
This commit is mostly a cherry-pick of ggml-org/llama.cpp#10783, plus optimization to do partial sort when sorting the logits. That mainline PR and friends were partially cherry-picked by #723, but wasn't really in a working state yet. A couple of additional changes: * Include timing information in response, which was (unintentionally?) done in mainline since ggml-org/llama.cpp#10643. * Also return the actual logprobs for accepted draft tokens. This is still a TODO in mainline [1]. Note that there is a TG performance penalty to return the logprobs, as we need to sort the logits. By doing partial sort, the penalty is quite small. Here are some numbers I got using the same prompt: This PR with partial sort: * no draft, no logprobs: 12.87 tok/s * no draft, with logprobs: 12.61 tok/s (2.0% drop) * with draft, no logprobs: 36.74 tok/s * with draft, with logprobs: 36.12 tok/s (1.7% drop) If cherry-pick the full sort from mainline PR: * no draft, no logprobs: 12.81 tok/s * no draft, with logprobs: 12.02 tok/s (6.2% drop) * with draft, no logprobs: 36.59 tok/s * with draft, with logprobs: 29.08 tok/s (20.5% drop) [1] https://github.com/ggml-org/llama.cpp/blob/b6548/tools/server/server.cpp#L4019 Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -558,6 +558,7 @@ struct slot_params {
|
|||||||
std::vector<std::string> antiprompt;
|
std::vector<std::string> antiprompt;
|
||||||
|
|
||||||
bool timings_per_token = false;
|
bool timings_per_token = false;
|
||||||
|
bool post_sampling_probs = false;
|
||||||
json input_prefix;
|
json input_prefix;
|
||||||
json input_suffix;
|
json input_suffix;
|
||||||
|
|
||||||
@@ -1549,6 +1550,8 @@ struct server_context {
|
|||||||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||||
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||||
|
|
||||||
|
slot.params.post_sampling_probs = json_value(data, "post_sampling_probs", default_params.post_sampling_probs);
|
||||||
|
|
||||||
// speculative decoding parameters
|
// speculative decoding parameters
|
||||||
slot.params.speculative.n_max = json_value(data, "speculative.n_max", params.n_draft);
|
slot.params.speculative.n_max = json_value(data, "speculative.n_max", params.n_draft);
|
||||||
slot.params.speculative.n_min = json_value(data, "speculative.n_min", params.n_draft_min);
|
slot.params.speculative.n_min = json_value(data, "speculative.n_min", params.n_draft_min);
|
||||||
@@ -1951,26 +1954,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check if there is incomplete UTF-8 character at the end
|
// check if there is incomplete UTF-8 character at the end
|
||||||
bool incomplete = false;
|
bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
|
||||||
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
|
|
||||||
unsigned char c = slot.generated_text[slot.generated_text.size() - i];
|
|
||||||
if ((c & 0xC0) == 0x80) {
|
|
||||||
// continuation byte: 10xxxxxx
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if ((c & 0xE0) == 0xC0) {
|
|
||||||
// 2-byte character: 110xxxxx ...
|
|
||||||
incomplete = i < 2;
|
|
||||||
} else if ((c & 0xF0) == 0xE0) {
|
|
||||||
// 3-byte character: 1110xxxx ...
|
|
||||||
incomplete = i < 3;
|
|
||||||
} else if ((c & 0xF8) == 0xF0) {
|
|
||||||
// 4-byte character: 11110xxx ...
|
|
||||||
incomplete = i < 4;
|
|
||||||
}
|
|
||||||
// else 1-byte character or invalid byte
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!incomplete) {
|
if (!incomplete) {
|
||||||
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
||||||
@@ -2066,6 +2050,49 @@ struct server_context {
|
|||||||
return slot.has_next_token; // continue
|
return slot.has_next_token; // continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
|
||||||
|
size_t n_probs = slot.sparams.n_probs;
|
||||||
|
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
|
|
||||||
|
if (post_sampling) {
|
||||||
|
const auto * cur_p = llama_sampling_get_candidates(slot.ctx_sampling);
|
||||||
|
const size_t max_probs = cur_p->size;
|
||||||
|
|
||||||
|
// set probability for sampled token
|
||||||
|
for (size_t i = 0; i < max_probs; i++) {
|
||||||
|
if (cur_p->data[i].id == result.tok) {
|
||||||
|
result.prob = cur_p->data[i].p;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set probability for top n_probs tokens
|
||||||
|
result.probs.reserve(max_probs);
|
||||||
|
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
|
||||||
|
result.probs.push_back({
|
||||||
|
cur_p->data[i].id,
|
||||||
|
llama_detokenize(ctx, {cur_p->data[i].id}, special),
|
||||||
|
cur_p->data[i].p
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto&&[sampled_token_p, cur] = get_token_probabilities(ctx, idx, result.tok, n_probs);
|
||||||
|
|
||||||
|
// set probability for sampled token
|
||||||
|
result.prob = sampled_token_p;
|
||||||
|
|
||||||
|
// set probability for top n_probs tokens
|
||||||
|
result.probs.reserve(n_probs);
|
||||||
|
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
|
||||||
|
result.probs.push_back({
|
||||||
|
cur[i].id,
|
||||||
|
llama_detokenize(ctx, {cur[i].id}, special),
|
||||||
|
cur[i].p
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
json get_formated_generation(const server_slot & slot) const {
|
json get_formated_generation(const server_slot & slot) const {
|
||||||
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
|
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
|
||||||
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
|
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
|
||||||
@@ -2163,6 +2190,7 @@ struct server_context {
|
|||||||
res.stop = false;
|
res.stop = false;
|
||||||
res.stream = slot.params.stream;
|
res.stream = slot.params.stream;
|
||||||
res.content = tkn.text_to_send;
|
res.content = tkn.text_to_send;
|
||||||
|
res.post_sampling_probs = slot.params.post_sampling_probs;
|
||||||
res.oaicompat = slot.params.oaicompat;
|
res.oaicompat = slot.params.oaicompat;
|
||||||
res.oaicompat_model = slot.params.oaicompat_model;
|
res.oaicompat_model = slot.params.oaicompat_model;
|
||||||
res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||||
@@ -2175,26 +2203,18 @@ struct server_context {
|
|||||||
{"multimodal", false}
|
{"multimodal", false}
|
||||||
};
|
};
|
||||||
slot.update_chat_msg(res.oaicompat_msg_diffs);
|
slot.update_chat_msg(res.oaicompat_msg_diffs);
|
||||||
|
|
||||||
|
// populate res.probs_output
|
||||||
if (slot.sparams.n_probs > 0) {
|
if (slot.sparams.n_probs > 0) {
|
||||||
const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
|
res.probs_output = {tkn}; // copy the token probs
|
||||||
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
|
res.data["completion_probabilities"] = probs_vector_to_json(ctx, res.probs_output);
|
||||||
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
|
|
||||||
|
|
||||||
std::vector<completion_token_output> probs_output;
|
|
||||||
if (probs_pos < probs_stop_pos) {
|
|
||||||
probs_output = std::vector<completion_token_output>(
|
|
||||||
slot.generated_token_probs.begin() + probs_pos,
|
|
||||||
slot.generated_token_probs.begin() + probs_stop_pos);
|
|
||||||
}
|
|
||||||
slot.n_sent_token_probs = probs_stop_pos;
|
|
||||||
|
|
||||||
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.oaicompat) {
|
if (slot.oaicompat) {
|
||||||
res.data["oaicompat_token_ctr"] = slot.n_decoded;
|
res.data["oaicompat_token_ctr"] = slot.n_decoded;
|
||||||
res.data["model"] = slot.oaicompat_model;
|
res.data["model"] = slot.oaicompat_model;
|
||||||
}
|
}
|
||||||
|
|
||||||
// populate timings if this is final response or timings_per_token is enabled
|
// populate timings if this is final response or timings_per_token is enabled
|
||||||
if (slot.params.timings_per_token) {
|
if (slot.params.timings_per_token) {
|
||||||
res.timings = slot.get_timings();
|
res.timings = slot.get_timings();
|
||||||
@@ -2212,6 +2232,8 @@ struct server_context {
|
|||||||
res.stream = slot.params.stream;
|
res.stream = slot.params.stream;
|
||||||
res.include_usage = slot.params.include_usage;
|
res.include_usage = slot.params.include_usage;
|
||||||
res.content = slot.generated_text;
|
res.content = slot.generated_text;
|
||||||
|
res.timings = slot.get_timings();
|
||||||
|
res.post_sampling_probs = slot.params.post_sampling_probs;
|
||||||
res.oaicompat = slot.params.oaicompat;
|
res.oaicompat = slot.params.oaicompat;
|
||||||
res.oaicompat_model = slot.params.oaicompat_model;
|
res.oaicompat_model = slot.params.oaicompat_model;
|
||||||
res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||||
@@ -2239,26 +2261,23 @@ struct server_context {
|
|||||||
//{"oaicompat_chat_format", slot.params.oaicompat_chat_format},
|
//{"oaicompat_chat_format", slot.params.oaicompat_chat_format},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// populate res.probs_output
|
||||||
if (slot.sparams.n_probs > 0) {
|
if (slot.sparams.n_probs > 0) {
|
||||||
std::vector<completion_token_output> probs;
|
|
||||||
if (!slot.params.stream && slot.stopped_word) {
|
if (!slot.params.stream && slot.stopped_word) {
|
||||||
const std::vector<llama_token> stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false);
|
const std::vector<llama_token> stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false);
|
||||||
|
|
||||||
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
|
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
|
||||||
probs = std::vector<completion_token_output>(
|
res.probs_output = std::vector<completion_token_output>(
|
||||||
slot.generated_token_probs.begin(),
|
slot.generated_token_probs.begin(),
|
||||||
slot.generated_token_probs.end() - safe_offset);
|
slot.generated_token_probs.end() - safe_offset);
|
||||||
} else {
|
} else {
|
||||||
probs = std::vector<completion_token_output>(
|
res.probs_output = std::vector<completion_token_output>(
|
||||||
slot.generated_token_probs.begin(),
|
slot.generated_token_probs.begin(),
|
||||||
slot.generated_token_probs.end());
|
slot.generated_token_probs.end());
|
||||||
}
|
}
|
||||||
//res.generation_params = slot.params;
|
res.data["completion_probabilities"] = probs_vector_to_json(ctx, res.probs_output);
|
||||||
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
res.timings = slot.get_timings();
|
|
||||||
|
|
||||||
if (slot.oaicompat) {
|
if (slot.oaicompat) {
|
||||||
res.data["oaicompat_token_ctr"] = slot.n_decoded;
|
res.data["oaicompat_token_ctr"] = slot.n_decoded;
|
||||||
res.data["model"] = slot.oaicompat_model;
|
res.data["model"] = slot.oaicompat_model;
|
||||||
@@ -3199,7 +3218,8 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
completion_token_output result;
|
completion_token_output result;
|
||||||
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
|
const int tok_idx = slot.i_batch - i;
|
||||||
|
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, tok_idx);
|
||||||
|
|
||||||
llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
|
llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
|
||||||
|
|
||||||
@@ -3215,35 +3235,12 @@ struct server_context {
|
|||||||
|
|
||||||
slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
|
slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
|
||||||
|
|
||||||
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
|
|
||||||
result.tok = id;
|
result.tok = id;
|
||||||
|
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
||||||
result.text_to_send = llama_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
result.text_to_send = llama_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
||||||
|
|
||||||
const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
|
if (slot.sparams.n_probs > 0) {
|
||||||
if (n_probs > 0) {
|
populate_token_probs(slot, result, slot.params.post_sampling_probs, params.special, tok_idx);
|
||||||
const size_t n_valid = slot.ctx_sampling->n_valid;
|
|
||||||
|
|
||||||
// Make sure at least n_probs top tokens are at the front of the vector:
|
|
||||||
if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
|
|
||||||
llama_sample_top_k(ctx, &cur_p, n_probs, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (slot.sparams.temp == 0.0f) {
|
|
||||||
// With greedy sampling the probabilities have possibly not been calculated.
|
|
||||||
for (size_t i = 0; i < n_probs; ++i) {
|
|
||||||
result.probs.push_back({
|
|
||||||
cur_p.data[i].id,llama_detokenize(ctx, {cur_p.data[i].id}, params.special),
|
|
||||||
i == 0 ? 1.0f : 0.0f
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (size_t i = 0; i < n_probs; ++i) {
|
|
||||||
result.probs.push_back({
|
|
||||||
cur_p.data[i].id, llama_detokenize(ctx, {cur_p.data[i].id}, params.special),
|
|
||||||
i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!process_token(result, slot)) {
|
if (!process_token(result, slot)) {
|
||||||
@@ -3348,7 +3345,11 @@ struct server_context {
|
|||||||
|
|
||||||
result.tok = ids[i];
|
result.tok = ids[i];
|
||||||
result.text_to_send = llama_token_to_piece(ctx, result.tok, params.special);
|
result.text_to_send = llama_token_to_piece(ctx, result.tok, params.special);
|
||||||
// result.prob = 1.0f; // set later
|
result.prob = 1.0f; // set later
|
||||||
|
|
||||||
|
if (slot.sparams.n_probs > 0) {
|
||||||
|
populate_token_probs(slot, result, slot.params.post_sampling_probs, params.special, i);
|
||||||
|
}
|
||||||
|
|
||||||
if (!process_token(result, slot)) {
|
if (!process_token(result, slot)) {
|
||||||
// release slot because of stop condition
|
// release slot because of stop condition
|
||||||
|
|||||||
@@ -406,7 +406,6 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// OAI utils
|
// OAI utils
|
||||||
//
|
//
|
||||||
@@ -616,13 +615,12 @@ static json oaicompat_chat_params_parse(
|
|||||||
|
|
||||||
// Handle "logprobs" field
|
// Handle "logprobs" field
|
||||||
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
|
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
|
||||||
if (body.contains("logprobs")) {
|
if (json_value(body, "logprobs", false)) {
|
||||||
if (has_tools && stream) {
|
if (has_tools && stream) {
|
||||||
throw std::runtime_error("logprobs is not supported with tools + stream");
|
throw std::runtime_error("logprobs is not supported with tools + stream");
|
||||||
}
|
}
|
||||||
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
|
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
|
||||||
}
|
} else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
|
||||||
else if (body.contains("top_logprobs")) {
|
|
||||||
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
|
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -715,3 +713,43 @@ static json format_error_response(const std::string & message, const enum error_
|
|||||||
{"type", type_str},
|
{"type", type_str},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct token_probabilities {
|
||||||
|
float sampled_token_p;
|
||||||
|
std::vector<llama_token_data> cur;
|
||||||
|
};
|
||||||
|
|
||||||
|
static token_probabilities get_token_probabilities(llama_context * ctx, int idx, llama_token sampled_token_id, int n_sorted) {
|
||||||
|
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||||
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
|
n_sorted = std::min(n_sorted, n_vocab);
|
||||||
|
|
||||||
|
std::vector<std::pair<float, llama_token>> sorted(n_vocab);
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) sorted[token_id] = {logits[token_id], token_id};
|
||||||
|
|
||||||
|
std::partial_sort(sorted.begin(), sorted.begin() + n_sorted, sorted.end(), std::greater<std::pair<float,llama_token>>{});
|
||||||
|
|
||||||
|
float max_l = sorted.front().first;
|
||||||
|
float cum_sum = 0.0f;
|
||||||
|
float sampled_token_p = 0.0f;
|
||||||
|
bool sampled_token_found = false;
|
||||||
|
std::vector<llama_token_data> cur(n_sorted);
|
||||||
|
for (int i = 0; i < n_vocab; ++i) {
|
||||||
|
float p = expf(sorted[i].first - max_l);
|
||||||
|
cum_sum += p;
|
||||||
|
if (i < n_sorted) {
|
||||||
|
cur[i] = {sorted[i].second, sorted[i].first, p};
|
||||||
|
}
|
||||||
|
if (!sampled_token_found && sorted[i].second == sampled_token_id) {
|
||||||
|
sampled_token_p = p;
|
||||||
|
sampled_token_found = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = n_sorted; i < n_vocab; ++i) cum_sum += expf(sorted[i].first - max_l);
|
||||||
|
|
||||||
|
float inv_cum_sum = 1/cum_sum;
|
||||||
|
for (int i = 0; i < n_sorted; ++i) cur[i].p *= inv_cum_sum;
|
||||||
|
sampled_token_p *= inv_cum_sum;
|
||||||
|
|
||||||
|
return {sampled_token_p, cur};
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user