diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 52b63f7f..c6ce6834 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -1977,6 +1977,7 @@ void server_context::send_final_response(server_slot& slot) { res->oai_resp_reasoning_id = slot.oai_resp_reasoning_id; res->oai_resp_message_id = slot.oai_resp_message_id; res->n_decoded = slot.n_decoded; + res->n_prompt_tokens_cache = slot.n_prompt_tokens_cache; res->anthropic_thinking_block_started = slot.anthropic_thinking_block_started; res->anthropic_text_block_started = slot.anthropic_text_block_started; res->n_prompt_tokens = slot.n_prompt_tokens; @@ -3321,16 +3322,16 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t LLAMA_LOG_INFO("======== Cache: cache_size = %d, n_past0 = %d, n_past1 = %d, n_past_prompt1 = %d, n_past2 = %d, n_past_prompt2 = %d\n", (int32_t)slot.cache_tokens.size(), (int32_t)n_past0, (int32_t)prefix.first, (int32_t)prefix.second, (int32_t)prefix_nonexact.first, (int32_t)prefix_nonexact.second); int32_t size_threshold = 20; if (prefix.first + size_threshold < prefix_nonexact.first) { - LLAMA_LOG_WARN("Common part contains missing or extra space and new line\n"); + // LLAMA_LOG_WARN("Common part contains missing or extra space and new line\n"); prefix = prefix_nonexact; } slot.n_past = prefix.first; slot.n_past_prompt = prefix.second; slot.n_past_offset = slot.n_past_prompt - slot.n_past; - if (slot.n_past != slot.n_past_prompt) { - LLAMA_LOG_INFO("Mistokenization found and handled successfully.\n"); - } + //if (slot.n_past != slot.n_past_prompt) { + // LLAMA_LOG_INFO("Mistokenization found and handled successfully.\n"); + //} if ((slot.n_past + size_threshold < slot.cache_tokens.size())) { LLAMA_LOG_WARN("Common part does not match fully\n"); @@ -3360,7 +3361,7 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t slot.n_past_se--; } } - + slot.n_prompt_tokens_cache = slot.n_past_prompt; slot.n_prompt_tokens_processed = 0; } diff --git a/examples/server/server-context.h b/examples/server/server-context.h index 50e05131..937b4dec 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -56,6 +56,7 @@ struct server_slot { int32_t n_predict = -1; // TODO: disambiguate from params.n_predict int32_t n_prompt_tokens = 0; + int32_t n_prompt_tokens_cache = 0; int32_t n_prompt_tokens_processed = 0; json prompt; // can be either a string, array of strings or array of token ids diff --git a/examples/server/server-task.cpp b/examples/server/server-task.cpp index 561f3120..f1e8e3cb 100644 --- a/examples/server/server-task.cpp +++ b/examples/server/server-task.cpp @@ -120,6 +120,16 @@ json server_task_result_cmpl_partial::to_json_oaicompat_partial() { return res; } +json server_task_result_cmpl_final::usage_json_oaicompat() { + return json{ + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}, + {"prompt_tokens_details", json { {"cached_tokens", n_prompt_tokens_cache} }}, + }; +} + + json server_task_result_cmpl_final::to_json_oaicompat_final() { std::time_t t = std::time(0); json logprobs = json(nullptr); // OAI default to null @@ -144,11 +154,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_final() { {"created", t}, {"model", oaicompat_model}, {"object", "text_completion"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens} - }}, + {"usage", usage_json_oaicompat()}, {"id", oaicompat_cmpl_id} }; @@ -379,11 +385,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat_final() { {"created", t}, {"model", oaicompat_model}, {"object", "chat.completion"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens} - }}, + {"usage", usage_json_oaicompat()}, {"id", oaicompat_cmpl_id} }; @@ -445,11 +447,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() { {"id", oaicompat_cmpl_id}, {"model", oaicompat_model}, {"object", "chat.completion.chunk"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens}, - }}, + {"usage", usage_json_oaicompat()}, }); } if (timings.prompt_n >= 0) { @@ -523,10 +521,11 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_final() { {"object", "response"}, {"output", output}, {"status", "completed"}, - {"usage", json{ + {"usage", json { {"input_tokens", n_prompt_tokens}, {"output_tokens", n_decoded}, {"total_tokens", n_decoded + n_prompt_tokens}, + {"input_tokens_details", json { {"cached_tokens", n_prompt_tokens_cache} }}, }}, }; @@ -633,11 +632,12 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() { {"status", "completed"}, {"model", oaicompat_model}, {"output", output}, - {"usage", json{ + {"usage", json { {"input_tokens", n_prompt_tokens}, {"output_tokens", n_decoded}, {"total_tokens", n_decoded + n_prompt_tokens}, - }}, + {"input_tokens_details", json { {"cached_tokens", n_prompt_tokens_cache} }}, + }} }}, }}, }); @@ -703,7 +703,8 @@ json server_task_result_cmpl_final::to_json_anthropic_final() { {"stop_reason", stop_reason}, {"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)}, {"usage", { - {"input_tokens", n_prompt_tokens}, + {"cache_read_input_tokens", n_prompt_tokens_cache}, + {"input_tokens", n_prompt_tokens - n_prompt_tokens_cache}, {"output_tokens", n_decoded} }} }; @@ -923,7 +924,8 @@ json server_task_result_cmpl_partial::to_json_anthropic_partial() { {"stop_reason", nullptr}, {"stop_sequence", nullptr}, {"usage", { - {"input_tokens", n_prompt_tokens}, + {"cache_read_input_tokens", n_prompt_tokens_cache}, + {"input_tokens", n_prompt_tokens - n_prompt_tokens_cache}, {"output_tokens", 0} }} }} diff --git a/examples/server/server-task.h b/examples/server/server-task.h index 9529261f..2bba8cbb 100644 --- a/examples/server/server-task.h +++ b/examples/server/server-task.h @@ -166,6 +166,7 @@ struct server_task_result { bool truncated; int32_t n_decoded; int32_t n_prompt_tokens; + int32_t n_prompt_tokens_cache; int32_t n_tokens_cached; bool has_new_line; std::string stopping_word; @@ -258,6 +259,8 @@ struct server_task_result_cmpl_final : server_task_result { json to_json_non_oaicompat_final(); + json usage_json_oaicompat(); + json to_json_oaicompat_final(); json to_json_oaicompat_chat_final(); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d1740e60..feaf1b4e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1262,7 +1262,8 @@ int main(int argc, char ** argv) { {"object", "model"}, {"created", std::time(0)}, {"owned_by", "llamacpp"}, - {"meta", model_meta} + {"meta", model_meta}, + {"max_model_len", params.n_ctx}, //vllm specs }, }} };