diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 511a1515..f7de8200 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3604,7 +3604,7 @@ int main(int argc, char ** argv) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); auto body = json::parse(req.body); const auto& chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default; - json data = oaicompat_completion_params_parse(ctx_server.model, body, chat_template, ctx_server.params.use_jinja); + json data = oaicompat_completion_params_parse(json::parse(req.body)); const int id_task = ctx_server.queue_tasks.get_new_id(); @@ -3707,7 +3707,7 @@ int main(int argc, char ** argv) { auto body = json::parse(req.body); const auto& chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default; - json data = oaicompat_completion_params_parse(ctx_server.model,body, chat_template, params.use_jinja); + json data = oaicompat_chat_completion_params_parse(ctx_server.model,body, chat_template, params.use_jinja); const int id_task = ctx_server.queue_tasks.get_new_id(); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 36d22cc9..5911eeeb 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -424,7 +424,47 @@ static tool_choice_type tool_choice_parse_oaicompat(const std::string & tool_cho // OAI utils // -static json oaicompat_completion_params_parse( +static json oaicompat_completion_params_parse(const json& body) { + json llama_params; + + if (!body.contains("prompt")) { + throw std::runtime_error("\"prompt\" is required"); + } + + // Handle "stop" field + if (body.contains("stop") && body.at("stop").is_string()) { + llama_params["stop"] = json::array({ body.at("stop").get() }); + } + else { + llama_params["stop"] = json_value(body, "stop", json::array()); + } + + // Handle "n" field + int n_choices = json_value(body, "n", 1); + if (n_choices != 1) { + throw std::runtime_error("Only one completion choice is allowed"); + } + + // Params supported by OAI but unsupported by llama.cpp + static const std::vector unsupported_params{ "best_of", "echo", "suffix" }; + for (const auto& param : unsupported_params) { + if (body.contains(param)) { + throw std::runtime_error("Unsupported param: " + param); + } + } + + // Copy remaining properties to llama_params + for (const auto& item : body.items()) { + // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" + if (!llama_params.contains(item.key()) || item.key() == "n_predict") { + llama_params[item.key()] = item.value(); + } + } + + return llama_params; +} + +static json oaicompat_chat_completion_params_parse( const struct llama_model * model, const json & body, /* openai api json semantics */ const common_chat_template& tmpl,