mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-20 22:49:31 +00:00
Refactor chat and server file (#1062)
* Add alternative log functions * chat: fix int overflow, prevent size calculation in float/double (#17357) * chat: fix int overflow, prevent size calculation in float/double * Update common/chat.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * common : move all common_chat_parse_* to chat-parser.cpp. (#17481) # Conflicts: # common/chat.cpp * server: split server.cpp code into server/common/task/queue/context * Fix compiler warning * Clean up code * common: use native MultiByteToWideChar * move server prompt to server task * Clean code * delete utils.hpp --------- Co-authored-by: firecoperana <firecoperana> Co-authored-by: Xuan-Son Nguyen <son@huggingface.co> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: DAN™ <dranger003@gmail.com>
This commit is contained in:
@@ -12,8 +12,15 @@ endif()
|
||||
|
||||
set(TARGET_SRCS
|
||||
server.cpp
|
||||
utils.hpp
|
||||
httplib.h
|
||||
server-task.cpp
|
||||
server-task.h
|
||||
server-queue.cpp
|
||||
server-queue.h
|
||||
server-common.cpp
|
||||
server-common.h
|
||||
server-context.cpp
|
||||
server-context.h
|
||||
)
|
||||
set(PUBLIC_ASSETS
|
||||
index.html.gz
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
455
examples/server/server-common.h
Normal file
455
examples/server/server-common.h
Normal file
@@ -0,0 +1,455 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include <src/llama-impl.h>
|
||||
#include "chat.h"
|
||||
#include "mtmd.h"
|
||||
#include "mtmd-helper.h"
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cinttypes>
|
||||
|
||||
|
||||
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "base64.hpp"
|
||||
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <random>
|
||||
#include <set>
|
||||
|
||||
// increase max payload length to allow use of larger context size
|
||||
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
|
||||
// increase backlog size to avoid connection resets for >> 1 slots
|
||||
#define CPPHTTPLIB_LISTEN_BACKLOG 512
|
||||
// increase max URI length to handle longer prompts in query string
|
||||
#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 32768
|
||||
// disable Nagle's algorithm
|
||||
#define CPPHTTPLIB_TCP_NODELAY true
|
||||
#include "httplib.h"
|
||||
|
||||
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||
#define SLT_CNT(slot, fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
|
||||
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||
|
||||
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
|
||||
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
|
||||
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
||||
enum error_type {
|
||||
ERROR_TYPE_INVALID_REQUEST,
|
||||
ERROR_TYPE_AUTHENTICATION,
|
||||
ERROR_TYPE_SERVER,
|
||||
ERROR_TYPE_NOT_FOUND,
|
||||
ERROR_TYPE_PERMISSION,
|
||||
ERROR_TYPE_UNAVAILABLE, // custom error
|
||||
ERROR_TYPE_NOT_SUPPORTED, // custom error
|
||||
};
|
||||
|
||||
extern bool server_verbose;
|
||||
extern bool server_log_json;
|
||||
|
||||
#ifndef SERVER_VERBOSE
|
||||
#define SERVER_VERBOSE 1
|
||||
#endif
|
||||
|
||||
#if SERVER_VERBOSE != 1
|
||||
#define LOG_VERBOSE(MSG, ...)
|
||||
#else
|
||||
#define LOG_VERBOSE(MSG, ...) \
|
||||
do \
|
||||
{ \
|
||||
if (server_verbose) \
|
||||
{ \
|
||||
server_log("VERB", __func__, __LINE__, MSG, __VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#define LOG_ERROR( MSG, ...) server_log("ERR", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||
#define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||
#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||
|
||||
using raw_buffer = std::vector<uint8_t>;
|
||||
|
||||
void server_log(const char* level, const char* function, int line, const char* message, const json& extra);
|
||||
|
||||
template <typename T>
|
||||
static T json_value(const json& body, const std::string& key, const T& default_value) {
|
||||
// Fallback null to default value
|
||||
if (body.contains(key) && !body.at(key).is_null()) {
|
||||
try {
|
||||
return body.at(key);
|
||||
}
|
||||
catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const& err) {
|
||||
std::stringstream ss;
|
||||
ss << "Wrong type supplied for parameter '" << key << "'. Expected '" << json(default_value).type_name() << "', using default value: " << err.what();
|
||||
LOG_WARNING(ss.str().c_str(), body);
|
||||
return default_value;
|
||||
}
|
||||
}
|
||||
else {
|
||||
return default_value;
|
||||
}
|
||||
}
|
||||
|
||||
// thin wrapper around common_grammar_trigger with (de)serialization functions
|
||||
struct server_grammar_trigger {
|
||||
common_grammar_trigger value;
|
||||
|
||||
server_grammar_trigger() = default;
|
||||
server_grammar_trigger(const common_grammar_trigger& value) : value(value) {}
|
||||
server_grammar_trigger(const json& in);
|
||||
|
||||
json to_json() const;
|
||||
};
|
||||
|
||||
|
||||
//
|
||||
// chat template utils
|
||||
//
|
||||
|
||||
//
|
||||
// base64 utils (TODO: move to common in the future)
|
||||
//
|
||||
|
||||
static const std::string base64_chars =
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
"abcdefghijklmnopqrstuvwxyz"
|
||||
"0123456789+/";
|
||||
|
||||
bool is_base64(uint8_t c);
|
||||
|
||||
std::vector<uint8_t> base64_decode(const std::string& encoded_string);
|
||||
|
||||
//
|
||||
// random string / id
|
||||
//
|
||||
|
||||
std::string random_string();
|
||||
|
||||
std::string gen_chatcmplid();
|
||||
|
||||
std::string gen_tool_call_id();
|
||||
|
||||
//
|
||||
// other common utils
|
||||
//
|
||||
float get_slot_similarity(size_t lcp, size_t prompt_length, size_t cache_length);
|
||||
|
||||
size_t common_part(const std::vector<llama_token>& a, const std::vector<llama_token>& b);
|
||||
|
||||
size_t common_part(const std::string& a, const std::string& b);
|
||||
|
||||
// return the last index of character that can form a valid string
|
||||
// if the last character is potentially cut in half, return the index before the cut
|
||||
// if validate_utf8(text) == text.size(), then the whole text is valid utf8
|
||||
size_t validate_utf8(const std::string& text);
|
||||
|
||||
// TODO: reuse llama_detokenize
|
||||
|
||||
std::string tokens_to_str(llama_context* ctx, const llama_tokens& tokens);
|
||||
|
||||
// format incomplete utf-8 multibyte character for output
|
||||
std::string tokens_to_output_formatted_string(const llama_context* ctx, const llama_token token);
|
||||
|
||||
struct common_prefix {
|
||||
size_t first = 0;
|
||||
size_t second = 0;
|
||||
};
|
||||
|
||||
common_prefix common_prefix_add(const common_prefix& a, const common_prefix& b);
|
||||
|
||||
common_prefix find_common_string_prefix(const std::string& a_str, const std::string& b_str, const std::set<char>& ignore_set);
|
||||
|
||||
size_t find_n_tokens_from_string(const llama_context* ctx, const llama_tokens& a, const size_t max_size, size_t start,
|
||||
std::vector<size_t>& map);
|
||||
|
||||
std::string remove_with_set(std::string str, const std::set<char>& chars_to_remove);
|
||||
|
||||
common_prefix find_largest_common_number(const std::vector<size_t>& a_list, const std::vector<size_t>& b_list);
|
||||
|
||||
size_t find_n_tokens_from_string_with_ignore(const llama_context* ctx, const llama_tokens& a, const size_t max_size, size_t start, const std::set<char>& ignore_set,
|
||||
std::vector<size_t>& map);
|
||||
|
||||
common_prefix find_common_text_token_prefix(const llama_context* ctx, const llama_tokens& a, const llama_tokens& b,
|
||||
size_t start, bool exact);
|
||||
|
||||
struct completion_token_output {
|
||||
llama_token tok;
|
||||
std::string text_to_send;
|
||||
float prob;
|
||||
|
||||
struct prob_info {
|
||||
llama_token tok;
|
||||
std::string txt;
|
||||
float prob;
|
||||
};
|
||||
std::vector<prob_info> probs;
|
||||
|
||||
json to_json(bool post_sampling_probs) const;
|
||||
|
||||
static float logarithm(float x);
|
||||
|
||||
static std::vector<unsigned char> str_to_bytes(const std::string& str);
|
||||
|
||||
static json probs_vector_to_json(const std::vector<completion_token_output>& probs, bool post_sampling_probs);
|
||||
};
|
||||
|
||||
// convert a vector of completion_token_output to json
|
||||
json probs_vector_to_json(const llama_context* ctx, const std::vector<completion_token_output>& probs);
|
||||
|
||||
bool server_sent_event(httplib::DataSink& sink, const json& data);
|
||||
|
||||
bool server_sent_anthropic_event(httplib::DataSink& sink, const json& data);
|
||||
|
||||
//
|
||||
// OAI utils
|
||||
//
|
||||
// used by /completions endpoint
|
||||
json oaicompat_chat_params_parse(const json& body);
|
||||
|
||||
struct oaicompat_parser_options {
|
||||
bool use_jinja;
|
||||
bool prefill_assistant;
|
||||
common_reasoning_format reasoning_format;
|
||||
std::map<std::string, std::string> chat_template_kwargs;
|
||||
common_chat_templates* tmpls;
|
||||
bool allow_image;
|
||||
bool allow_audio;
|
||||
bool enable_thinking = true;
|
||||
};
|
||||
|
||||
// used by /chat/completions endpoint
|
||||
json oaicompat_chat_params_parse(
|
||||
const struct llama_model* model,
|
||||
json& body, /* openai api json semantics */
|
||||
const oaicompat_parser_options& opt,
|
||||
std::vector<raw_buffer>& out_files);
|
||||
|
||||
json anthropic_params_from_json(
|
||||
const struct llama_model* model,
|
||||
const json& body_in, /* anthropic messages api json semantics */
|
||||
const oaicompat_parser_options& opt,
|
||||
std::vector<raw_buffer>& out_files);
|
||||
|
||||
|
||||
//
|
||||
// tokenizer and input processing utils
|
||||
//
|
||||
|
||||
bool json_is_array_of_numbers(const json& data);
|
||||
|
||||
// is array having BOTH numbers & strings?
|
||||
bool json_is_array_of_mixed_numbers_strings(const json& data);
|
||||
|
||||
// does array have any individual integers/tokens?
|
||||
bool json_is_array_and_contains_numbers(const json& data);
|
||||
|
||||
// get value by path(key1 / key2)
|
||||
json json_get_nested_values(const std::vector<std::string>& paths, const json& js);
|
||||
|
||||
/**
|
||||
* this handles 2 cases:
|
||||
* - only string, example: "string"
|
||||
* - mixed string and tokens, example: [12, 34, "string", 56, 78]
|
||||
*/
|
||||
std::vector<llama_token> tokenize_mixed(const llama_vocab* vocab, const json& json_prompt, bool add_special, bool parse_special);
|
||||
|
||||
json format_tokenizer_response(const std::vector<llama_token>& tokens);
|
||||
|
||||
json format_detokenized_response(const std::string& content);
|
||||
|
||||
json format_error_response(const std::string& message, const enum error_type type);
|
||||
|
||||
struct token_probabilities {
|
||||
float sampled_token_p;
|
||||
std::vector<llama_token_data> cur;
|
||||
};
|
||||
|
||||
token_probabilities get_token_probabilities(llama_context* ctx, int idx, llama_token sampled_token_id, int n_sorted);
|
||||
|
||||
/**
|
||||
* server_tokens is a helper to manage the input tokens and image for the server.
|
||||
* it is made this way to simplify the logic of KV cache management.
|
||||
*/
|
||||
struct server_tokens {
|
||||
bool has_mtmd = false;
|
||||
|
||||
private: // disallow accessing these members directly, risking out-of-sync
|
||||
|
||||
// map a **start** index in tokens to the image chunk
|
||||
// note: the order need to be in-sync with tokens
|
||||
std::map<size_t, mtmd::input_chunk_ptr> map_idx_to_media;
|
||||
|
||||
// list of tokens
|
||||
// if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk
|
||||
// otherwise, it is a normal text token
|
||||
// note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list
|
||||
// note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos
|
||||
llama_tokens tokens;
|
||||
|
||||
// for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
|
||||
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
|
||||
// idx 0 1 2 3 4 5 6 7 8 9 10
|
||||
// pos 0 1 2 3 4 5 5 5 7 7 7
|
||||
// map_idx_to_media will contain: {5, img0}, {8, img1}
|
||||
|
||||
public:
|
||||
server_tokens() = default;
|
||||
~server_tokens() = default;
|
||||
|
||||
// Prevent copying
|
||||
server_tokens(const server_tokens&) = delete;
|
||||
server_tokens& operator=(const server_tokens&) = delete;
|
||||
|
||||
// Allow moving (usually implicitly generated if members are movable)
|
||||
server_tokens(server_tokens&&) = default;
|
||||
server_tokens& operator=(server_tokens&&) = default;
|
||||
|
||||
// Allow accessing elements using [] operator
|
||||
llama_token operator[](size_t index) { return tokens[index]; }
|
||||
const llama_token& operator[](size_t index) const { return tokens[index]; }
|
||||
|
||||
server_tokens(mtmd::input_chunks& mtmd_chunks, bool has_mtmd);
|
||||
|
||||
server_tokens(const llama_tokens& tokens, bool has_mtmd);
|
||||
|
||||
llama_pos pos_next() const;
|
||||
|
||||
// for debugging
|
||||
std::string str() const;
|
||||
|
||||
const mtmd::input_chunk_ptr& find_chunk(size_t idx) const;
|
||||
|
||||
void push_back(llama_token tok);
|
||||
|
||||
// will create a copy of the chunk if it contains non-text data
|
||||
void push_back(const mtmd_input_chunk* chunk);
|
||||
|
||||
// appends server tokens, updates the media map. copies media chunks.
|
||||
void push_back(server_tokens& tokens);
|
||||
|
||||
// for compatibility with context shift and prompt truncation
|
||||
void insert(const std::vector<llama_token>& inp_tokens);
|
||||
|
||||
// for compatibility with context shift and prompt truncation
|
||||
void resize(size_t size);
|
||||
|
||||
llama_token* data();
|
||||
|
||||
llama_tokens::iterator begin();
|
||||
|
||||
llama_tokens::iterator end();
|
||||
|
||||
llama_tokens::const_iterator cbegin();
|
||||
|
||||
llama_tokens::const_iterator cend();
|
||||
|
||||
llama_tokens tokens_data();
|
||||
|
||||
// for compatibility with speculative decoding, ctx shift, slot save/load
|
||||
const std::vector<llama_token>& get_text_tokens() const;
|
||||
|
||||
// for compatibility with speculative decoding
|
||||
void set_token(llama_pos pos, llama_token id);
|
||||
|
||||
size_t size() const;
|
||||
|
||||
bool empty() const;
|
||||
|
||||
void clear();
|
||||
|
||||
void keep_first(size_t n);
|
||||
|
||||
std::string detokenize(const llama_context* ctx, bool special) const;
|
||||
|
||||
std::string detokenize(const llama_context* ctx, bool special, size_t start, size_t length) const;
|
||||
|
||||
size_t find_n_from_tokens(const llama_context* ctx, const server_tokens& b, bool special,
|
||||
size_t start, const size_t length);
|
||||
|
||||
size_t get_common_prefix_exact(const server_tokens& b) const;
|
||||
|
||||
|
||||
common_prefix get_common_prefix(const llama_context* ctx, const server_tokens& b, bool exact = false) const;
|
||||
// take first n tokens of tokens list a
|
||||
// find the common prefix between a and b
|
||||
common_prefix get_common_prefix_first_n(const llama_context* ctx, const server_tokens& b, size_t n, bool exact = false) const;
|
||||
|
||||
// make sure all text tokens are within the vocab range
|
||||
bool validate(const struct llama_context* ctx) const;
|
||||
|
||||
// encode and decode the image chunk
|
||||
int32_t process_chunk(
|
||||
llama_context* ctx,
|
||||
mtmd_context* mctx,
|
||||
size_t idx,
|
||||
llama_pos pos,
|
||||
int32_t seq_id,
|
||||
size_t& n_tokens_out) const;
|
||||
|
||||
// Keep the first n_keep and remove n_discard tokens from tokens
|
||||
void discard_n_tokens(int32_t n_keep, int32_t n_discard);
|
||||
|
||||
// Similarity between prompt and cached
|
||||
float get_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const;
|
||||
|
||||
// Similarity between common part and cache
|
||||
float get_cached_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const;
|
||||
};
|
||||
|
||||
// Computes FNV-1a hash of the data
|
||||
std::string fnv_hash(const uint8_t* data, size_t len);
|
||||
|
||||
server_tokens process_mtmd_prompt(mtmd_context* mctx, std::string prompt, std::vector<raw_buffer> files);
|
||||
|
||||
/**
|
||||
* break the input "prompt" object into multiple prompt if needed, then tokenize them
|
||||
* use tokenize_input_prompts() if the input could be an array.
|
||||
* this supports these cases:
|
||||
* - "prompt": "string"
|
||||
* - "prompt": [12, 34, 56]
|
||||
* - "prompt": [12, 34, "string", 56, 78]
|
||||
* - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] }
|
||||
*/
|
||||
server_tokens tokenize_input_subprompt(const llama_vocab* vocab, mtmd_context* mctx, const json& json_prompt, bool add_special, bool parse_special);
|
||||
|
||||
/**
|
||||
* break the input "prompt" object into multiple prompt if needed, then tokenize them
|
||||
* this supports these cases:
|
||||
* - "prompt": "string"
|
||||
* - "prompt": [12, 34, 56]
|
||||
* - "prompt": [12, 34, "string", 56, 78]
|
||||
* - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] }
|
||||
* and multiple prompts (multi-tasks):
|
||||
* - "prompt": ["string1", "string2"]
|
||||
* - "prompt": ["string1", [12, 34, 56]]
|
||||
* - "prompt": [[12, 34, 56], [78, 90, 12]]
|
||||
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}]
|
||||
*/
|
||||
std::vector<server_tokens> tokenize_input_prompts(const llama_vocab* vocab, mtmd_context* mctx, const json& json_prompt, bool add_special, bool parse_special);
|
||||
|
||||
// Assuming raw_buffer has .data() and .size() members
|
||||
void print_files_info(const std::vector<raw_buffer>& files);
|
||||
|
||||
bool prompt_cache_equal(llama_context* ctx, const server_tokens& cache_tokens,
|
||||
const server_tokens& prompt_tokens, size_t start, const common_prefix& prefix);
|
||||
2763
examples/server/server-context.cpp
Normal file
2763
examples/server/server-context.cpp
Normal file
File diff suppressed because it is too large
Load Diff
316
examples/server/server-context.h
Normal file
316
examples/server/server-context.h
Normal file
@@ -0,0 +1,316 @@
|
||||
#include "server-task.h"
|
||||
#include "server-queue.h"
|
||||
#include "speculative.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
|
||||
|
||||
|
||||
enum slot_state {
|
||||
SLOT_STATE_IDLE,
|
||||
SLOT_STATE_PROCESSING,
|
||||
};
|
||||
|
||||
|
||||
|
||||
enum slot_command {
|
||||
SLOT_COMMAND_NONE,
|
||||
SLOT_COMMAND_LOAD_PROMPT,
|
||||
SLOT_COMMAND_RELEASE,
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
struct slot_params {
|
||||
bool stream = true;
|
||||
bool include_usage = false;
|
||||
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
|
||||
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
|
||||
std::vector<std::string> antiprompt;
|
||||
|
||||
bool timings_per_token = false;
|
||||
bool post_sampling_probs = false;
|
||||
json input_prefix;
|
||||
json input_suffix;
|
||||
|
||||
// speculative decoding parameters
|
||||
struct {
|
||||
int n_max = 16; // max drafted tokens
|
||||
int n_min = 0; // min drafted tokens to accept
|
||||
float p_min = 0.75f; // min probability required to accept a token in the draft
|
||||
} speculative;
|
||||
|
||||
// OAI-compat fields
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_syntax oaicompat_chat_syntax;
|
||||
|
||||
};
|
||||
|
||||
|
||||
struct server_slot {
|
||||
int id;
|
||||
int id_task = -1;
|
||||
int id_multi = -1;
|
||||
|
||||
struct slot_params params;
|
||||
|
||||
slot_state state = SLOT_STATE_IDLE;
|
||||
slot_command command = SLOT_COMMAND_NONE;
|
||||
|
||||
llama_context* ctx = nullptr;
|
||||
// used to determine the slot that has been used the longest
|
||||
int64_t t_last_used = -1;
|
||||
|
||||
std::unique_ptr<const server_task> task;
|
||||
|
||||
// generation props
|
||||
int32_t n_ctx = 0; // context size per slot
|
||||
int32_t n_past = 0;
|
||||
int32_t n_past_prompt = 0;
|
||||
int32_t n_decoded = 0;
|
||||
int32_t n_remaining = -1;
|
||||
int32_t n_discarded_prompt = 0;
|
||||
int32_t n_kept_prompt = 0;
|
||||
|
||||
int32_t i_batch = -1;
|
||||
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
|
||||
|
||||
int32_t n_prompt_tokens = 0;
|
||||
int32_t n_prompt_tokens_processed = 0;
|
||||
|
||||
json prompt; // can be either a string, array of strings or array of token ids
|
||||
|
||||
// when a task is submitted, we first tokenize the prompt and store it here
|
||||
server_tokens prompt_tokens;
|
||||
server_tokens cache_tokens;
|
||||
|
||||
std::string generated_text;
|
||||
|
||||
std::vector<completion_token_output> generated_token_probs;
|
||||
common_chat_msg chat_msg;
|
||||
|
||||
bool infill = false;
|
||||
bool embedding = false;
|
||||
bool has_next_token = true;
|
||||
bool truncated = false;
|
||||
bool stopped_eos = false;
|
||||
bool stopped_word = false;
|
||||
bool stopped_limit = false;
|
||||
|
||||
bool oaicompat = false;
|
||||
|
||||
std::string oaicompat_model;
|
||||
std::string stopping_word;
|
||||
stop_type stop;
|
||||
|
||||
server_prompt server_cached_prompt;
|
||||
|
||||
void prompt_save(server_prompt_cache& prompt_cache) const;
|
||||
|
||||
void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens);
|
||||
|
||||
// sampling
|
||||
llama_token sampled;
|
||||
struct llama_sampling_params sparams;
|
||||
llama_sampling_context* ctx_sampling = nullptr;
|
||||
json json_schema;
|
||||
|
||||
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
std::vector<std::string> generated_tool_call_ids;
|
||||
|
||||
int32_t ga_i = 0; // group-attention state
|
||||
int32_t ga_n = 1; // group-attention factor
|
||||
int32_t ga_w = 512; // group-attention width
|
||||
|
||||
// multimodal
|
||||
mtmd_context* mctx = nullptr;
|
||||
|
||||
// speculative decoding
|
||||
struct llama_speculative* spec = nullptr;
|
||||
llama_context* ctx_dft = nullptr;
|
||||
llama_batch batch_spec = {};
|
||||
|
||||
// speculative decoding stats
|
||||
int32_t n_draft_total = 0; // Total draft tokens generated
|
||||
int32_t n_draft_accepted = 0; // Draft tokens actually accepted
|
||||
|
||||
int32_t n_past_se = 0; // self-extend
|
||||
|
||||
// stats
|
||||
size_t n_sent_text = 0; // number of sent text character
|
||||
size_t n_sent_token_probs = 0;
|
||||
|
||||
int64_t t_start_process_prompt;
|
||||
int64_t t_start_generation;
|
||||
|
||||
double t_prompt_processing; // ms
|
||||
double t_token_generation; // ms
|
||||
|
||||
void reset();
|
||||
|
||||
bool has_budget(gpt_params& global_params);
|
||||
|
||||
bool available() const;
|
||||
|
||||
bool is_processing() const;
|
||||
|
||||
void add_token_string(const completion_token_output& token);
|
||||
|
||||
void release();
|
||||
|
||||
json get_formated_timings() const;
|
||||
|
||||
result_timings get_timings() const;
|
||||
|
||||
const common_chat_msg& update_chat_msg(std::vector<common_chat_msg_diff>& diffs);
|
||||
|
||||
size_t find_stopping_strings(const std::string& text, const size_t last_token_size, bool is_full_stop);
|
||||
|
||||
void print_timings() const;
|
||||
};
|
||||
|
||||
struct server_metrics {
|
||||
int64_t t_start = 0;
|
||||
|
||||
uint64_t n_prompt_tokens_processed_total = 0;
|
||||
uint64_t t_prompt_processing_total = 0;
|
||||
uint64_t n_tokens_predicted_total = 0;
|
||||
uint64_t t_tokens_generation_total = 0;
|
||||
|
||||
uint64_t n_prompt_tokens_processed = 0;
|
||||
uint64_t t_prompt_processing = 0;
|
||||
|
||||
uint64_t n_tokens_predicted = 0;
|
||||
uint64_t t_tokens_generation = 0;
|
||||
|
||||
void init();
|
||||
|
||||
void on_prompt_eval(const server_slot& slot);
|
||||
|
||||
void on_prediction(const server_slot& slot);
|
||||
|
||||
void reset_bucket();
|
||||
};
|
||||
|
||||
struct server_context {
|
||||
llama_model* model = nullptr;
|
||||
llama_context* ctx = nullptr;
|
||||
std::vector<llama_lora_adapter_container> lora_adapters;
|
||||
|
||||
gpt_params params;
|
||||
|
||||
llama_batch batch;
|
||||
|
||||
bool clean_kv_cache = true;
|
||||
bool add_bos_token = true;
|
||||
bool has_eos_token = false;
|
||||
|
||||
// multimodal
|
||||
mtmd_context* mctx = nullptr;
|
||||
|
||||
// For speculative decoding
|
||||
llama_model* model_draft = nullptr;
|
||||
llama_context* ctx_draft = nullptr;
|
||||
llama_context_params cparams_dft;
|
||||
|
||||
int32_t n_ctx; // total context for all clients / slots
|
||||
|
||||
// system prompt
|
||||
bool system_need_update = false;
|
||||
|
||||
std::string system_prompt;
|
||||
std::vector<llama_token> system_tokens;
|
||||
|
||||
// slots / clients
|
||||
std::vector<server_slot> slots;
|
||||
json default_generation_settings_for_props;
|
||||
|
||||
server_queue queue_tasks;
|
||||
server_response queue_results;
|
||||
|
||||
std::unique_ptr<server_prompt_cache> prompt_cache;
|
||||
|
||||
server_metrics metrics;
|
||||
|
||||
common_chat_templates_ptr chat_templates;
|
||||
oaicompat_parser_options oai_parser_opt;
|
||||
// Necessary similarity of prompt for slot selection
|
||||
float slot_prompt_similarity = 0.0f;
|
||||
int32_t cache_ram_n_min = 0;
|
||||
float cache_ram_similarity = 0.5f;
|
||||
|
||||
~server_context();
|
||||
|
||||
bool load_model(const gpt_params& params_);
|
||||
|
||||
void init();
|
||||
|
||||
std::vector<llama_token> tokenize(const json& json_prompt, bool add_special) const;
|
||||
|
||||
server_slot* get_slot_by_id(int id);
|
||||
|
||||
server_slot* get_available_slot(const server_task& task);
|
||||
|
||||
bool launch_slot_with_task(server_slot& slot, server_task& task);
|
||||
|
||||
void kv_cache_clear();
|
||||
|
||||
void system_prompt_update();
|
||||
|
||||
bool system_prompt_set(const std::string& sys_prompt);
|
||||
|
||||
bool process_token(completion_token_output& result, server_slot& slot);
|
||||
|
||||
void populate_token_probs(const server_slot& slot, completion_token_output& result, bool post_sampling, bool special, int idx);
|
||||
|
||||
json get_formated_generation(const server_slot& slot) const;
|
||||
|
||||
void send_error(const server_task& task, const std::string& error, const enum error_type type = ERROR_TYPE_SERVER);
|
||||
|
||||
void send_error(const server_slot& slot, const std::string& error, const enum error_type type = ERROR_TYPE_SERVER);
|
||||
|
||||
void send_error(const int id_task, const int id_multi, const std::string& error, const enum error_type type = ERROR_TYPE_SERVER);
|
||||
|
||||
// if multimodal is enabled, send an error and return false
|
||||
bool ensure_no_mtmd(const int id_task);
|
||||
|
||||
void send_partial_response(server_slot& slot, completion_token_output tkn);
|
||||
|
||||
void send_final_response(server_slot& slot);
|
||||
|
||||
void send_embedding(const server_slot& slot, const llama_batch& batch);
|
||||
|
||||
void request_completion(int id_task, int id_multi, json data, bool infill, bool embedding, server_tokens&& inputs);
|
||||
|
||||
void request_cancel(int id_task);
|
||||
|
||||
void split_multiprompt_task(int id_multi, server_task& multiprompt_task);
|
||||
|
||||
void process_single_task(server_task&& task);
|
||||
|
||||
void on_finish_multitask(const server_task_multi& multitask);
|
||||
|
||||
void print_tokens(const server_tokens& prompt, const server_tokens& cache, size_t start1 = 0, size_t start2 = 0, size_t length = 10);
|
||||
|
||||
void discard_n_kv_and_cache_tokens(llama_context* ctx, server_slot& slot, int32_t n_keep, int32_t n_discard);
|
||||
|
||||
// convert keep first few and discard next tokens in a to b
|
||||
void context_shift_find_n_tokens(llama_context* ctx, const server_tokens& a, const server_tokens& b, int32_t n_keep,
|
||||
int32_t n_discard, int32_t& n_kept, int32_t& n_discarded, bool exact = false);
|
||||
|
||||
void context_shift_prompt(llama_context* ctx, server_slot& slot, bool exact = false);
|
||||
|
||||
void update_slots();
|
||||
|
||||
json model_meta() const;
|
||||
};
|
||||
194
examples/server/server-queue.cpp
Normal file
194
examples/server/server-queue.cpp
Normal file
@@ -0,0 +1,194 @@
|
||||
#include "server-task.h"
|
||||
#include "server-queue.h"
|
||||
#include "server-common.h"
|
||||
|
||||
#include "log.h"
|
||||
#include <chrono>
|
||||
|
||||
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
|
||||
#define RES_INF(fmt, ...) LOG_INF("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define RES_WRN(fmt, ...) LOG_WRN("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define RES_ERR(fmt, ...) LOG_ERR("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define RES_DBG(fmt, ...) LOG_DBG("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
|
||||
|
||||
int server_queue::post(server_task task) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
if (task.id == -1) {
|
||||
task.id = id++;
|
||||
//LOG_VERBOSE("new task id", { {"new_id", task.id} });
|
||||
QUE_DBG("new task, id = %d\n", task.id);
|
||||
}
|
||||
queue_tasks.push_back(std::move(task));
|
||||
condition_tasks.notify_one();
|
||||
return task.id;
|
||||
}
|
||||
|
||||
void server_queue::defer(server_task&& task) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
queue_tasks_deferred.push_back(std::move(task));
|
||||
}
|
||||
|
||||
int server_queue::get_new_id() {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
int new_id = id++;
|
||||
//LOG_VERBOSE("new task id", { {"new_id", new_id} });
|
||||
QUE_DBG("new task, id = %d\n", id);
|
||||
return new_id;
|
||||
}
|
||||
|
||||
void server_queue::notify_slot_changed() {
|
||||
// move deferred tasks back to main loop
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
for (auto& task : queue_tasks_deferred) {
|
||||
queue_tasks.push_back(std::move(task));
|
||||
}
|
||||
queue_tasks_deferred.clear();
|
||||
}
|
||||
|
||||
void server_queue::on_new_task(std::function<void(server_task&&)> callback) {
|
||||
callback_new_task = std::move(callback);
|
||||
}
|
||||
|
||||
|
||||
void server_queue::start_loop() {
|
||||
running = true;
|
||||
|
||||
while (true) {
|
||||
LOG_VERBOSE("new task may arrive", {});
|
||||
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
if (queue_tasks.empty()) {
|
||||
lock.unlock();
|
||||
break;
|
||||
}
|
||||
server_task task = std::move(queue_tasks.front());
|
||||
queue_tasks.erase(queue_tasks.begin());
|
||||
lock.unlock();
|
||||
//LOG_VERBOSE("callback_new_task", { {"id_task", task.id} });
|
||||
callback_new_task(std::move(task));
|
||||
}
|
||||
|
||||
LOG_VERBOSE("update_multitasks", {});
|
||||
|
||||
// check if we have any finished multitasks
|
||||
auto queue_iterator = queue_multitasks.begin();
|
||||
while (queue_iterator != queue_multitasks.end()) {
|
||||
if (queue_iterator->subtasks_remaining.empty()) {
|
||||
// all subtasks done == multitask is done
|
||||
server_task_multi current_multitask = *queue_iterator;
|
||||
callback_finish_multitask(current_multitask);
|
||||
// remove this multitask
|
||||
queue_iterator = queue_multitasks.erase(queue_iterator);
|
||||
}
|
||||
else {
|
||||
++queue_iterator;
|
||||
}
|
||||
}
|
||||
|
||||
// all tasks in the current loop is processed, slots data is now ready
|
||||
LOG_VERBOSE("callback_update_slots", {});
|
||||
|
||||
callback_update_slots();
|
||||
|
||||
LOG_VERBOSE("wait for new task", {});
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
if (queue_tasks.empty()) {
|
||||
if (!running) {
|
||||
LOG_VERBOSE("ending start_loop", {});
|
||||
return;
|
||||
}
|
||||
condition_tasks.wait(lock, [&] {
|
||||
return (!queue_tasks.empty() || !running);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void server_queue::add_multitask(int id_multi, std::vector<int>& sub_ids) {
|
||||
std::lock_guard<std::mutex> lock(mutex_tasks);
|
||||
server_task_multi multi;
|
||||
multi.id = id_multi;
|
||||
std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
|
||||
queue_multitasks.push_back(multi);
|
||||
}
|
||||
|
||||
|
||||
void server_queue::update_multitask(int id_multi, int id_sub, server_task_result& result) {
|
||||
std::lock_guard<std::mutex> lock(mutex_tasks);
|
||||
for (auto& multitask : queue_multitasks) {
|
||||
if (multitask.id == id_multi) {
|
||||
multitask.subtasks_remaining.erase(id_sub);
|
||||
multitask.results.push_back(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void server_response::add_waiting_task_id(int id_task) {
|
||||
//LOG_VERBOSE("waiting for task id", { {"id_task", id_task} });
|
||||
QUE_DBG("waiting for task id, id = %d\n", id_task);
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
waiting_task_ids.insert(id_task);
|
||||
}
|
||||
|
||||
void server_response::remove_waiting_task_id(int id_task) {
|
||||
//LOG_VERBOSE("remove waiting for task id", { {"id_task", id_task} });
|
||||
QUE_DBG("remove waiting for task id, id = %d\n", id_task);
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
waiting_task_ids.erase(id_task);
|
||||
}
|
||||
|
||||
|
||||
server_task_result server_response::recv(int id_task) {
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
condition_results.wait(lock, [&] {
|
||||
return !queue_results.empty();
|
||||
});
|
||||
|
||||
for (int i = 0; i < (int)queue_results.size(); i++) {
|
||||
if (queue_results[i].id == id_task) {
|
||||
assert(queue_results[i].id_multi == -1);
|
||||
server_task_result res = queue_results[i];
|
||||
queue_results.erase(queue_results.begin() + i);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// should never reach here
|
||||
}
|
||||
|
||||
void server_response::send(server_task_result result) {
|
||||
//LOG_VERBOSE("send new result", { {"id_task", result.id} });
|
||||
QUE_DBG("send new result, id = %d\n", result.id);
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
for (const auto& id_task : waiting_task_ids) {
|
||||
// LOG_TEE("waiting task id %i \n", id_task);
|
||||
// for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
|
||||
if (result.id_multi == id_task) {
|
||||
//LOG_VERBOSE("callback_update_multitask", { {"id_task", id_task} });
|
||||
QUE_DBG("callback_update_multitask, id = %d\n", id_task);
|
||||
callback_update_multitask(id_task, result.id, result);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (result.id == id_task) {
|
||||
//LOG_VERBOSE("queue_results.push_back", { {"id_task", id_task} });
|
||||
QUE_DBG("queue_results.push_back, id = %d\n", id_task);
|
||||
queue_results.push_back(result);
|
||||
condition_results.notify_all();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
117
examples/server/server-queue.h
Normal file
117
examples/server/server-queue.h
Normal file
@@ -0,0 +1,117 @@
|
||||
#pragma once
|
||||
#include "server-task.h"
|
||||
|
||||
#include <condition_variable>
|
||||
#include <deque>
|
||||
#include <mutex>
|
||||
#include <unordered_set>
|
||||
|
||||
struct server_task_multi {
|
||||
int id = -1;
|
||||
|
||||
std::set<int> subtasks_remaining;
|
||||
std::vector<server_task_result> results;
|
||||
};
|
||||
|
||||
|
||||
struct server_queue {
|
||||
int id = 0;
|
||||
bool running;
|
||||
|
||||
// queues
|
||||
std::vector<server_task> queue_tasks;
|
||||
std::vector<server_task> queue_tasks_deferred;
|
||||
|
||||
std::vector<server_task_multi> queue_multitasks;
|
||||
|
||||
std::mutex mutex_tasks;
|
||||
std::condition_variable condition_tasks;
|
||||
|
||||
// callback functions
|
||||
std::function<void(server_task &&)> callback_new_task;
|
||||
std::function<void(server_task_multi &)> callback_finish_multitask;
|
||||
std::function<void(void)> callback_update_slots;
|
||||
|
||||
|
||||
// Add a new task to the end of the queue
|
||||
int post(server_task task);
|
||||
|
||||
// Add a new task, but defer until one slot is available
|
||||
void defer(server_task&& task);
|
||||
|
||||
// Get the next id for creating anew task
|
||||
int get_new_id();
|
||||
|
||||
// Register function to process a new task
|
||||
void on_new_task(std::function<void(server_task&&)> callback);
|
||||
|
||||
// Register function to process a multitask when it is finished
|
||||
void on_finish_multitask(std::function<void(server_task_multi&)> callback) {
|
||||
callback_finish_multitask = std::move(callback);
|
||||
}
|
||||
|
||||
// Register the function to be called when all slots data is ready to be processed
|
||||
void on_update_slots(std::function<void(void)> callback) {
|
||||
callback_update_slots = std::move(callback);
|
||||
}
|
||||
|
||||
// Call when the state of one slot is changed
|
||||
void notify_slot_changed();
|
||||
|
||||
// end the start_loop routine
|
||||
void terminate() {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
running = false;
|
||||
condition_tasks.notify_all();
|
||||
}
|
||||
|
||||
/**
|
||||
* Main loop consists of these steps:
|
||||
* - Wait until a new task arrives
|
||||
* - Process the task (i.e. maybe copy data into slot)
|
||||
* - Check if multitask is finished
|
||||
* - Update all slots
|
||||
*/
|
||||
void start_loop();
|
||||
|
||||
//
|
||||
// functions to manage multitasks
|
||||
//
|
||||
|
||||
// add a multitask by specifying the id of all subtask (subtask is a server_task)
|
||||
void add_multitask(int id_multi, std::vector<int>& sub_ids);
|
||||
|
||||
// updatethe remaining subtasks, while appending results to multitask
|
||||
void update_multitask(int id_multi, int id_sub, server_task_result& result);
|
||||
};
|
||||
|
||||
struct server_response {
|
||||
typedef std::function<void(int, int, server_task_result&)> callback_multitask_t;
|
||||
callback_multitask_t callback_update_multitask;
|
||||
|
||||
// for keeping track of all tasks waiting for the result
|
||||
std::set<int> waiting_task_ids;
|
||||
|
||||
// the main result queue
|
||||
std::vector<server_task_result> queue_results;
|
||||
|
||||
std::mutex mutex_results;
|
||||
std::condition_variable condition_results;
|
||||
|
||||
// add the id_task to the list of tasks waiting for response
|
||||
void add_waiting_task_id(int id_task);
|
||||
|
||||
// when the request is finished, we can remove task associated with it
|
||||
void remove_waiting_task_id(int id_task);
|
||||
|
||||
// This function blocks the thread until there is a response for this id_task
|
||||
server_task_result recv(int id_task);
|
||||
|
||||
// Register the function to update multitask
|
||||
void on_multitask_update(callback_multitask_t callback) {
|
||||
callback_update_multitask = std::move(callback);
|
||||
}
|
||||
|
||||
// Send a new result to a waiting id_task
|
||||
void send(server_task_result result);
|
||||
};
|
||||
816
examples/server/server-task.cpp
Normal file
816
examples/server/server-task.cpp
Normal file
@@ -0,0 +1,816 @@
|
||||
#include "server-task.h"
|
||||
|
||||
|
||||
json result_timings::to_json() const {
|
||||
json base = {
|
||||
{"prompt_n", prompt_n},
|
||||
{"prompt_ms", prompt_ms},
|
||||
{"prompt_per_token_ms", prompt_per_token_ms},
|
||||
{"prompt_per_second", prompt_per_second},
|
||||
|
||||
{"predicted_n", predicted_n},
|
||||
{"predicted_ms", predicted_ms},
|
||||
{"predicted_per_token_ms", predicted_per_token_ms},
|
||||
{"predicted_per_second", predicted_per_second},
|
||||
|
||||
{"n_ctx", n_ctx},
|
||||
{"n_past", n_past},
|
||||
};
|
||||
|
||||
if (draft_n > 0) {
|
||||
base["draft_n"] = draft_n;
|
||||
base["draft_n_accepted"] = draft_n_accepted;
|
||||
}
|
||||
|
||||
return base;
|
||||
}
|
||||
|
||||
|
||||
json server_task_result::to_json_final() {
|
||||
switch (oaicompat) {
|
||||
case OAICOMPAT_TYPE_NONE:
|
||||
return to_json_non_oaicompat_final();
|
||||
case OAICOMPAT_TYPE_COMPLETION:
|
||||
return to_json_oaicompat_final();
|
||||
case OAICOMPAT_TYPE_CHAT:
|
||||
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat_final();
|
||||
case OAICOMPAT_TYPE_ANTHROPIC:
|
||||
return stream ? to_json_anthropic_stream() : to_json_anthropic_final();
|
||||
default:
|
||||
GGML_ASSERT(false && "Invalid oaicompat_type");
|
||||
}
|
||||
}
|
||||
|
||||
json server_task_result::to_json_partial() {
|
||||
switch (oaicompat) {
|
||||
case OAICOMPAT_TYPE_NONE:
|
||||
return to_json_non_oaicompat_partial();
|
||||
case OAICOMPAT_TYPE_COMPLETION:
|
||||
return to_json_oaicompat_partial();
|
||||
case OAICOMPAT_TYPE_CHAT:
|
||||
return to_json_oaicompat_chat_partial();
|
||||
case OAICOMPAT_TYPE_ANTHROPIC:
|
||||
return to_json_anthropic_partial();
|
||||
default:
|
||||
GGML_ASSERT(false && "Invalid oaicompat_type");
|
||||
}
|
||||
}
|
||||
|
||||
json server_task_result::to_json_non_oaicompat_partial() {
|
||||
// non-OAI-compat JSON
|
||||
json res = json{
|
||||
{"index", index},
|
||||
{"content", content},
|
||||
{"tokens", tokens},
|
||||
{"stop", false},
|
||||
{"id_slot", id_multi},
|
||||
{"tokens_predicted", n_decoded},
|
||||
{"tokens_evaluated", n_prompt_tokens},
|
||||
};
|
||||
// populate the timings object when needed (usually for the last response or with timings_per_token enabled)
|
||||
if (timings.prompt_n > 0) {
|
||||
res.push_back({ "timings", timings.to_json() });
|
||||
}
|
||||
if (!probs_output.empty()) {
|
||||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_non_oaicompat_final() {
|
||||
json res = json{
|
||||
{"index", index},
|
||||
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
||||
{"tokens", stream ? std::vector<llama_token> {} : tokens},
|
||||
{"id_slot", id_multi},
|
||||
{"stop", true},
|
||||
{"model", oaicompat_model},
|
||||
{"tokens_predicted", n_decoded},
|
||||
{"tokens_evaluated", n_prompt_tokens},
|
||||
//{"generation_settings", default_generation_settings_for_props.to_json()},
|
||||
{"prompt", prompt},
|
||||
{"has_new_line", has_new_line},
|
||||
{"truncated", truncated},
|
||||
//{"stop_type", stop_type_to_str(STOP_TYPE_EOS)},
|
||||
{"stopping_word", stopping_word},
|
||||
{"tokens_cached", n_tokens_cached},
|
||||
{"timings", timings.to_json()},
|
||||
};
|
||||
if (!stream && !probs_output.empty()) {
|
||||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
||||
}
|
||||
return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
|
||||
}
|
||||
|
||||
json server_task_result::to_json_oaicompat_partial() {
|
||||
std::time_t t = std::time(0);
|
||||
json logprobs = json(nullptr); // OAI default to null
|
||||
if (probs_output.size() > 0) {
|
||||
logprobs = json{
|
||||
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
|
||||
};
|
||||
}
|
||||
json res = json{
|
||||
{"choices", json::array({
|
||||
json{
|
||||
{"text", content},
|
||||
{"index", index},
|
||||
{"logprobs", logprobs},
|
||||
{"finish_reason", nullptr},
|
||||
}
|
||||
})},
|
||||
{"created", t},
|
||||
{"model", oaicompat_model},
|
||||
{"object", "text_completion"},
|
||||
{"usage", json {
|
||||
{"completion_tokens", n_decoded},
|
||||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens}
|
||||
}},
|
||||
{"id", oaicompat_cmpl_id}
|
||||
};
|
||||
|
||||
// extra fields for debugging purposes
|
||||
if (verbose) {
|
||||
res["__verbose"] = to_json_non_oaicompat_partial();
|
||||
}
|
||||
if (timings.prompt_n >= 0) {
|
||||
res.push_back({ "timings", timings.to_json() });
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_oaicompat_final() {
|
||||
std::time_t t = std::time(0);
|
||||
json logprobs = json(nullptr); // OAI default to null
|
||||
if (!stream && probs_output.size() > 0) {
|
||||
logprobs = json{
|
||||
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
|
||||
};
|
||||
}
|
||||
json finish_reason = "length";
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
finish_reason = "stop";
|
||||
}
|
||||
json res = json{
|
||||
{"choices", json::array({
|
||||
json{
|
||||
{"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
||||
{"index", index},
|
||||
{"logprobs", logprobs},
|
||||
{"finish_reason", finish_reason},
|
||||
}
|
||||
})},
|
||||
{"created", t},
|
||||
{"model", oaicompat_model},
|
||||
{"object", "text_completion"},
|
||||
{"usage", json {
|
||||
{"completion_tokens", n_decoded},
|
||||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens}
|
||||
}},
|
||||
{"id", oaicompat_cmpl_id}
|
||||
};
|
||||
|
||||
// extra fields for debugging purposes
|
||||
if (verbose) {
|
||||
res["__verbose"] = to_json_non_oaicompat_final();
|
||||
}
|
||||
if (timings.prompt_n >= 0) {
|
||||
res.push_back({ "timings", timings.to_json() });
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_oaicompat_chat_partial() {
|
||||
bool first = n_decoded == 1;
|
||||
std::time_t t = std::time(0);
|
||||
json choices;
|
||||
|
||||
std::vector<json> deltas;
|
||||
auto add_delta = [&](const json& delta) {
|
||||
deltas.push_back({
|
||||
{"choices", json::array({
|
||||
json {
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", delta},
|
||||
},
|
||||
})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"object", "chat.completion.chunk"},
|
||||
{"usage", json {
|
||||
{"completion_tokens", n_decoded},
|
||||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens},
|
||||
}},
|
||||
});
|
||||
};
|
||||
// We have to send an initial update to conform to openai behavior
|
||||
if (first) {
|
||||
add_delta({
|
||||
{"role", "assistant"},
|
||||
{"content", nullptr},
|
||||
});
|
||||
}
|
||||
|
||||
for (const auto& diff : oaicompat_msg_diffs) {
|
||||
add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
|
||||
}
|
||||
|
||||
if (!deltas.empty()) {
|
||||
GGML_ASSERT(deltas[deltas.size() - 1].at("choices").size() >= 1);
|
||||
|
||||
if (probs_output.size() > 0) {
|
||||
deltas[deltas.size() - 1].at("choices").at(0)["logprobs"] = json{
|
||||
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
|
||||
};
|
||||
}
|
||||
|
||||
if (timings.prompt_n >= 0) {
|
||||
deltas[deltas.size() - 1].push_back({ "timings", timings.to_json() });
|
||||
}
|
||||
}
|
||||
|
||||
return deltas;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_oaicompat_chat_final() {
|
||||
std::string finish_reason = "length";
|
||||
common_chat_msg msg;
|
||||
if (!oaicompat_msg.empty()) {
|
||||
msg = oaicompat_msg;
|
||||
}
|
||||
else {
|
||||
msg.role = "assistant";
|
||||
msg.content = content;
|
||||
}
|
||||
if (stop) {
|
||||
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
|
||||
}
|
||||
|
||||
|
||||
json choice{
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"message", msg.to_json_oaicompat<json>()},
|
||||
};
|
||||
|
||||
if (!stream && probs_output.size() > 0) {
|
||||
choice["logprobs"] = json{
|
||||
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
|
||||
};
|
||||
}
|
||||
|
||||
std::time_t t = std::time(0);
|
||||
|
||||
json res = json{
|
||||
{"choices", json::array({choice})},
|
||||
{"created", t},
|
||||
{"model", oaicompat_model},
|
||||
{"object", "chat.completion"},
|
||||
{"usage", json {
|
||||
{"completion_tokens", n_decoded},
|
||||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens}
|
||||
}},
|
||||
{"id", oaicompat_cmpl_id}
|
||||
};
|
||||
|
||||
// extra fields for debugging purposes
|
||||
if (verbose) {
|
||||
res["__verbose"] = to_json_non_oaicompat_final();
|
||||
}
|
||||
if (timings.prompt_n >= 0) {
|
||||
res.push_back({ "timings", timings.to_json() });
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_oaicompat_chat_stream() {
|
||||
std::time_t t = std::time(0);
|
||||
std::string finish_reason = "length";
|
||||
if (stop) {
|
||||
//if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls";
|
||||
}
|
||||
|
||||
json deltas = json::array();
|
||||
for (const auto& diff : oaicompat_msg_diffs) {
|
||||
deltas.push_back({
|
||||
{"choices", json::array({
|
||||
json {
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", common_chat_msg_diff_to_json_oaicompat<json>(diff)},
|
||||
},
|
||||
})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"object", "chat.completion.chunk"},
|
||||
});
|
||||
}
|
||||
|
||||
deltas.push_back({
|
||||
{"choices", json::array({
|
||||
json {
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()},
|
||||
},
|
||||
})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"object", "chat.completion.chunk"},
|
||||
});
|
||||
if (include_usage) {
|
||||
// OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
|
||||
// https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
|
||||
deltas.push_back({
|
||||
{"choices", json::array()},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"object", "chat.completion.chunk"},
|
||||
{"usage", json {
|
||||
{"completion_tokens", n_decoded},
|
||||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens},
|
||||
}},
|
||||
});
|
||||
}
|
||||
if (timings.prompt_n >= 0) {
|
||||
deltas.back().push_back({ "timings", timings.to_json() });
|
||||
}
|
||||
// extra fields for debugging purposes
|
||||
if (verbose && !deltas.empty()) {
|
||||
deltas.front()["__verbose"] = to_json_non_oaicompat_final();
|
||||
}
|
||||
|
||||
return deltas;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_anthropic_final() {
|
||||
std::string stop_reason = "max_tokens";
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use";
|
||||
}
|
||||
|
||||
json content_blocks = json::array();
|
||||
|
||||
common_chat_msg msg;
|
||||
if (!oaicompat_msg.empty()) {
|
||||
msg = oaicompat_msg;
|
||||
}
|
||||
else {
|
||||
msg.role = "assistant";
|
||||
msg.content = content;
|
||||
}
|
||||
|
||||
|
||||
if (!msg.content.empty()) {
|
||||
content_blocks.push_back({
|
||||
{"type", "text"},
|
||||
{"text", msg.content}
|
||||
});
|
||||
}
|
||||
|
||||
for (const auto& tool_call : msg.tool_calls) {
|
||||
json tool_use_block = {
|
||||
{"type", "tool_use"},
|
||||
{"id", tool_call.id},
|
||||
{"name", tool_call.name}
|
||||
};
|
||||
|
||||
try {
|
||||
tool_use_block["input"] = json::parse(tool_call.arguments);
|
||||
}
|
||||
catch (const std::exception&) {
|
||||
tool_use_block["input"] = json::object();
|
||||
}
|
||||
|
||||
content_blocks.push_back(tool_use_block);
|
||||
}
|
||||
|
||||
json res = {
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"type", "message"},
|
||||
{"role", "assistant"},
|
||||
{"content", content_blocks},
|
||||
{"model", oaicompat_model},
|
||||
{"stop_reason", stop_reason},
|
||||
{"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)},
|
||||
{"usage", {
|
||||
{"input_tokens", n_prompt_tokens},
|
||||
{"output_tokens", n_decoded}
|
||||
}}
|
||||
};
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_anthropic_stream() {
|
||||
json events = json::array();
|
||||
|
||||
std::string stop_reason = "max_tokens";
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use";
|
||||
}
|
||||
|
||||
bool has_text = !oaicompat_msg.content.empty();
|
||||
size_t num_tool_calls = oaicompat_msg.tool_calls.size();
|
||||
|
||||
bool text_block_started = false;
|
||||
std::set<size_t> tool_calls_started;
|
||||
|
||||
for (const auto& diff : oaicompat_msg_diffs) {
|
||||
|
||||
if (!diff.content_delta.empty()) {
|
||||
if (!text_block_started) {
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", 0},
|
||||
{"content_block", {
|
||||
{"type", "text"},
|
||||
{"text", ""}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
text_block_started = true;
|
||||
}
|
||||
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", 0},
|
||||
{"delta", {
|
||||
{"type", "text_delta"},
|
||||
{"text", diff.content_delta}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
if (diff.tool_call_index != std::string::npos) {
|
||||
size_t content_block_index = (has_text ? 1 : 0) + diff.tool_call_index;
|
||||
|
||||
if (tool_calls_started.find(diff.tool_call_index) == tool_calls_started.end()) {
|
||||
const auto& full_tool_call = oaicompat_msg.tool_calls[diff.tool_call_index];
|
||||
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", content_block_index},
|
||||
{"content_block", {
|
||||
{"type", "tool_use"},
|
||||
{"id", full_tool_call.id},
|
||||
{"name", full_tool_call.name}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
tool_calls_started.insert(diff.tool_call_index);
|
||||
}
|
||||
|
||||
if (!diff.tool_call_delta.arguments.empty()) {
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", content_block_index},
|
||||
{"delta", {
|
||||
{"type", "input_json_delta"},
|
||||
{"partial_json", diff.tool_call_delta.arguments}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (has_text) {
|
||||
events.push_back({
|
||||
{"event", "content_block_stop"},
|
||||
{"data", {
|
||||
{"type", "content_block_stop"},
|
||||
{"index", 0}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_tool_calls; i++) {
|
||||
size_t content_block_index = (has_text ? 1 : 0) + i;
|
||||
events.push_back({
|
||||
{"event", "content_block_stop"},
|
||||
{"data", {
|
||||
{"type", "content_block_stop"},
|
||||
{"index", content_block_index}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
events.push_back({
|
||||
{"event", "message_delta"},
|
||||
{"data", {
|
||||
{"type", "message_delta"},
|
||||
{"delta", {
|
||||
{"stop_reason", stop_reason},
|
||||
{"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)}
|
||||
}},
|
||||
{"usage", {
|
||||
{"output_tokens", n_decoded}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
|
||||
events.push_back({
|
||||
{"event", "message_stop"},
|
||||
{"data", {
|
||||
{"type", "message_stop"}
|
||||
}}
|
||||
});
|
||||
|
||||
// extra fields for debugging purposes
|
||||
if (verbose && !events.empty()) {
|
||||
events.front()["data"]["__verbose"] = to_json_non_oaicompat_final();
|
||||
}
|
||||
// Don't add timings for Anthropic API (breaks spec compliance)
|
||||
if (oaicompat != OAICOMPAT_TYPE_ANTHROPIC && timings.prompt_n >= 0 && !events.empty()) {
|
||||
events.back()["data"]["timings"] = timings.to_json();
|
||||
}
|
||||
|
||||
return events;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_anthropic_partial() {
|
||||
json events = json::array();
|
||||
bool first = n_decoded == 1;
|
||||
static bool text_block_started = false;
|
||||
|
||||
if (first) {
|
||||
text_block_started = false;
|
||||
|
||||
events.push_back({
|
||||
{"event", "message_start"},
|
||||
{"data", {
|
||||
{"type", "message_start"},
|
||||
{"message", {
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"type", "message"},
|
||||
{"role", "assistant"},
|
||||
{"content", json::array()},
|
||||
{"model", oaicompat_model},
|
||||
{"stop_reason", nullptr},
|
||||
{"stop_sequence", nullptr},
|
||||
{"usage", {
|
||||
{"input_tokens", n_prompt_tokens},
|
||||
{"output_tokens", 0}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
for (const auto& diff : oaicompat_msg_diffs) {
|
||||
if (!diff.content_delta.empty()) {
|
||||
if (!text_block_started) {
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", 0},
|
||||
{"content_block", {
|
||||
{"type", "text"},
|
||||
{"text", ""}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
text_block_started = true;
|
||||
}
|
||||
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", 0},
|
||||
{"delta", {
|
||||
{"type", "text_delta"},
|
||||
{"text", diff.content_delta}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
if (diff.tool_call_index != std::string::npos) {
|
||||
size_t content_block_index = (text_block_started ? 1 : 0) + diff.tool_call_index;
|
||||
|
||||
if (!diff.tool_call_delta.name.empty()) {
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", content_block_index},
|
||||
{"content_block", {
|
||||
{"type", "tool_use"},
|
||||
{"id", diff.tool_call_delta.id},
|
||||
{"name", diff.tool_call_delta.name}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
if (!diff.tool_call_delta.arguments.empty()) {
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", content_block_index},
|
||||
{"delta", {
|
||||
{"type", "input_json_delta"},
|
||||
{"partial_json", diff.tool_call_delta.arguments}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (verbose && !events.empty() && first) {
|
||||
events.front()["data"]["__verbose"] = to_json_non_oaicompat_partial();
|
||||
}
|
||||
|
||||
if (timings.prompt_n >= 0 && !events.empty()) {
|
||||
events.back()["data"]["timings"] = timings.to_json();
|
||||
}
|
||||
|
||||
//if (is_progress && !events.empty()) {
|
||||
// events.back()["data"]["prompt_progress"] = progress.to_json();
|
||||
//}
|
||||
|
||||
return events;
|
||||
}
|
||||
|
||||
|
||||
size_t server_prompt::size() const {
|
||||
size_t res = data.size();
|
||||
|
||||
for (const auto& checkpoint : checkpoints) {
|
||||
res += checkpoint.size();
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
size_t server_prompt_cache::size() const {
|
||||
size_t res = 0;
|
||||
|
||||
for (const auto& state : states) {
|
||||
res += state.size();
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
size_t server_prompt_cache::n_tokens() const {
|
||||
size_t res = 0;
|
||||
|
||||
for (const auto& state : states) {
|
||||
res += state.n_tokens();
|
||||
}
|
||||
return res;
|
||||
|
||||
}
|
||||
|
||||
bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot) {
|
||||
const auto lcp_best = prompt.tokens.get_common_prefix(ctx, tokens_new);
|
||||
|
||||
float f_keep_best = float(lcp_best.second) / prompt.tokens.size();
|
||||
float sim_best = prompt.tokens.get_tokens_similarity(ctx, tokens_new, prompt.n_kept_prompt, prompt.n_discarded_prompt);
|
||||
LLAMA_LOG_INFO(" - looking for better prompt, base f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, prompt.n_kept_prompt, prompt.n_discarded_prompt);
|
||||
|
||||
auto it_best = states.end();
|
||||
|
||||
// find the most similar cached prompt, that would also preserve the most context
|
||||
for (auto it = states.begin(); it != states.end(); ++it) {
|
||||
const auto lcp_cur = it->tokens.get_common_prefix(ctx, tokens_new);
|
||||
const float f_keep_cur = float(lcp_cur.first) / it->tokens.size();
|
||||
const float sim_cur = it->tokens.get_tokens_similarity(ctx, tokens_new, it->n_kept_prompt, it->n_discarded_prompt);
|
||||
if (sim_best < sim_cur) {
|
||||
f_keep_best = f_keep_cur;
|
||||
sim_best = sim_cur;
|
||||
it_best = it;
|
||||
}
|
||||
}
|
||||
|
||||
if (it_best != states.end()) {
|
||||
LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, it_best->n_kept_prompt, it_best->n_discarded_prompt);
|
||||
const size_t size = it_best->data.size();
|
||||
const size_t n = llama_state_seq_set_data(ctx, it_best->data.data(), size, id_slot);
|
||||
if (n != size) {
|
||||
LLAMA_LOG_INFO("failed to restore state with size %zu\n", size);
|
||||
return false;
|
||||
}
|
||||
|
||||
it_best->data.clear();
|
||||
it_best->data.shrink_to_fit();
|
||||
|
||||
prompt = std::move(*it_best);
|
||||
|
||||
states.erase(it_best);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
server_prompt* server_prompt_cache::alloc(const server_prompt& prompt, size_t state_size) {
|
||||
for (auto it = states.begin(); it != states.end();) {
|
||||
auto tokens_ctx_shift = server_tokens(prompt.tokens.get_text_tokens(), false); // copy cache tokens
|
||||
tokens_ctx_shift.discard_n_tokens(prompt.n_kept_prompt, prompt.n_discarded_prompt);
|
||||
auto prefix = it->tokens.get_common_prefix(ctx, tokens_ctx_shift);
|
||||
const size_t len = prefix.first;
|
||||
const size_t len_prompt = prefix.second;
|
||||
// first check if the current state is contained fully in the cache
|
||||
if (len_prompt == tokens_ctx_shift.size()) {
|
||||
LLAMA_LOG_INFO("%s", " - prompt is already in the cache, skipping\n");
|
||||
return nullptr;
|
||||
}
|
||||
// next, remove any cached prompts that are fully contained in the current prompt
|
||||
else if (len == it->tokens.size()) {
|
||||
LLAMA_LOG_INFO(" - removing obsolete cached prompt with length %d\n", (int)len);
|
||||
it = states.erase(it);
|
||||
}
|
||||
else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uint8_t> state_data;
|
||||
|
||||
// check if we can allocate enough memory for the new state
|
||||
try {
|
||||
state_data.resize(state_size);
|
||||
}
|
||||
catch (const std::bad_alloc& e) {
|
||||
LLAMA_LOG_INFO("failed to allocate memory for prompt cache state: %s\n", e.what());
|
||||
|
||||
limit_size = std::max<size_t>(1, 0.4 * size());
|
||||
|
||||
LLAMA_LOG_INFO(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0));
|
||||
|
||||
update();
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// TODO: for some reason we can't copy server_tokens, so we have to do this workaround
|
||||
auto& cur = states.emplace_back();
|
||||
cur = {
|
||||
/*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
|
||||
/*.n_keep =*/ prompt.n_kept_prompt,
|
||||
/*.n_discarded_prompt =*/ prompt.n_discarded_prompt,
|
||||
/*.data =*/ std::move(state_data),
|
||||
/*.checkpoints =*/ prompt.checkpoints,
|
||||
};
|
||||
|
||||
return &cur;
|
||||
}
|
||||
|
||||
|
||||
void server_prompt_cache::update() {
|
||||
if (limit_size > 0) {
|
||||
// always keep at least one state, regardless of the limits
|
||||
while (states.size() > 1 && size() > limit_size) {
|
||||
if (states.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
|
||||
|
||||
states.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
// average size per token
|
||||
const float size_per_token = std::max<float>(1.0f, float(size()) / (std::max<size_t>(1, n_tokens())));
|
||||
|
||||
// dynamically increase the token limit if it can fit in the memory limit
|
||||
const size_t limit_tokens_cur = limit_size > 0 ? std::max<size_t>(limit_tokens, limit_size / size_per_token) : limit_tokens;
|
||||
|
||||
LLAMA_LOG_INFO(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n",
|
||||
states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur);
|
||||
|
||||
for (const auto& state : states) {
|
||||
LLAMA_LOG_INFO(" - prompt %p: %7d tokens, %7d discarded, checkpoints: %2zu, %9.3f MiB\n",
|
||||
(const void*)&state, state.n_tokens(), state.n_discarded_prompt, state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
|
||||
}
|
||||
}
|
||||
216
examples/server/server-task.h
Normal file
216
examples/server/server-task.h
Normal file
@@ -0,0 +1,216 @@
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <list>
|
||||
// TODO: prevent including the whole server-common.h as we only use server_tokens
|
||||
#include "server-common.h"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
enum stop_type {
|
||||
STOP_TYPE_NONE,
|
||||
STOP_TYPE_EOS,
|
||||
STOP_TYPE_WORD,
|
||||
STOP_TYPE_LIMIT,
|
||||
};
|
||||
|
||||
|
||||
|
||||
enum server_task_type {
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
SERVER_TASK_TYPE_EMBEDDING,
|
||||
SERVER_TASK_TYPE_RERANK,
|
||||
SERVER_TASK_TYPE_INFILL,
|
||||
SERVER_TASK_TYPE_CANCEL,
|
||||
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
||||
SERVER_TASK_TYPE_METRICS,
|
||||
SERVER_TASK_TYPE_SLOT_SAVE,
|
||||
SERVER_TASK_TYPE_SLOT_RESTORE,
|
||||
SERVER_TASK_TYPE_SLOT_ERASE,
|
||||
SERVER_TASK_TYPE_SET_LORA,
|
||||
};
|
||||
|
||||
enum oaicompat_type {
|
||||
OAICOMPAT_TYPE_NONE,
|
||||
OAICOMPAT_TYPE_CHAT,
|
||||
OAICOMPAT_TYPE_COMPLETION,
|
||||
OAICOMPAT_TYPE_EMBEDDING,
|
||||
OAICOMPAT_TYPE_ANTHROPIC,
|
||||
};
|
||||
|
||||
|
||||
struct server_task {
|
||||
int id = -1; // to be filled by server_queue
|
||||
int id_multi = -1;
|
||||
int id_target = -1;
|
||||
//int id_slot = -1;
|
||||
|
||||
// used by SERVER_TASK_TYPE_INFERENCE
|
||||
server_tokens tokens;
|
||||
|
||||
server_task_type type;
|
||||
json data;
|
||||
|
||||
bool infill = false;
|
||||
bool embedding = false;
|
||||
|
||||
server_task() = default;
|
||||
server_task(server_task_type type) : type(type) {}
|
||||
|
||||
};
|
||||
|
||||
struct result_timings {
|
||||
int32_t prompt_n = -1;
|
||||
double prompt_ms;
|
||||
double prompt_per_token_ms;
|
||||
double prompt_per_second;
|
||||
|
||||
int32_t predicted_n = -1;
|
||||
double predicted_ms;
|
||||
double predicted_per_token_ms;
|
||||
double predicted_per_second;
|
||||
int32_t n_ctx = 0;
|
||||
int32_t n_past = 0;
|
||||
|
||||
// Optional speculative metrics - only included when > 0
|
||||
int32_t draft_n = 0;
|
||||
int32_t draft_n_accepted = 0;
|
||||
|
||||
json to_json() const;
|
||||
};
|
||||
|
||||
struct server_task_result {
|
||||
int id = -1;
|
||||
int id_multi = -1;
|
||||
|
||||
json data;
|
||||
|
||||
bool stop;
|
||||
bool error;
|
||||
bool final_result = false;
|
||||
result_timings timings;
|
||||
// OAI-compat fields
|
||||
//bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
common_chat_msg oaicompat_msg;
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||
|
||||
int index = 0;
|
||||
|
||||
std::string content;
|
||||
std::vector<llama_token> tokens;
|
||||
|
||||
bool stream;
|
||||
bool include_usage;
|
||||
std::string prompt;
|
||||
//slot_params generation_params;
|
||||
|
||||
bool truncated;
|
||||
int32_t n_decoded;
|
||||
int32_t n_prompt_tokens;
|
||||
int32_t n_tokens_cached;
|
||||
bool has_new_line;
|
||||
std::string stopping_word;
|
||||
|
||||
bool post_sampling_probs = false;
|
||||
std::vector<completion_token_output> probs_output;
|
||||
std::vector<std::string> response_fields;
|
||||
|
||||
//slot_params generation_params;
|
||||
|
||||
bool verbose = false;
|
||||
|
||||
|
||||
int get_index() {
|
||||
return index;
|
||||
}
|
||||
|
||||
bool is_stop() {
|
||||
return true; // in stream mode, final responses are considered stop
|
||||
}
|
||||
|
||||
json to_json_final();
|
||||
|
||||
json to_json_partial();
|
||||
|
||||
json to_json_non_oaicompat_partial();
|
||||
|
||||
json to_json_non_oaicompat_final();
|
||||
|
||||
json to_json_oaicompat_partial();
|
||||
|
||||
json to_json_oaicompat_final();
|
||||
|
||||
json to_json_oaicompat_chat_partial();
|
||||
|
||||
json to_json_oaicompat_chat_final();
|
||||
|
||||
json to_json_oaicompat_chat_stream();
|
||||
|
||||
json to_json_anthropic_final();
|
||||
|
||||
json to_json_anthropic_stream();
|
||||
|
||||
json to_json_anthropic_partial();
|
||||
};
|
||||
|
||||
|
||||
struct server_prompt_checkpoint {
|
||||
llama_pos pos_min;
|
||||
llama_pos pos_max;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
|
||||
size_t size() const {
|
||||
return data.size();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct server_prompt {
|
||||
server_tokens tokens;
|
||||
int n_kept_prompt;
|
||||
int n_discarded_prompt;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
|
||||
std::list<server_prompt_checkpoint> checkpoints;
|
||||
|
||||
size_t size() const;
|
||||
|
||||
int n_tokens() const {
|
||||
return tokens.size();
|
||||
}
|
||||
};
|
||||
|
||||
struct server_prompt_cache {
|
||||
server_prompt_cache(llama_context* ctx, int32_t limit_size_mib, size_t limit_tokens) {
|
||||
this->ctx = ctx;
|
||||
this->limit_size = 1024ull * 1024ull * (limit_size_mib < 0 ? 0 : limit_size_mib);
|
||||
this->limit_tokens = limit_tokens;
|
||||
}
|
||||
|
||||
std::list<server_prompt> states;
|
||||
|
||||
// in bytes, 0 = no limit
|
||||
size_t limit_size = 0;
|
||||
|
||||
// in tokens, 0 = no limit
|
||||
size_t limit_tokens = 0;
|
||||
llama_context* ctx;
|
||||
size_t size() const;
|
||||
|
||||
size_t n_tokens() const;
|
||||
|
||||
server_prompt* alloc(const server_prompt& prompt, size_t state_size);
|
||||
|
||||
bool load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot);
|
||||
|
||||
void update();
|
||||
};
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user