diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 95e40911..546b4c63 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -558,6 +558,7 @@ struct slot_params { std::vector antiprompt; bool timings_per_token = false; + bool post_sampling_probs = false; json input_prefix; json input_suffix; @@ -1549,6 +1550,8 @@ struct server_context { slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); + slot.params.post_sampling_probs = json_value(data, "post_sampling_probs", default_params.post_sampling_probs); + // speculative decoding parameters slot.params.speculative.n_max = json_value(data, "speculative.n_max", params.n_draft); slot.params.speculative.n_min = json_value(data, "speculative.n_min", params.n_draft_min); @@ -1951,26 +1954,7 @@ struct server_context { } // check if there is incomplete UTF-8 character at the end - bool incomplete = false; - for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) { - unsigned char c = slot.generated_text[slot.generated_text.size() - i]; - if ((c & 0xC0) == 0x80) { - // continuation byte: 10xxxxxx - continue; - } - if ((c & 0xE0) == 0xC0) { - // 2-byte character: 110xxxxx ... - incomplete = i < 2; - } else if ((c & 0xF0) == 0xE0) { - // 3-byte character: 1110xxxx ... - incomplete = i < 3; - } else if ((c & 0xF8) == 0xF0) { - // 4-byte character: 11110xxx ... - incomplete = i < 4; - } - // else 1-byte character or invalid byte - break; - } + bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); if (!incomplete) { size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); @@ -2066,6 +2050,49 @@ struct server_context { return slot.has_next_token; // continue } + void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) { + size_t n_probs = slot.sparams.n_probs; + size_t n_vocab = llama_n_vocab(llama_get_model(ctx)); + + if (post_sampling) { + const auto * cur_p = llama_sampling_get_candidates(slot.ctx_sampling); + const size_t max_probs = cur_p->size; + + // set probability for sampled token + for (size_t i = 0; i < max_probs; i++) { + if (cur_p->data[i].id == result.tok) { + result.prob = cur_p->data[i].p; + break; + } + } + + // set probability for top n_probs tokens + result.probs.reserve(max_probs); + for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { + result.probs.push_back({ + cur_p->data[i].id, + llama_detokenize(ctx, {cur_p->data[i].id}, special), + cur_p->data[i].p + }); + } + } else { + auto&&[sampled_token_p, cur] = get_token_probabilities(ctx, idx, result.tok, n_probs); + + // set probability for sampled token + result.prob = sampled_token_p; + + // set probability for top n_probs tokens + result.probs.reserve(n_probs); + for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { + result.probs.push_back({ + cur[i].id, + llama_detokenize(ctx, {cur[i].id}, special), + cur[i].p + }); + } + } + } + json get_formated_generation(const server_slot & slot) const { const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); @@ -2163,6 +2190,7 @@ struct server_context { res.stop = false; res.stream = slot.params.stream; res.content = tkn.text_to_send; + res.post_sampling_probs = slot.params.post_sampling_probs; res.oaicompat = slot.params.oaicompat; res.oaicompat_model = slot.params.oaicompat_model; res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; @@ -2175,26 +2203,18 @@ struct server_context { {"multimodal", false} }; slot.update_chat_msg(res.oaicompat_msg_diffs); + + // populate res.probs_output if (slot.sparams.n_probs > 0) { - const std::vector to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false); - const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size()); - const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size()); - - std::vector probs_output; - if (probs_pos < probs_stop_pos) { - probs_output = std::vector( - slot.generated_token_probs.begin() + probs_pos, - slot.generated_token_probs.begin() + probs_stop_pos); - } - slot.n_sent_token_probs = probs_stop_pos; - - res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output); + res.probs_output = {tkn}; // copy the token probs + res.data["completion_probabilities"] = probs_vector_to_json(ctx, res.probs_output); } if (slot.oaicompat) { res.data["oaicompat_token_ctr"] = slot.n_decoded; res.data["model"] = slot.oaicompat_model; } + // populate timings if this is final response or timings_per_token is enabled if (slot.params.timings_per_token) { res.timings = slot.get_timings(); @@ -2212,6 +2232,8 @@ struct server_context { res.stream = slot.params.stream; res.include_usage = slot.params.include_usage; res.content = slot.generated_text; + res.timings = slot.get_timings(); + res.post_sampling_probs = slot.params.post_sampling_probs; res.oaicompat = slot.params.oaicompat; res.oaicompat_model = slot.params.oaicompat_model; res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; @@ -2239,26 +2261,23 @@ struct server_context { //{"oaicompat_chat_format", slot.params.oaicompat_chat_format}, }; + // populate res.probs_output if (slot.sparams.n_probs > 0) { - std::vector probs; if (!slot.params.stream && slot.stopped_word) { const std::vector stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false); size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); - probs = std::vector( + res.probs_output = std::vector( slot.generated_token_probs.begin(), slot.generated_token_probs.end() - safe_offset); } else { - probs = std::vector( + res.probs_output = std::vector( slot.generated_token_probs.begin(), slot.generated_token_probs.end()); } - //res.generation_params = slot.params; - res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs); + res.data["completion_probabilities"] = probs_vector_to_json(ctx, res.probs_output); } - res.timings = slot.get_timings(); - if (slot.oaicompat) { res.data["oaicompat_token_ctr"] = slot.n_decoded; res.data["model"] = slot.oaicompat_model; @@ -3199,7 +3218,8 @@ struct server_context { } completion_token_output result; - const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i); + const int tok_idx = slot.i_batch - i; + const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, tok_idx); llama_sampling_accept(slot.ctx_sampling, ctx, id, true); @@ -3215,35 +3235,12 @@ struct server_context { slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3; - llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false }; result.tok = id; + result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs result.text_to_send = llama_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); - const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs); - if (n_probs > 0) { - const size_t n_valid = slot.ctx_sampling->n_valid; - - // Make sure at least n_probs top tokens are at the front of the vector: - if (slot.sparams.temp == 0.0f && n_probs > n_valid) { - llama_sample_top_k(ctx, &cur_p, n_probs, 0); - } - - if (slot.sparams.temp == 0.0f) { - // With greedy sampling the probabilities have possibly not been calculated. - for (size_t i = 0; i < n_probs; ++i) { - result.probs.push_back({ - cur_p.data[i].id,llama_detokenize(ctx, {cur_p.data[i].id}, params.special), - i == 0 ? 1.0f : 0.0f - }); - } - } else { - for (size_t i = 0; i < n_probs; ++i) { - result.probs.push_back({ - cur_p.data[i].id, llama_detokenize(ctx, {cur_p.data[i].id}, params.special), - i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability. - }); - } - } + if (slot.sparams.n_probs > 0) { + populate_token_probs(slot, result, slot.params.post_sampling_probs, params.special, tok_idx); } if (!process_token(result, slot)) { @@ -3348,7 +3345,11 @@ struct server_context { result.tok = ids[i]; result.text_to_send = llama_token_to_piece(ctx, result.tok, params.special); - // result.prob = 1.0f; // set later + result.prob = 1.0f; // set later + + if (slot.sparams.n_probs > 0) { + populate_token_probs(slot, result, slot.params.post_sampling_probs, params.special, i); + } if (!process_token(result, slot)) { // release slot because of stop condition diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 5efa9cdd..d7fd85f9 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -406,7 +406,6 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector cur; +}; + +static token_probabilities get_token_probabilities(llama_context * ctx, int idx, llama_token sampled_token_id, int n_sorted) { + const auto * logits = llama_get_logits_ith(ctx, idx); + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + n_sorted = std::min(n_sorted, n_vocab); + + std::vector> sorted(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) sorted[token_id] = {logits[token_id], token_id}; + + std::partial_sort(sorted.begin(), sorted.begin() + n_sorted, sorted.end(), std::greater>{}); + + float max_l = sorted.front().first; + float cum_sum = 0.0f; + float sampled_token_p = 0.0f; + bool sampled_token_found = false; + std::vector cur(n_sorted); + for (int i = 0; i < n_vocab; ++i) { + float p = expf(sorted[i].first - max_l); + cum_sum += p; + if (i < n_sorted) { + cur[i] = {sorted[i].second, sorted[i].first, p}; + } + if (!sampled_token_found && sorted[i].second == sampled_token_id) { + sampled_token_p = p; + sampled_token_found = true; + } + } + for (int i = n_sorted; i < n_vocab; ++i) cum_sum += expf(sorted[i].first - max_l); + + float inv_cum_sum = 1/cum_sum; + for (int i = 0; i < n_sorted; ++i) cur[i].p *= inv_cum_sum; + sampled_token_p *= inv_cum_sum; + + return {sampled_token_p, cur}; +}