mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
server : add Anthropic Messages API support (#1012)
This commit is contained in:
@@ -96,6 +96,7 @@ enum oaicompat_type {
|
||||
OAICOMPAT_TYPE_CHAT,
|
||||
OAICOMPAT_TYPE_COMPLETION,
|
||||
OAICOMPAT_TYPE_EMBEDDING,
|
||||
OAICOMPAT_TYPE_ANTHROPIC,
|
||||
};
|
||||
|
||||
struct result_timings {
|
||||
@@ -221,6 +222,8 @@ struct server_task_result {
|
||||
return to_json_oaicompat_final();
|
||||
case OAICOMPAT_TYPE_CHAT:
|
||||
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat_final();
|
||||
case OAICOMPAT_TYPE_ANTHROPIC:
|
||||
return stream ? to_json_anthropic_stream() : to_json_anthropic_final();
|
||||
default:
|
||||
GGML_ASSERT(false && "Invalid oaicompat_type");
|
||||
}
|
||||
@@ -233,7 +236,9 @@ struct server_task_result {
|
||||
case OAICOMPAT_TYPE_COMPLETION:
|
||||
return to_json_oaicompat_partial();
|
||||
case OAICOMPAT_TYPE_CHAT:
|
||||
return to_json_oaicompat_chat_partial();
|
||||
return to_json_oaicompat_chat_partial();
|
||||
case OAICOMPAT_TYPE_ANTHROPIC:
|
||||
return to_json_anthropic_partial();
|
||||
default:
|
||||
GGML_ASSERT(false && "Invalid oaicompat_type");
|
||||
}
|
||||
@@ -539,6 +544,307 @@ struct server_task_result {
|
||||
|
||||
return deltas;
|
||||
}
|
||||
|
||||
json to_json_anthropic_final() {
|
||||
std::string stop_reason = "max_tokens";
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use";
|
||||
}
|
||||
|
||||
json content_blocks = json::array();
|
||||
|
||||
common_chat_msg msg;
|
||||
if (!oaicompat_msg.empty()) {
|
||||
msg = oaicompat_msg;
|
||||
} else {
|
||||
msg.role = "assistant";
|
||||
msg.content = content;
|
||||
}
|
||||
|
||||
|
||||
if (!msg.content.empty()) {
|
||||
content_blocks.push_back({
|
||||
{"type", "text"},
|
||||
{"text", msg.content}
|
||||
});
|
||||
}
|
||||
|
||||
for (const auto & tool_call : msg.tool_calls) {
|
||||
json tool_use_block = {
|
||||
{"type", "tool_use"},
|
||||
{"id", tool_call.id},
|
||||
{"name", tool_call.name}
|
||||
};
|
||||
|
||||
try {
|
||||
tool_use_block["input"] = json::parse(tool_call.arguments);
|
||||
} catch (const std::exception &) {
|
||||
tool_use_block["input"] = json::object();
|
||||
}
|
||||
|
||||
content_blocks.push_back(tool_use_block);
|
||||
}
|
||||
|
||||
json res = {
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"type", "message"},
|
||||
{"role", "assistant"},
|
||||
{"content", content_blocks},
|
||||
{"model", oaicompat_model},
|
||||
{"stop_reason", stop_reason},
|
||||
{"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)},
|
||||
{"usage", {
|
||||
{"input_tokens", n_prompt_tokens},
|
||||
{"output_tokens", n_decoded}
|
||||
}}
|
||||
};
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
json to_json_anthropic_stream() {
|
||||
json events = json::array();
|
||||
|
||||
std::string stop_reason = "max_tokens";
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use";
|
||||
}
|
||||
|
||||
bool has_text = !oaicompat_msg.content.empty();
|
||||
size_t num_tool_calls = oaicompat_msg.tool_calls.size();
|
||||
|
||||
bool text_block_started = false;
|
||||
std::set<size_t> tool_calls_started;
|
||||
|
||||
for (const auto & diff : oaicompat_msg_diffs) {
|
||||
|
||||
if (!diff.content_delta.empty()) {
|
||||
if (!text_block_started) {
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", 0},
|
||||
{"content_block", {
|
||||
{"type", "text"},
|
||||
{"text", ""}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
text_block_started = true;
|
||||
}
|
||||
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", 0},
|
||||
{"delta", {
|
||||
{"type", "text_delta"},
|
||||
{"text", diff.content_delta}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
if (diff.tool_call_index != std::string::npos) {
|
||||
size_t content_block_index = (has_text ? 1 : 0) + diff.tool_call_index;
|
||||
|
||||
if (tool_calls_started.find(diff.tool_call_index) == tool_calls_started.end()) {
|
||||
const auto & full_tool_call = oaicompat_msg.tool_calls[diff.tool_call_index];
|
||||
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", content_block_index},
|
||||
{"content_block", {
|
||||
{"type", "tool_use"},
|
||||
{"id", full_tool_call.id},
|
||||
{"name", full_tool_call.name}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
tool_calls_started.insert(diff.tool_call_index);
|
||||
}
|
||||
|
||||
if (!diff.tool_call_delta.arguments.empty()) {
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", content_block_index},
|
||||
{"delta", {
|
||||
{"type", "input_json_delta"},
|
||||
{"partial_json", diff.tool_call_delta.arguments}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (has_text) {
|
||||
events.push_back({
|
||||
{"event", "content_block_stop"},
|
||||
{"data", {
|
||||
{"type", "content_block_stop"},
|
||||
{"index", 0}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_tool_calls; i++) {
|
||||
size_t content_block_index = (has_text ? 1 : 0) + i;
|
||||
events.push_back({
|
||||
{"event", "content_block_stop"},
|
||||
{"data", {
|
||||
{"type", "content_block_stop"},
|
||||
{"index", content_block_index}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
events.push_back({
|
||||
{"event", "message_delta"},
|
||||
{"data", {
|
||||
{"type", "message_delta"},
|
||||
{"delta", {
|
||||
{"stop_reason", stop_reason},
|
||||
{"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)}
|
||||
}},
|
||||
{"usage", {
|
||||
{"output_tokens", n_decoded}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
|
||||
events.push_back({
|
||||
{"event", "message_stop"},
|
||||
{"data", {
|
||||
{"type", "message_stop"}
|
||||
}}
|
||||
});
|
||||
|
||||
// extra fields for debugging purposes
|
||||
if (verbose && !events.empty()) {
|
||||
events.front()["data"]["__verbose"] = to_json_non_oaicompat_final();
|
||||
}
|
||||
// Don't add timings for Anthropic API (breaks spec compliance)
|
||||
if (oaicompat != OAICOMPAT_TYPE_ANTHROPIC && timings.prompt_n >= 0 && !events.empty()) {
|
||||
events.back()["data"]["timings"] = timings.to_json();
|
||||
}
|
||||
|
||||
return events;
|
||||
}
|
||||
|
||||
json to_json_anthropic_partial() {
|
||||
json events = json::array();
|
||||
bool first = n_decoded == 1;
|
||||
static bool text_block_started = false;
|
||||
|
||||
if (first) {
|
||||
text_block_started = false;
|
||||
|
||||
events.push_back({
|
||||
{"event", "message_start"},
|
||||
{"data", {
|
||||
{"type", "message_start"},
|
||||
{"message", {
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"type", "message"},
|
||||
{"role", "assistant"},
|
||||
{"content", json::array()},
|
||||
{"model", oaicompat_model},
|
||||
{"stop_reason", nullptr},
|
||||
{"stop_sequence", nullptr},
|
||||
{"usage", {
|
||||
{"input_tokens", n_prompt_tokens},
|
||||
{"output_tokens", 0}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
for (const auto & diff : oaicompat_msg_diffs) {
|
||||
if (!diff.content_delta.empty()) {
|
||||
if (!text_block_started) {
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", 0},
|
||||
{"content_block", {
|
||||
{"type", "text"},
|
||||
{"text", ""}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
text_block_started = true;
|
||||
}
|
||||
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", 0},
|
||||
{"delta", {
|
||||
{"type", "text_delta"},
|
||||
{"text", diff.content_delta}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
if (diff.tool_call_index != std::string::npos) {
|
||||
size_t content_block_index = (text_block_started ? 1 : 0) + diff.tool_call_index;
|
||||
|
||||
if (!diff.tool_call_delta.name.empty()) {
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", content_block_index},
|
||||
{"content_block", {
|
||||
{"type", "tool_use"},
|
||||
{"id", diff.tool_call_delta.id},
|
||||
{"name", diff.tool_call_delta.name}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
if (!diff.tool_call_delta.arguments.empty()) {
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", content_block_index},
|
||||
{"delta", {
|
||||
{"type", "input_json_delta"},
|
||||
{"partial_json", diff.tool_call_delta.arguments}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (verbose && !events.empty() && first) {
|
||||
events.front()["data"]["__verbose"] = to_json_non_oaicompat_partial();
|
||||
}
|
||||
|
||||
if (timings.prompt_n >= 0 && !events.empty()) {
|
||||
events.back()["data"]["timings"] = timings.to_json();
|
||||
}
|
||||
|
||||
//if (is_progress && !events.empty()) {
|
||||
// events.back()["data"]["prompt_progress"] = progress.to_json();
|
||||
//}
|
||||
|
||||
return events;
|
||||
}
|
||||
};
|
||||
|
||||
static inline std::string stop_type_to_str(stop_type type) {
|
||||
@@ -4380,20 +4686,12 @@ int main(int argc, char ** argv) {
|
||||
//
|
||||
|
||||
auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
// TODO: should we apply API key to all endpoints, including "/health" and "/models"?
|
||||
static const std::set<std::string> protected_endpoints = {
|
||||
"/props",
|
||||
"/completion",
|
||||
"/completions",
|
||||
"/v1/completions",
|
||||
"/chat/completions",
|
||||
"/v1/chat/completions",
|
||||
"/infill",
|
||||
"/tokenize",
|
||||
"/detokenize",
|
||||
"/embedding",
|
||||
"/embeddings",
|
||||
"/v1/embeddings",
|
||||
static const std::unordered_set<std::string> public_endpoints = {
|
||||
"/health",
|
||||
"/v1/health",
|
||||
"/models",
|
||||
"/v1/models",
|
||||
"/api/tags"
|
||||
};
|
||||
|
||||
// If API key is not set, skip validation
|
||||
@@ -4401,8 +4699,8 @@ int main(int argc, char ** argv) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// If path is not in protected_endpoints list, skip validation
|
||||
if (protected_endpoints.find(req.path) == protected_endpoints.end()) {
|
||||
// If path is public or is static file, skip validation
|
||||
if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -4417,11 +4715,25 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
auth_header = req.get_header_value("X-Api-Key");
|
||||
|
||||
if (std::find(params.api_keys.begin(), params.api_keys.end(), auth_header) != params.api_keys.end()) {
|
||||
return true; // API key is valid
|
||||
}
|
||||
|
||||
// API key is invalid or not provided
|
||||
res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
|
||||
|
||||
LOG_WARNING("Unauthorized: Invalid API Key", {});
|
||||
|
||||
res.status = 401;
|
||||
res.set_content(
|
||||
(json {
|
||||
{"error", {
|
||||
{"message", "Invalid API Key"},
|
||||
{"type", "authentication_error"},
|
||||
{"code", 401}
|
||||
}}
|
||||
}).dump(-1, ' ', false, json::error_handler_t::replace),
|
||||
"application/json; charset=utf-8"
|
||||
);
|
||||
LOG_WARNING("Unauthorized: Invalid API Key\n", {});
|
||||
return false;
|
||||
};
|
||||
|
||||
@@ -4874,6 +5186,13 @@ int main(int argc, char ** argv) {
|
||||
else {
|
||||
const auto chunked_content_provider = [id_task, &ctx_server, completion_id, oaicompat, send_done = params.send_done](size_t, httplib::DataSink& sink) {
|
||||
bool successful_completion = false;
|
||||
const auto sse = [oaicompat, &sink](const json &res) {
|
||||
if (oaicompat == OAICOMPAT_TYPE_ANTHROPIC) {
|
||||
return server_sent_anthropic_event(sink, res);
|
||||
} else {
|
||||
return server_sent_event(sink, res);
|
||||
}
|
||||
};
|
||||
while (true) {
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
if (!result.error) {
|
||||
@@ -4895,7 +5214,7 @@ int main(int argc, char ** argv) {
|
||||
if (res_json.is_array()) {
|
||||
// chat completions and oai completions
|
||||
for (const auto& res : res_json) {
|
||||
if (!server_sent_event(sink, res)) {
|
||||
if (!sse(res)) {
|
||||
// sending failed (HTTP connection closed), cancel the generation
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
return false;
|
||||
@@ -4908,7 +5227,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
else {
|
||||
// legacy completions
|
||||
if (!server_sent_event(sink, res_json)) {
|
||||
if (!sse(res_json)) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
return false;
|
||||
}
|
||||
@@ -4918,7 +5237,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (!server_sent_event(sink, result.data)) {
|
||||
if (!sse(result.data)) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
return false;
|
||||
}
|
||||
@@ -4926,7 +5245,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
bool ok = true;
|
||||
if (successful_completion) {
|
||||
if (successful_completion && oaicompat != OAICOMPAT_TYPE_ANTHROPIC && oaicompat != OAICOMPAT_TYPE_NONE) {
|
||||
static const std::string done_message = "data: [DONE]\n\n";
|
||||
LOG_VERBOSE("data stream", { {"to_send", done_message} });
|
||||
if (!sink.write(done_message.c_str(), done_message.size())) {
|
||||
@@ -5003,6 +5322,40 @@ int main(int argc, char ** argv) {
|
||||
OAICOMPAT_TYPE_CHAT);
|
||||
};
|
||||
|
||||
const auto handle_anthropic_messages = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
std::vector<raw_buffer> files;
|
||||
json body = json::parse(req.body);
|
||||
json body_parsed = anthropic_params_from_json(
|
||||
ctx_server.model,
|
||||
body,
|
||||
ctx_server.oai_parser_opt,
|
||||
files);
|
||||
return handle_completions_impl(
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
body_parsed,
|
||||
files,
|
||||
res,
|
||||
OAICOMPAT_TYPE_ANTHROPIC);
|
||||
};
|
||||
|
||||
const auto handle_anthropic_count_tokens = [&ctx_server, &handle_completions_impl, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
std::vector<raw_buffer> files;
|
||||
json body = json::parse(req.body);
|
||||
|
||||
// Parse the Anthropic request (max_tokens is not required for count_tokens)
|
||||
json body_parsed = anthropic_params_from_json(
|
||||
ctx_server.model,
|
||||
body,
|
||||
ctx_server.oai_parser_opt,
|
||||
files);
|
||||
|
||||
json prompt = body_parsed.at("prompt");
|
||||
llama_tokens tokens = tokenize_mixed(llama_get_vocab(ctx_server.ctx), prompt, true, true);
|
||||
|
||||
res_ok(res, {{"input_tokens", static_cast<int>(tokens.size())}});
|
||||
return res;
|
||||
};
|
||||
|
||||
// same with handle_chat_completions, but without inference part
|
||||
const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request& req, httplib::Response& res) {
|
||||
auto body = json::parse(req.body);
|
||||
@@ -5554,6 +5907,8 @@ int main(int argc, char ** argv) {
|
||||
svr->Post("/v1/completions", handle_completions_oai);
|
||||
svr->Post("/chat/completions", handle_chat_completions);
|
||||
svr->Post("/v1/chat/completions", handle_chat_completions);
|
||||
svr->Post("/v1/messages", handle_anthropic_messages);
|
||||
svr->Post("/v1/messages/count_tokens", handle_anthropic_count_tokens);
|
||||
svr->Post("/infill", handle_infill);
|
||||
svr->Post("/embedding", handle_embeddings); // legacy
|
||||
svr->Post("/embeddings", handle_embeddings);
|
||||
|
||||
@@ -597,6 +597,18 @@ static bool server_sent_event(httplib::DataSink& sink, const json& data) {
|
||||
return sink.write(str.c_str(), str.size());
|
||||
}
|
||||
|
||||
static bool server_sent_anthropic_event(httplib::DataSink& sink, const json& data) {
|
||||
const std::string str =
|
||||
(data.contains("event") && data.contains("data"))?
|
||||
("event: " + data.at("event").get<std::string>() + "\n" +
|
||||
"data: " + data.at("data").dump(-1, ' ', false, json::error_handler_t::replace) + "\n\n"):
|
||||
("data: " + data.at("data").dump(-1, ' ', false, json::error_handler_t::replace) + "\n\n");
|
||||
|
||||
LOG_VERBOSE("data stream, to_send: %s", str.c_str());
|
||||
|
||||
return sink.write(str.c_str(), str.size());
|
||||
}
|
||||
|
||||
//
|
||||
// OAI utils
|
||||
//
|
||||
@@ -946,6 +958,420 @@ static json oaicompat_chat_params_parse(
|
||||
return llama_params;
|
||||
}
|
||||
|
||||
static json anthropic_params_from_json(
|
||||
const struct llama_model* model,
|
||||
const json & body_in, /* anthropic messages api json semantics */
|
||||
const oaicompat_parser_options & opt,
|
||||
std::vector<raw_buffer> & out_files)
|
||||
{
|
||||
json body = body_in;
|
||||
json llama_params;
|
||||
|
||||
if (body.contains("stop_sequences")) {
|
||||
llama_params["stop"] = body.at("stop_sequences");
|
||||
} else {
|
||||
llama_params["stop"] = json::array();
|
||||
}
|
||||
|
||||
// handle max_tokens (required in Anthropic, but we're permissive)
|
||||
if (!body.contains("max_tokens")) {
|
||||
llama_params["n_predict"] = 4096;
|
||||
} else {
|
||||
llama_params["n_predict"] = body.at("max_tokens");
|
||||
}
|
||||
|
||||
if (body.contains("top_k")) {
|
||||
llama_params["top_k"] = body.at("top_k");
|
||||
}
|
||||
|
||||
if (body.contains("thinking")) {
|
||||
json thinking = json_value(body, "thinking", json::object());
|
||||
std::string thinking_type = json_value(thinking, "type", std::string());
|
||||
if (thinking_type == "enabled") {
|
||||
int budget_tokens = json_value(thinking, "budget_tokens", 10000);
|
||||
llama_params["thinking_budget_tokens"] = budget_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
if (body.contains("metadata")) {
|
||||
json metadata = json_value(body, "metadata", json::object());
|
||||
std::string user_id = json_value(metadata, "user_id", std::string());
|
||||
if (!user_id.empty()) {
|
||||
llama_params["__metadata_user_id"] = user_id;
|
||||
}
|
||||
}
|
||||
|
||||
json oai_messages = json::array();
|
||||
auto system_param = json_value(body, "system", json());
|
||||
if (!system_param.is_null()) {
|
||||
std::string system_content;
|
||||
|
||||
if (system_param.is_string()) {
|
||||
system_content = system_param.get<std::string>();
|
||||
} else if (system_param.is_array()) {
|
||||
for (const auto & block : system_param) {
|
||||
if (json_value(block, "type", std::string()) == "text") {
|
||||
system_content += json_value(block, "text", std::string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
oai_messages.push_back({
|
||||
{"role", "system"},
|
||||
{"content", system_content}
|
||||
});
|
||||
}
|
||||
|
||||
if (!body.contains("messages")) {
|
||||
throw std::runtime_error("'messages' is required");
|
||||
}
|
||||
json & messages = body.at("messages");
|
||||
if (!messages.is_array()) {
|
||||
throw std::runtime_error("Expected 'messages' to be an array");
|
||||
}
|
||||
|
||||
for (auto & msg : messages) {
|
||||
std::string role = json_value(msg, "role", std::string());
|
||||
if (role != "assistant" && !msg.contains("content")) {
|
||||
throw std::runtime_error("All non-assistant messages must contain 'content'");
|
||||
}
|
||||
if (role == "assistant") {
|
||||
if (!msg.contains("content")) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
json & content = msg.at("content");
|
||||
|
||||
if (content.is_string()) {
|
||||
oai_messages.push_back(msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!content.is_array()) {
|
||||
throw std::runtime_error("Expected 'content' to be a string or an array");
|
||||
}
|
||||
|
||||
json tool_calls = json::array();
|
||||
json converted_content = json::array();
|
||||
json tool_results = json::array();
|
||||
bool has_tool_calls = false;
|
||||
|
||||
for (auto & block : content) {
|
||||
std::string type = json_value(block, "type", std::string());
|
||||
|
||||
if (type == "text") {
|
||||
converted_content.push_back(block);
|
||||
} else if (type == "image") {
|
||||
json source = json_value(block, "source", json::object());
|
||||
std::string source_type = json_value(source, "type", std::string());
|
||||
|
||||
if (source_type == "base64") {
|
||||
std::string media_type = json_value(source, "media_type", std::string("image/jpeg"));
|
||||
std::string data = json_value(source, "data", std::string());
|
||||
|
||||
converted_content.push_back({
|
||||
{"type", "image_url"},
|
||||
{"image_url", {
|
||||
{"url", "data:" + media_type + ";base64," + data}
|
||||
}}
|
||||
});
|
||||
} else if (source_type == "url") {
|
||||
std::string url = json_value(source, "url", std::string());
|
||||
converted_content.push_back({
|
||||
{"type", "image_url"},
|
||||
{"image_url", {
|
||||
{"url", url}
|
||||
}}
|
||||
});
|
||||
}
|
||||
} else if (type == "tool_use") {
|
||||
tool_calls.push_back({
|
||||
{"id", json_value(block, "id", std::string())},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", json_value(block, "name", std::string())},
|
||||
{"arguments", json_value(block, "input", json::object()).dump()}
|
||||
}}
|
||||
});
|
||||
has_tool_calls = true;
|
||||
} else if (type == "tool_result") {
|
||||
std::string tool_use_id = json_value(block, "tool_use_id", std::string());
|
||||
|
||||
auto result_content = json_value(block, "content", json());
|
||||
std::string result_text;
|
||||
if (result_content.is_string()) {
|
||||
result_text = result_content.get<std::string>();
|
||||
} else if (result_content.is_array()) {
|
||||
for (const auto & c : result_content) {
|
||||
if (json_value(c, "type", std::string()) == "text") {
|
||||
result_text += json_value(c, "text", std::string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tool_results.push_back({
|
||||
{"role", "tool"},
|
||||
{"tool_call_id", tool_use_id},
|
||||
{"content", result_text}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (!tool_results.empty()) {
|
||||
if (!converted_content.empty() || has_tool_calls) {
|
||||
json new_msg = {{"role", role}};
|
||||
if (!converted_content.empty()) {
|
||||
new_msg["content"] = converted_content;
|
||||
} else if (has_tool_calls) {
|
||||
new_msg["content"] = "";
|
||||
}
|
||||
if (!tool_calls.empty()) {
|
||||
new_msg["tool_calls"] = tool_calls;
|
||||
}
|
||||
oai_messages.push_back(new_msg);
|
||||
}
|
||||
for (const auto & tool_msg : tool_results) {
|
||||
oai_messages.push_back(tool_msg);
|
||||
}
|
||||
} else {
|
||||
if (!converted_content.empty() || has_tool_calls) {
|
||||
json new_msg = {{"role", role}};
|
||||
if (!converted_content.empty()) {
|
||||
new_msg["content"] = converted_content;
|
||||
} else if (has_tool_calls) {
|
||||
new_msg["content"] = "";
|
||||
}
|
||||
if (!tool_calls.empty()) {
|
||||
new_msg["tool_calls"] = tool_calls;
|
||||
}
|
||||
oai_messages.push_back(new_msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
json oai_tools = json::array();
|
||||
if (body.contains("tools")) {
|
||||
json & tools = body.at("tools");
|
||||
if (tools.is_array()) {
|
||||
for (auto & tool : tools) {
|
||||
oai_tools.push_back({
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", json_value(tool, "name", std::string())},
|
||||
{"description", json_value(tool, "description", std::string())},
|
||||
{"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()}
|
||||
}}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string oai_tool_choice = "auto";
|
||||
if (body.contains("tool_choice")) {
|
||||
json & tc = body.at("tool_choice");
|
||||
if (tc.is_object()) {
|
||||
std::string type = json_value(tc, "type", std::string());
|
||||
if (type == "auto") {
|
||||
oai_tool_choice = "auto";
|
||||
} else if (type == "any") {
|
||||
oai_tool_choice = "required";
|
||||
} else if (type == "tool") {
|
||||
oai_tool_choice = "required";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto & msg : oai_messages) {
|
||||
if (!msg.contains("content")) {
|
||||
continue;
|
||||
}
|
||||
json & content = msg.at("content");
|
||||
if (content.is_string() || content.is_null()) {
|
||||
continue;
|
||||
}
|
||||
if (!content.is_array()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto & p : content) {
|
||||
std::string type = json_value(p, "type", std::string());
|
||||
if (type == "image_url") {
|
||||
if (!opt.allow_image) {
|
||||
throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
|
||||
}
|
||||
|
||||
json image_url = json_value(p, "image_url", json::object());
|
||||
std::string url = json_value(image_url, "url", std::string());
|
||||
if (string_starts_with(url, "http")) {
|
||||
// download remote image
|
||||
common_remote_params params;
|
||||
params.headers.push_back("User-Agent: ik_llama.cpp/");
|
||||
params.max_size = 1024 * 1024 * 10; // 10MB
|
||||
params.timeout = 10; // seconds
|
||||
LOG_INFO("downloading image from '%s'\n", url.c_str());
|
||||
auto res = common_remote_get_content(url, params);
|
||||
if (200 <= res.first && res.first < 300) {
|
||||
LOG_INFO("downloaded %ld bytes\n", res.second.size());
|
||||
raw_buffer data;
|
||||
data.insert(data.end(), res.second.begin(), res.second.end());
|
||||
out_files.push_back(data);
|
||||
} else {
|
||||
throw std::runtime_error("Failed to download image");
|
||||
}
|
||||
} else {
|
||||
// try to decode base64 image
|
||||
std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
|
||||
if (parts.size() != 2) {
|
||||
throw std::runtime_error("Invalid image_url.url value");
|
||||
} else if (!string_starts_with(parts[0], "data:image/")) {
|
||||
throw std::runtime_error("Invalid image_url.url format: " + parts[0]);
|
||||
} else if (!string_ends_with(parts[0], "base64")) {
|
||||
throw std::runtime_error("image_url.url must be base64 encoded");
|
||||
} else {
|
||||
auto base64_data = parts[1];
|
||||
auto decoded_data = base64_decode(base64_data);
|
||||
out_files.push_back(decoded_data);
|
||||
}
|
||||
}
|
||||
|
||||
// replace this chunk with a marker
|
||||
p["type"] = "text";
|
||||
p["text"] = mtmd_default_marker();
|
||||
p.erase("image_url");
|
||||
} else if (type == "input_audio") {
|
||||
if (!opt.allow_audio) {
|
||||
throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
|
||||
}
|
||||
|
||||
json input_audio = json_value(p, "input_audio", json::object());
|
||||
std::string data = json_value(input_audio, "data", std::string());
|
||||
std::string format = json_value(input_audio, "format", std::string());
|
||||
if (format != "wav" && format != "mp3") {
|
||||
throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'");
|
||||
}
|
||||
auto decoded_data = base64_decode(data);
|
||||
out_files.push_back(decoded_data);
|
||||
|
||||
// replace this chunk with a marker
|
||||
p["type"] = "text";
|
||||
p["text"] = mtmd_default_marker();
|
||||
p.erase("input_audio");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.messages = common_chat_msgs_parse_oaicompat(oai_messages);
|
||||
inputs.tools = common_chat_tools_parse_oaicompat(oai_tools);
|
||||
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(oai_tool_choice);
|
||||
inputs.json_schema = "";
|
||||
inputs.grammar = "";
|
||||
inputs.use_jinja = opt.use_jinja;
|
||||
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
|
||||
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
|
||||
inputs.reasoning_format = opt.reasoning_format;
|
||||
inputs.enable_thinking = opt.enable_thinking;
|
||||
|
||||
if (opt.enable_thinking && opt.prefill_assistant) {
|
||||
if (!inputs.messages.empty() && inputs.messages.back().role == "assistant") {
|
||||
inputs.enable_thinking = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
llama_params["parse_tool_calls"] = true;
|
||||
}
|
||||
|
||||
// merge the template args provided from command line with the args provided in the user request
|
||||
auto chat_template_kwargs_object = json_value(body, "chat_template_kwargs", json::object());
|
||||
inputs.chat_template_kwargs = opt.chat_template_kwargs;
|
||||
for (const auto & item : chat_template_kwargs_object.items()) {
|
||||
inputs.chat_template_kwargs[item.key()] = item.value().dump();
|
||||
}
|
||||
|
||||
// parse the "enable_thinking" kwarg to override the default value
|
||||
auto enable_thinking_kwarg = json_value(inputs.chat_template_kwargs, "enable_thinking", std::string(""));
|
||||
if (enable_thinking_kwarg == "true") {
|
||||
inputs.enable_thinking = true;
|
||||
} else if (enable_thinking_kwarg == "false") {
|
||||
inputs.enable_thinking = false;
|
||||
} else if (!enable_thinking_kwarg.empty() && enable_thinking_kwarg[0] == '"') {
|
||||
throw std::runtime_error("invalid type for \"enable_thinking\" (expected boolean, got string)");
|
||||
}
|
||||
|
||||
// if the assistant message appears at the end of list, we do not add end-of-turn token
|
||||
bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && opt.prefill_assistant;
|
||||
common_chat_msg last_message;
|
||||
if (prefill_assistant_message) {
|
||||
last_message = inputs.messages.back();
|
||||
inputs.messages.pop_back();
|
||||
|
||||
// sanity check, max one assistant message at the end of the list
|
||||
if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){
|
||||
throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list.");
|
||||
}
|
||||
|
||||
inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
|
||||
if (inputs.enable_thinking) {
|
||||
throw std::runtime_error("Assistant response prefill is incompatible with enable_thinking.");
|
||||
}
|
||||
|
||||
inputs.add_generation_prompt = true;
|
||||
}
|
||||
|
||||
// Apply chat template to the list of messages
|
||||
auto chat_params = common_chat_templates_apply(opt.tmpls, inputs);
|
||||
|
||||
// Append assistant prefilled message
|
||||
if (prefill_assistant_message) {
|
||||
if (!last_message.content_parts.empty()) {
|
||||
for (auto & p : last_message.content_parts) {
|
||||
chat_params.prompt += p.text;
|
||||
}
|
||||
} else {
|
||||
chat_params.prompt += last_message.content;
|
||||
}
|
||||
}
|
||||
|
||||
llama_params["chat_format"] = static_cast<int>(chat_params.format);
|
||||
llama_params["prompt"] = chat_params.prompt;
|
||||
if (!chat_params.grammar.empty()) {
|
||||
llama_params["grammar"] = chat_params.grammar;
|
||||
}
|
||||
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
|
||||
auto grammar_triggers = json::array();
|
||||
for (const auto & trigger : chat_params.grammar_triggers) {
|
||||
server_grammar_trigger ct(trigger);
|
||||
grammar_triggers.push_back(ct.to_json());
|
||||
}
|
||||
llama_params["grammar_triggers"] = grammar_triggers;
|
||||
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
|
||||
llama_params["thinking_forced_open"] = chat_params.thinking_forced_open;
|
||||
for (const auto & stop : chat_params.additional_stops) {
|
||||
llama_params["stop"].push_back(stop);
|
||||
}
|
||||
|
||||
// Handle "n" field
|
||||
int n_choices = json_value(body, "n", 1);
|
||||
if (n_choices != 1) {
|
||||
throw std::runtime_error("Only one completion choice is allowed");
|
||||
}
|
||||
|
||||
// Copy remaining properties to llama_params
|
||||
// This allows user to use llama.cpp-specific params like "mirostat", ... via Anthropic endpoint.
|
||||
// See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
|
||||
for (const auto & item : body.items()) {
|
||||
// Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
|
||||
if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
|
||||
llama_params[item.key()] = item.value();
|
||||
}
|
||||
}
|
||||
|
||||
return llama_params;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// tokenizer and input processing utils
|
||||
|
||||
Reference in New Issue
Block a user