mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
add jinja template support (#677)
Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -52,6 +52,7 @@ set(TARGET common)
|
||||
|
||||
add_library(${TARGET} STATIC
|
||||
base64.hpp
|
||||
chat-template.hpp
|
||||
common.h
|
||||
common.cpp
|
||||
chat.h
|
||||
@@ -72,6 +73,7 @@ add_library(${TARGET} STATIC
|
||||
json-schema-to-grammar.cpp
|
||||
train.h
|
||||
train.cpp
|
||||
minja.hpp
|
||||
ngram-cache.h
|
||||
ngram-cache.cpp
|
||||
)
|
||||
|
||||
249
common/chat-template.hpp
Normal file
249
common/chat-template.hpp
Normal file
@@ -0,0 +1,249 @@
|
||||
/*
|
||||
Copyright 2024 Google LLC
|
||||
|
||||
Use of this source code is governed by an MIT-style
|
||||
license that can be found in the LICENSE file or at
|
||||
https://opensource.org/licenses/MIT.
|
||||
*/
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "minja.hpp"
|
||||
#include <json.hpp>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
namespace minja {
|
||||
|
||||
class chat_template {
|
||||
public:
|
||||
|
||||
private:
|
||||
bool supports_tools_ = true;
|
||||
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
|
||||
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
||||
bool requires_object_arguments_ = false;
|
||||
bool supports_system_role_ = true;
|
||||
bool supports_parallel_tool_calls_ = false;
|
||||
std::string source_;
|
||||
std::string bos_token_;
|
||||
std::string eos_token_;
|
||||
std::shared_ptr<minja::TemplateNode> template_root_;
|
||||
|
||||
std::string try_render(
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools,
|
||||
bool add_generation_prompt,
|
||||
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
|
||||
{
|
||||
try {
|
||||
auto prompt = apply(messages, tools, add_generation_prompt, extra_context);
|
||||
// fprintf(stderr, "Prompt: %s\n", prompt.c_str());
|
||||
return prompt;
|
||||
} catch (const std::exception & e) {
|
||||
// fprintf(stderr, "Error: %s\n", e.what());
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
|
||||
: source_(source), bos_token_(bos_token), eos_token_(eos_token)
|
||||
{
|
||||
template_root_ = minja::Parser::parse(source_, {
|
||||
/* .trim_blocks = */ true,
|
||||
/* .lstrip_blocks = */ true,
|
||||
/* .keep_trailing_newline = */ false,
|
||||
});
|
||||
supports_tools_ = source.find("tools") != std::string::npos;
|
||||
|
||||
auto renders_string_arguments =
|
||||
try_render({
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "Hey"}
|
||||
},
|
||||
{
|
||||
{"role", "assistant"},
|
||||
{"tool_calls", json::array({
|
||||
{
|
||||
{"id", "call_1___"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"arguments", "{\"code\": \"print('Hello, World!')\"}"},
|
||||
{"name", "ipython"},
|
||||
}},
|
||||
},
|
||||
})},
|
||||
}
|
||||
}, {}, false).find("{\"code\": \"print") != std::string::npos;
|
||||
if (!renders_string_arguments) {
|
||||
auto renders_object_arguments =
|
||||
try_render({
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "Hey"}
|
||||
},
|
||||
{
|
||||
{"role", "assistant"},
|
||||
{"tool_calls", json::array({
|
||||
{
|
||||
{"id", "call_1___"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"arguments", {
|
||||
{"code", "print('Hello, World!')"},
|
||||
}},
|
||||
{"name", "ipython"},
|
||||
}},
|
||||
},
|
||||
})},
|
||||
}
|
||||
}, {}, false).find("{\"code\": \"print") != std::string::npos;
|
||||
requires_object_arguments_ = renders_object_arguments;
|
||||
}
|
||||
supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos;
|
||||
|
||||
supports_system_role_ = try_render({
|
||||
{{"role", "system"}, {"content", "<System Needle>"}},
|
||||
{{"role", "user"}, {"content", "Hey"}}
|
||||
}, {}, false).find("<System Needle>") != std::string::npos;
|
||||
}
|
||||
|
||||
const std::string & source() const { return source_; }
|
||||
const std::string & bos_token() const { return bos_token_; }
|
||||
const std::string & eos_token() const { return eos_token_; }
|
||||
bool supports_tools() const { return supports_tools_; }
|
||||
bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; }
|
||||
|
||||
std::string apply(
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools,
|
||||
bool add_generation_prompt,
|
||||
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
|
||||
{
|
||||
json actual_messages;
|
||||
|
||||
// First, "fix" messages so they have a chance to be rendered correctly by the template
|
||||
|
||||
if (requires_object_arguments_ || !supports_system_role_ || !supports_tools_) {
|
||||
actual_messages = json::array();
|
||||
|
||||
std::string pending_system;
|
||||
auto flush_sys = [&]() {
|
||||
if (!pending_system.empty()) {
|
||||
actual_messages.push_back({
|
||||
{"role", "user"},
|
||||
{"content", pending_system},
|
||||
});
|
||||
pending_system.clear();
|
||||
}
|
||||
};
|
||||
for (const auto & message_ : messages) {
|
||||
auto message = message_;
|
||||
if (!message.contains("role") || !message.contains("content")) {
|
||||
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
||||
}
|
||||
std::string role = message.at("role");
|
||||
|
||||
if (message.contains("tool_calls")) {
|
||||
if (requires_object_arguments_ || !supports_tools_) {
|
||||
for (auto & tool_call : message.at("tool_calls")) {
|
||||
if (tool_call["type"] == "function") {
|
||||
auto & function = tool_call.at("function");
|
||||
std::string arguments = function.at("arguments");
|
||||
function["arguments"] = json::parse(arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!supports_tools_) {
|
||||
auto content = message.at("content");
|
||||
auto tool_calls = json::array();
|
||||
for (const auto & tool_call : message.at("tool_calls")) {
|
||||
if (tool_call.at("type") != "function") {
|
||||
continue;
|
||||
}
|
||||
const auto & function = tool_call.at("function");
|
||||
auto tc = json {
|
||||
{"name", function.at("name")},
|
||||
{"arguments", function.at("arguments")},
|
||||
};
|
||||
if (tool_call.contains("id")) {
|
||||
tc["id"] = tool_call["id"];
|
||||
}
|
||||
tool_calls.push_back(tc);
|
||||
}
|
||||
auto obj = json {
|
||||
{"tool_calls", tool_calls},
|
||||
};
|
||||
if (!content.is_null() && content != "") {
|
||||
obj["content"] = content;
|
||||
}
|
||||
message["content"] = obj.dump(2);
|
||||
message.erase("tool_calls");
|
||||
}
|
||||
}
|
||||
if (!supports_tools_ && role == "tool") {
|
||||
message["role"] = "user";
|
||||
auto obj = json {
|
||||
{"tool_response", {
|
||||
{"tool", message.at("name")},
|
||||
{"content", message.at("content")},
|
||||
}},
|
||||
};
|
||||
if (message.contains("tool_call_id")) {
|
||||
obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
|
||||
}
|
||||
message["content"] = obj.dump(2);
|
||||
message.erase("name");
|
||||
}
|
||||
|
||||
if (!message["content"].is_null() && !supports_system_role_) {
|
||||
std::string content = message.at("content");
|
||||
if (role == "system") {
|
||||
if (!pending_system.empty()) pending_system += "\n";
|
||||
pending_system += content;
|
||||
continue;
|
||||
} else {
|
||||
if (role == "user") {
|
||||
if (!pending_system.empty()) {
|
||||
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
|
||||
pending_system.clear();
|
||||
}
|
||||
} else {
|
||||
flush_sys();
|
||||
}
|
||||
}
|
||||
}
|
||||
actual_messages.push_back(message);
|
||||
}
|
||||
flush_sys();
|
||||
} else {
|
||||
actual_messages = messages;
|
||||
}
|
||||
|
||||
auto context = minja::Context::make(json({
|
||||
{"messages", actual_messages},
|
||||
{"add_generation_prompt", add_generation_prompt},
|
||||
{"bos_token", bos_token_},
|
||||
{"eos_token", eos_token_},
|
||||
}));
|
||||
|
||||
if (!tools.is_null()) {
|
||||
auto tools_val = minja::Value(tools);
|
||||
context->set("tools", tools_val);
|
||||
}
|
||||
if (!extra_context.is_null()) {
|
||||
for (auto & kv : extra_context.items()) {
|
||||
minja::Value val(kv.value());
|
||||
context->set(kv.key(), val);
|
||||
}
|
||||
}
|
||||
|
||||
return template_root_->render(context);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace minja
|
||||
@@ -15,9 +15,11 @@
|
||||
#include "json.hpp"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "llama.h"
|
||||
#include "chat-template.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <codecvt>
|
||||
#include <cstdarg>
|
||||
@@ -199,6 +201,16 @@ int32_t cpu_get_num_math() {
|
||||
return cpu_get_num_physical_cores();
|
||||
}
|
||||
|
||||
|
||||
static std::string read_file(const std::string& fname) {
|
||||
std::ifstream file(fname);
|
||||
if (!file) {
|
||||
throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
|
||||
}
|
||||
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
file.close();
|
||||
return content;
|
||||
}
|
||||
//
|
||||
// CLI argument parsing
|
||||
//
|
||||
@@ -278,6 +290,13 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
|
||||
if (!params.chat_template.empty() && !llama_chat_verify_template(nullptr, params.chat_template, params.use_jinja)) {
|
||||
throw std::runtime_error(string_format(
|
||||
"error: the supplied chat template is not supported: %s%s\n",
|
||||
params.chat_template.c_str(),
|
||||
params.use_jinja ? "" : "\nnote: llama.cpp was started without --jinja, we only support commonly used templates"
|
||||
));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1425,7 +1444,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
}
|
||||
if (arg == "--chat-template") {
|
||||
CHECK_ARG
|
||||
if (!llama_chat_verify_template(argv[i])) {
|
||||
if (!llama_chat_verify_template(nullptr, argv[i], false)) {
|
||||
fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]);
|
||||
fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n");
|
||||
invalid_param = true;
|
||||
@@ -1434,6 +1453,22 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
params.chat_template = argv[i];
|
||||
return true;
|
||||
}
|
||||
if (arg == "--chat-template-file") {
|
||||
CHECK_ARG
|
||||
std::string chat_template = read_file(std::string(argv[i]));
|
||||
if (!llama_chat_verify_template(nullptr, chat_template, false)) {
|
||||
fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]);
|
||||
fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n");
|
||||
invalid_param = true;
|
||||
return true;
|
||||
}
|
||||
params.chat_template = chat_template;
|
||||
return true;
|
||||
}
|
||||
if (arg == "--jinja") {
|
||||
params.use_jinja = true;
|
||||
return true;
|
||||
}
|
||||
if (arg == "--slot-prompt-similarity" || arg == "-sps") {
|
||||
CHECK_ARG
|
||||
params.slot_prompt_similarity = std::stof(argv[i]);
|
||||
@@ -1984,6 +2019,22 @@ std::string gpt_params_get_system_info(const gpt_params & params) {
|
||||
// String utils
|
||||
//
|
||||
|
||||
std::string string_format(const char* fmt, ...) {
|
||||
va_list ap;
|
||||
va_list ap2;
|
||||
va_start(ap, fmt);
|
||||
va_copy(ap2, ap);
|
||||
int size = vsnprintf(NULL, 0, fmt, ap);
|
||||
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
|
||||
std::vector<char> buf(size + 1);
|
||||
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
|
||||
GGML_ASSERT(size2 == size);
|
||||
va_end(ap2);
|
||||
va_end(ap);
|
||||
return std::string(buf.data(), size);
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::string> string_split(std::string input, char separator) {
|
||||
std::vector<std::string> parts;
|
||||
size_t separator_pos = input.find(separator);
|
||||
@@ -2985,6 +3036,22 @@ std::string llama_token_to_piece(const struct llama_context * ctx, llama_token t
|
||||
return piece;
|
||||
}
|
||||
|
||||
std::string llama_token_to_piece(const struct llama_model* model, llama_token token, bool special) {
|
||||
std::string piece;
|
||||
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
|
||||
const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
|
||||
if (n_chars < 0) {
|
||||
piece.resize(-n_chars);
|
||||
int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
|
||||
GGML_ASSERT(check == -n_chars);
|
||||
}
|
||||
else {
|
||||
piece.resize(n_chars);
|
||||
}
|
||||
|
||||
return piece;
|
||||
}
|
||||
|
||||
std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
|
||||
std::string text;
|
||||
text.resize(std::max(text.capacity(), tokens.size()));
|
||||
@@ -3011,50 +3078,60 @@ bool llama_should_add_bos_token(const llama_model * model) {
|
||||
// Chat template utils
|
||||
//
|
||||
|
||||
bool llama_chat_verify_template(const std::string & tmpl) {
|
||||
bool llama_chat_verify_template(const struct llama_model* model, const std::string& tmpl, bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
try {
|
||||
auto chat_template = minja::chat_template(tmpl, "<s>", "</s>");
|
||||
chat_template.apply({ {
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
} }, json(), true);
|
||||
return true;
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
fprintf(stdout,"%s: failed to apply template: %s\n", __func__, e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
llama_chat_message chat[] = {{"user", "test"}};
|
||||
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
|
||||
const int res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
|
||||
return res >= 0;
|
||||
}
|
||||
|
||||
std::string llama_chat_apply_template(const struct llama_model * model,
|
||||
const std::string & tmpl,
|
||||
const common_chat_template& tmpl,
|
||||
const std::vector<llama_chat_msg> & msgs,
|
||||
bool add_ass) {
|
||||
bool add_ass,
|
||||
bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
auto messages = json::array();
|
||||
for (const auto& msg : msgs) {
|
||||
messages.push_back({ {"role", msg.role}, {"content", msg.content} });
|
||||
}
|
||||
return tmpl.apply(messages, /* tools= */ json(), add_ass);
|
||||
}
|
||||
int alloc_size = 0;
|
||||
bool fallback = false; // indicate if we must fallback to default chatml
|
||||
std::vector<llama_chat_message> chat;
|
||||
for (auto & msg : msgs) {
|
||||
chat.push_back({msg.role.c_str(), msg.content.c_str()});
|
||||
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
|
||||
}
|
||||
|
||||
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
|
||||
std::vector<char> buf(alloc_size);
|
||||
|
||||
// run the first time to get the total output length
|
||||
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
|
||||
int32_t res = llama_chat_apply_template(model, tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
// error: chat template is not supported
|
||||
if (res < 0) {
|
||||
if (ptr_tmpl != nullptr) {
|
||||
// if the custom "tmpl" is not supported, we throw an error
|
||||
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
||||
throw std::runtime_error("this custom template is not supported");
|
||||
} else {
|
||||
// If the built-in template is not supported, we default to chatml
|
||||
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
fallback = true;
|
||||
}
|
||||
}
|
||||
|
||||
// if it turns out that our buffer is too small, we resize it
|
||||
if ((size_t) res > buf.size()) {
|
||||
buf.resize(res);
|
||||
res = llama_chat_apply_template(
|
||||
fallback ? nullptr : model,
|
||||
fallback ? "chatml" : ptr_tmpl,
|
||||
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
res = llama_chat_apply_template(model, tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
}
|
||||
|
||||
std::string formatted_chat(buf.data(), res);
|
||||
@@ -3062,12 +3139,13 @@ std::string llama_chat_apply_template(const struct llama_model * model,
|
||||
}
|
||||
|
||||
std::string llama_chat_format_single(const struct llama_model * model,
|
||||
const std::string & tmpl,
|
||||
const common_chat_template& tmpl,
|
||||
const std::vector<llama_chat_msg> & past_msg,
|
||||
const llama_chat_msg & new_msg,
|
||||
bool add_ass) {
|
||||
bool add_ass,
|
||||
bool use_jinja) {
|
||||
std::ostringstream ss;
|
||||
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false);
|
||||
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja);
|
||||
std::vector<llama_chat_msg> chat_new(past_msg);
|
||||
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
||||
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
||||
@@ -3075,21 +3153,77 @@ std::string llama_chat_format_single(const struct llama_model * model,
|
||||
};
|
||||
// format chat with new_msg
|
||||
chat_new.push_back(new_msg);
|
||||
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass);
|
||||
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja);
|
||||
// get the diff part
|
||||
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
std::string llama_chat_format_example(const struct llama_model * model,
|
||||
const std::string & tmpl) {
|
||||
std::string llama_chat_format_example(const struct llama_model * model, const common_chat_template& tmpl, bool use_jinja) {
|
||||
std::vector<llama_chat_msg> msgs = {
|
||||
{"system", "You are a helpful assistant"},
|
||||
{"user", "Hello"},
|
||||
{"assistant", "Hi there"},
|
||||
{"user", "How are you?"},
|
||||
};
|
||||
return llama_chat_apply_template(model, tmpl, msgs, true);
|
||||
return llama_chat_apply_template(model, tmpl, msgs, true, use_jinja);
|
||||
}
|
||||
|
||||
|
||||
common_chat_templates llama_chat_templates_from_model(const struct llama_model* model, const std::string& chat_template_override)
|
||||
{
|
||||
auto vocab = llama_model_get_vocab(model);
|
||||
std::string default_template_src = chat_template_override;
|
||||
std::string template_tool_use_src = chat_template_override;
|
||||
bool has_explicit_template = !chat_template_override.empty();
|
||||
if (chat_template_override.empty()) {
|
||||
auto str = llama_model_chat_template(model, /* name */ nullptr);
|
||||
if (str) {
|
||||
default_template_src = str;
|
||||
has_explicit_template = true;
|
||||
}
|
||||
str = llama_model_chat_template(model, /* name */ "tool_use");
|
||||
if (str) {
|
||||
template_tool_use_src = str;
|
||||
has_explicit_template = true;
|
||||
}
|
||||
}
|
||||
if (default_template_src.empty() || default_template_src == "chatml") {
|
||||
if (!template_tool_use_src.empty()) {
|
||||
default_template_src = template_tool_use_src;
|
||||
}
|
||||
else {
|
||||
default_template_src = R"(
|
||||
{%- for message in messages -%}
|
||||
{{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{- "<|im_start|>assistant\n" -}}
|
||||
{%- endif -%}
|
||||
)";
|
||||
}
|
||||
}
|
||||
const auto get_token = [&](llama_token token, const char* name, const char* jinja_variable_name) {
|
||||
if (token == LLAMA_TOKEN_NULL) {
|
||||
if (default_template_src.find(jinja_variable_name) != std::string::npos
|
||||
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
|
||||
fprintf(stdout, "%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
|
||||
}
|
||||
return std::string();
|
||||
}
|
||||
else {
|
||||
return llama_token_to_piece(model, token, true);
|
||||
}
|
||||
};
|
||||
auto token_bos = get_token(llama_token_bos(model), "BOS", "bos_token");
|
||||
auto token_eos = get_token(llama_token_eos(model), "EOS", "eos_token");
|
||||
return {
|
||||
has_explicit_template,
|
||||
std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
|
||||
template_tool_use_src.empty()
|
||||
? nullptr
|
||||
: std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos)
|
||||
};
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@@ -228,6 +228,7 @@ struct gpt_params {
|
||||
std::string hostname = "127.0.0.1";
|
||||
std::string public_path = "";
|
||||
std::string chat_template = "";
|
||||
bool use_jinja = false; // NOLINT
|
||||
std::string system_prompt = "";
|
||||
bool enable_chat_template = true;
|
||||
|
||||
@@ -400,6 +401,11 @@ std::string llama_token_to_piece(
|
||||
llama_token token,
|
||||
bool special = true);
|
||||
|
||||
std::string llama_token_to_piece(
|
||||
const struct llama_model* model,
|
||||
llama_token token,
|
||||
bool special = true);
|
||||
|
||||
// detokenizes a vector of tokens into a string
|
||||
// should work similar to Python's `tokenizer.decode`
|
||||
// optionally renders special/control tokens
|
||||
@@ -423,26 +429,45 @@ struct llama_chat_msg {
|
||||
};
|
||||
|
||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
bool llama_chat_verify_template(const std::string & tmpl);
|
||||
bool llama_chat_verify_template(const struct llama_model* , const std::string& tmpl, bool use_jinja);
|
||||
|
||||
namespace minja {
|
||||
class chat_template;
|
||||
}
|
||||
|
||||
typedef minja::chat_template common_chat_template;
|
||||
|
||||
struct common_chat_templates {
|
||||
bool has_explicit_template; // Model had builtin template or template overridde was specified.
|
||||
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
|
||||
std::unique_ptr<common_chat_template> template_tool_use;
|
||||
};
|
||||
|
||||
|
||||
// CPP wrapper for llama_chat_apply_template
|
||||
// If the built-in template is not supported, we default to chatml
|
||||
// If the custom "tmpl" is not supported, we throw an error
|
||||
std::string llama_chat_apply_template(const struct llama_model * model,
|
||||
const std::string & tmpl,
|
||||
const std::vector<llama_chat_msg> & chat,
|
||||
bool add_ass);
|
||||
std::string llama_chat_apply_template(
|
||||
const struct llama_model* model,
|
||||
const common_chat_template& tmpl,
|
||||
const std::vector< llama_chat_msg>& chat,
|
||||
bool add_ass,
|
||||
bool use_jinja);
|
||||
|
||||
// Format single message, while taking into account the position of that message in chat history
|
||||
std::string llama_chat_format_single(const struct llama_model * model,
|
||||
const std::string & tmpl,
|
||||
const std::vector<llama_chat_msg> & past_msg,
|
||||
const llama_chat_msg & new_msg,
|
||||
bool add_ass);
|
||||
std::string llama_chat_format_single(const struct llama_model* model,
|
||||
const common_chat_template& tmpl,
|
||||
const std::vector< llama_chat_msg>& past_msg,
|
||||
const llama_chat_msg& new_msg,
|
||||
bool add_ass,
|
||||
bool use_jinja);
|
||||
|
||||
// Returns an example of formatted chat
|
||||
std::string llama_chat_format_example(const struct llama_model * model,
|
||||
const std::string & tmpl);
|
||||
std::string llama_chat_format_example(const struct llama_model* model,
|
||||
const common_chat_template& tmpl, bool use_jinja);
|
||||
|
||||
common_chat_templates llama_chat_templates_from_model(const struct llama_model* model, const std::string& chat_template_override);
|
||||
|
||||
|
||||
//
|
||||
// KV cache utils
|
||||
@@ -502,3 +527,5 @@ void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const cha
|
||||
void yaml_dump_non_result_info(
|
||||
FILE * stream, const gpt_params & params, const llama_context * lctx,
|
||||
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);
|
||||
|
||||
std::string string_format(const char* fmt, ...);
|
||||
|
||||
3029
common/minja.hpp
Normal file
3029
common/minja.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -2,7 +2,7 @@
|
||||
|
||||
#include "console.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include "chat-template.hpp"
|
||||
#include <cassert>
|
||||
#include <cinttypes>
|
||||
#include <cmath>
|
||||
@@ -119,10 +119,10 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v
|
||||
LOG_TEE("%s", text);
|
||||
}
|
||||
|
||||
static std::string chat_add_and_format(struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) {
|
||||
static std::string chat_add_and_format(struct llama_model * model, common_chat_templates &chat_templates, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) {
|
||||
llama_chat_msg new_msg{role, content};
|
||||
auto formatted = llama_chat_format_single(
|
||||
model, g_params->chat_template, chat_msgs, new_msg, role == "user");
|
||||
auto formatted = llama_chat_format_single(model,
|
||||
*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja);
|
||||
chat_msgs.push_back({role, content});
|
||||
LOG("formatted: %s\n", formatted.c_str());
|
||||
return formatted;
|
||||
@@ -220,6 +220,7 @@ int main(int argc, char ** argv) {
|
||||
LOG_TEE("%s: error: unable to load model\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
auto chat_templates = llama_chat_templates_from_model(model, params.chat_template);
|
||||
|
||||
const int n_ctx_train = llama_n_ctx_train(model);
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
@@ -229,11 +230,10 @@ int main(int argc, char ** argv) {
|
||||
LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n",
|
||||
__func__, n_ctx_train, n_ctx);
|
||||
}
|
||||
|
||||
// print chat template example in conversation mode
|
||||
if (params.conversation) {
|
||||
if (params.enable_chat_template) {
|
||||
LOG_TEE("%s: chat template example: %s\n", __func__, llama_chat_format_example(model, params.chat_template).c_str());
|
||||
LOG_TEE("%s: chat template example: %s\n", __func__, llama_chat_format_example(model, *chat_templates.template_default, params.use_jinja).c_str());
|
||||
} else {
|
||||
LOG_TEE("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
|
||||
}
|
||||
@@ -274,11 +274,29 @@ int main(int argc, char ** argv) {
|
||||
LOG("add_bos: %d\n", add_bos);
|
||||
|
||||
std::vector<llama_token> embd_inp;
|
||||
bool waiting_for_first_input = params.conversation && params.enable_chat_template && params.system_prompt.empty();
|
||||
|
||||
{
|
||||
auto prompt = (params.conversation && params.enable_chat_template && !params.prompt.empty())
|
||||
? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
|
||||
: params.prompt;
|
||||
//auto prompt = (params.conversation && params.enable_chat_template && !params.prompt.empty())
|
||||
// ? chat_add_and_format(model, chat_templates,chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
|
||||
// : params.prompt;
|
||||
std::string prompt;
|
||||
|
||||
if (params.conversation && params.enable_chat_template) {
|
||||
// format the system prompt in conversation mode (will use template default if empty)
|
||||
prompt = params.system_prompt;
|
||||
|
||||
if (!prompt.empty()) {
|
||||
prompt = chat_add_and_format(model, chat_templates,chat_msgs, "system", prompt);
|
||||
}
|
||||
}
|
||||
else {
|
||||
// otherwise use the prompt as is
|
||||
prompt = params.prompt;
|
||||
}
|
||||
|
||||
|
||||
|
||||
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
|
||||
LOG("tokenize the prompt\n");
|
||||
embd_inp = ::llama_tokenize(ctx, prompt, true, true);
|
||||
@@ -292,7 +310,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// Should not run without any tokens
|
||||
if (embd_inp.empty()) {
|
||||
if (!params.conversation && embd_inp.empty()) {
|
||||
if (add_bos) {
|
||||
embd_inp.push_back(llama_token_bos(model));
|
||||
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
||||
@@ -837,7 +855,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// deal with end of generation tokens in interactive mode
|
||||
if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
|
||||
if (!waiting_for_first_input && llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
|
||||
LOG("found an EOG token\n");
|
||||
|
||||
if (params.interactive) {
|
||||
@@ -849,7 +867,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
if (params.enable_chat_template) {
|
||||
chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str());
|
||||
chat_add_and_format(model, chat_templates, chat_msgs, "assistant", assistant_ss.str());
|
||||
}
|
||||
is_interacting = true;
|
||||
printf("\n");
|
||||
@@ -857,12 +875,12 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// if current token is not EOG, we add it to current assistant message
|
||||
if (params.conversation) {
|
||||
if (params.conversation && !waiting_for_first_input) {
|
||||
auto id = llama_sampling_last(ctx_sampling);
|
||||
assistant_ss << llama_token_to_piece(ctx, id, false);
|
||||
}
|
||||
|
||||
if (n_past > 0 && is_interacting) {
|
||||
if ((n_past > 0 || waiting_for_first_input) && is_interacting) {
|
||||
LOG("waiting for user input\n");
|
||||
|
||||
if (params.conversation) {
|
||||
@@ -914,7 +932,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
bool format_chat = params.conversation && params.enable_chat_template;
|
||||
std::string user_inp = format_chat
|
||||
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
|
||||
? chat_add_and_format(model, chat_templates, chat_msgs, "user", std::move(buffer))
|
||||
: std::move(buffer);
|
||||
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
|
||||
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
|
||||
@@ -952,11 +970,12 @@ int main(int argc, char ** argv) {
|
||||
input_echo = false; // do not echo this again
|
||||
}
|
||||
|
||||
if (n_past > 0) {
|
||||
if (n_past > 0 || waiting_for_first_input) {
|
||||
if (is_interacting) {
|
||||
llama_sampling_reset(ctx_sampling);
|
||||
}
|
||||
is_interacting = false;
|
||||
waiting_for_first_input = false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -814,6 +814,7 @@ struct server_context {
|
||||
|
||||
server_metrics metrics;
|
||||
|
||||
common_chat_templates chat_templates;
|
||||
// Necessary similarity of prompt for slot selection
|
||||
float slot_prompt_similarity = 0.0f;
|
||||
|
||||
@@ -860,15 +861,47 @@ struct server_context {
|
||||
add_bos_token = llama_should_add_bos_token(model);
|
||||
GGML_ASSERT(llama_add_eos_token(model) != 1);
|
||||
|
||||
if (params.chat_template.empty() && !validate_model_chat_template(params.use_jinja)) {
|
||||
LOG_WARNING("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
||||
chat_templates = llama_chat_templates_from_model(model, "chatml");
|
||||
}
|
||||
else {
|
||||
chat_templates = llama_chat_templates_from_model(model, params.chat_template);
|
||||
}
|
||||
GGML_ASSERT(chat_templates.template_default.get() != nullptr);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool validate_model_chat_template() const {
|
||||
bool validate_model_chat_template(bool use_jinja) const {
|
||||
llama_chat_message chat[] = {{"user", "test"}};
|
||||
|
||||
const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
|
||||
|
||||
return res > 0;
|
||||
if (use_jinja) {
|
||||
auto templates = llama_chat_templates_from_model(model, "");
|
||||
GGML_ASSERT(templates.template_default);
|
||||
try {
|
||||
templates.template_default->apply({ {
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
} }, json(), true);
|
||||
if (templates.template_tool_use) {
|
||||
templates.template_tool_use->apply({ {
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
} }, json(), true);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
LOG_ERROR("failed to apply template: %s\n", e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else {
|
||||
const char* tmpl = llama_model_chat_template(model, /* name */ nullptr);
|
||||
const int32_t chat_res = llama_chat_apply_template(model, tmpl, chat, 1, true, nullptr, 0);
|
||||
return chat_res > 0;
|
||||
}
|
||||
}
|
||||
|
||||
void init() {
|
||||
@@ -3182,22 +3215,16 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const auto model_meta = ctx_server.model_meta();
|
||||
|
||||
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
|
||||
if (params.chat_template.empty()) {
|
||||
if (!ctx_server.validate_model_chat_template()) {
|
||||
LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
|
||||
params.chat_template = "chatml";
|
||||
}
|
||||
}
|
||||
|
||||
// print sample chat example to make it clear which template is used
|
||||
{
|
||||
|
||||
LOG_INFO("chat template", {
|
||||
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
|
||||
{"chat_template", ctx_server.chat_templates.template_default->source().c_str()},
|
||||
});
|
||||
|
||||
LOG_INFO("chat template", {
|
||||
{"chat_example", llama_chat_format_example(ctx_server.model, *ctx_server.chat_templates.template_default, ctx_server.params.use_jinja).c_str()},
|
||||
{"built_in", params.chat_template.empty()},
|
||||
});
|
||||
}
|
||||
|
||||
//
|
||||
// Middlewares
|
||||
//
|
||||
@@ -3560,9 +3587,11 @@ int main(int argc, char ** argv) {
|
||||
{ "system_prompt", ctx_server.system_prompt.c_str() },
|
||||
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
||||
{ "total_slots", ctx_server.params.n_parallel },
|
||||
{ "chat_template", curr_tmpl.c_str() }
|
||||
{ "chat_template", ctx_server.chat_templates.template_default->source() },
|
||||
};
|
||||
|
||||
if (ctx_server.params.use_jinja && ctx_server.chat_templates.template_tool_use) {
|
||||
data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source();
|
||||
}
|
||||
res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
};
|
||||
|
||||
@@ -3573,8 +3602,9 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
|
||||
json data = json::parse(req.body);
|
||||
auto body = json::parse(req.body);
|
||||
const auto& chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default;
|
||||
json data = oaicompat_completion_params_parse(ctx_server.model, body, chat_template, ctx_server.params.use_jinja);
|
||||
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
|
||||
@@ -3674,7 +3704,11 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||
|
||||
auto body = json::parse(req.body);
|
||||
const auto& chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default;
|
||||
json data = oaicompat_completion_params_parse(ctx_server.model,body, chat_template, params.use_jinja);
|
||||
|
||||
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
#include "minja.hpp"
|
||||
#include "chat-template.hpp"
|
||||
#include "kimi_k2_tools.hpp"
|
||||
#include "qwen3_tools.hpp"
|
||||
#include "deepseek_r1_tools.hpp"
|
||||
@@ -125,7 +127,7 @@ static inline void server_log(const char * level, const char * function, int lin
|
||||
//
|
||||
|
||||
// Format given chat. If tmpl is empty, we take the template from model metadata
|
||||
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages, const json & tools = json::array(), const std::string & model_name = "") {
|
||||
inline std::string format_chat(const struct llama_model * model, common_chat_template tmpl, const std::vector<json> & messages, const json & tools = json::array(), const std::string & model_name = "") {
|
||||
std::vector<llama_chat_msg> chat;
|
||||
|
||||
// Inject tools into the first system message, or create one if none exists
|
||||
@@ -197,8 +199,8 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
|
||||
|
||||
chat.push_back({role, content});
|
||||
}
|
||||
auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true, /* use_jinja= */ false);
|
||||
|
||||
auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true);
|
||||
LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
|
||||
return formatted_chat;
|
||||
}
|
||||
@@ -425,46 +427,24 @@ static tool_choice_type tool_choice_parse_oaicompat(const std::string & tool_cho
|
||||
static json oaicompat_completion_params_parse(
|
||||
const struct llama_model * model,
|
||||
const json & body, /* openai api json semantics */
|
||||
const std::string & chat_template) {
|
||||
const common_chat_template& tmpl,
|
||||
bool use_jinja) {
|
||||
json llama_params;
|
||||
|
||||
llama_params["__oaicompat"] = true;
|
||||
auto tools = json_value(body, "tools", json());
|
||||
auto has_tools = tools.is_array() && !tools.empty();
|
||||
|
||||
if (has_tools) {
|
||||
if (use_jinja) {
|
||||
fprintf(stdout,"tools param is not fully supported yet\n");
|
||||
// Extract tools from the request body
|
||||
json tools = json_value(body, "tools", json::array());
|
||||
|
||||
}
|
||||
// Debug: Log system prompt when tools are detected
|
||||
if (!tools.empty() && server_verbose) {
|
||||
LOG_VERBOSE("Tool calls detected in request", {
|
||||
{"tool_count", tools.size()},
|
||||
{"model", json_value(body, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}
|
||||
});
|
||||
|
||||
// Extract and log system prompt from messages
|
||||
if (body.contains("messages") && body["messages"].is_array()) {
|
||||
for (const auto& msg : body["messages"]) {
|
||||
if (msg.contains("role") && msg["role"] == "system" && msg.contains("content")) {
|
||||
std::string content_str;
|
||||
if (msg["content"].is_string()) {
|
||||
content_str = msg["content"];
|
||||
} else if (msg["content"].is_array()) {
|
||||
// Handle content blocks format
|
||||
for (const auto& block : msg["content"]) {
|
||||
if (block.contains("type") && block["type"] == "text" && block.contains("text")) {
|
||||
if (!content_str.empty()) content_str += " ";
|
||||
content_str += block["text"];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!content_str.empty()) {
|
||||
LOG_VERBOSE("System prompt with tools", {
|
||||
{"system_prompt", content_str.substr(0, 500) + (content_str.length() > 500 ? "..." : "")}
|
||||
});
|
||||
}
|
||||
break; // Only log first system message
|
||||
}
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("tools param requires --jinja flag");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -472,7 +452,7 @@ static json oaicompat_completion_params_parse(
|
||||
std::string model_name = json_value(body, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
|
||||
|
||||
// Apply chat template to the list of messages with tools
|
||||
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"), tools, model_name);
|
||||
llama_params["prompt"] = format_chat(model, tmpl, body.at("messages"), tools, model_name);
|
||||
|
||||
// Handle "stop" field
|
||||
if (body.contains("stop") && body.at("stop").is_string()) {
|
||||
@@ -491,6 +471,13 @@ static json oaicompat_completion_params_parse(
|
||||
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
|
||||
}
|
||||
}
|
||||
// Apply chat template to the list of messages
|
||||
if (use_jinja) {
|
||||
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
|
||||
}
|
||||
else {
|
||||
llama_params["prompt"] = format_chat(model, tmpl, body.at("messages"));
|
||||
}
|
||||
|
||||
// Handle "n" field
|
||||
int n_choices = json_value(body, "n", 1);
|
||||
|
||||
@@ -554,8 +554,9 @@ extern "C" {
|
||||
LLAMA_API bool llama_supports_mlock (void);
|
||||
LLAMA_API bool llama_supports_gpu_offload(void);
|
||||
|
||||
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
|
||||
|
||||
|
||||
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
||||
@@ -566,6 +567,8 @@ extern "C" {
|
||||
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
|
||||
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
||||
|
||||
LLAMA_API const struct llama_vocab* llama_model_get_vocab(const struct llama_model* model);
|
||||
LLAMA_API const char* llama_model_chat_template(const struct llama_model* model, const char* name);
|
||||
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
||||
LLAMA_API const struct llama_vocab* llama_get_model_vocab(const struct llama_model* model);
|
||||
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
||||
|
||||
124
models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja
Normal file
124
models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja
Normal file
@@ -0,0 +1,124 @@
|
||||
{%- set today = strftime_now("%Y-%m-%d") %}
|
||||
{%- set default_system_message = "You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.\nYour knowledge base was last updated on 2023-10-01. The current date is " + today + ".\n\nWhen you're not sure about some information or when the user's request requires up-to-date or specific data, you must use the available tools to fetch the information. Do not hesitate to use tools whenever they can provide a more accurate or complete response. If no relevant tools are available, then clearly state that you don't have the information and avoid making up anything.
|
||||
|
||||
If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. \"What are some good restaurants around me?\" => \"Where are you?\" or \"When is the next flight to Tokyo\" => \"Where do you travel from?\").
|
||||
You are always very attentive to dates, and when asked about information at specific dates, you discard information that is at another date.
|
||||
You follow these instructions in all languages, and always respond to the user in the language they use or request.
|
||||
Next sections describe the capabilities that you have.
|
||||
|
||||
# WEB BROWSING INSTRUCTIONS
|
||||
|
||||
You cannot perform any web search or access internet to open URLs, links etc. If it seems like the user is expecting you to do so, you clarify the situation and ask the user to copy paste the text directly in the chat.
|
||||
|
||||
# MULTI-MODAL INSTRUCTIONS
|
||||
|
||||
You have the ability to read images, but you cannot generate images. You also cannot transcribe audio files or videos.
|
||||
You cannot read nor transcribe audio files or videos.
|
||||
|
||||
# TOOL CALLING INSTRUCTIONS
|
||||
|
||||
You may have access to tools that you can use to fetch information or perform actions. You must use these tools in the following situations:
|
||||
|
||||
1. When the request requires up-to-date information.
|
||||
2. When the request requires specific data that you do not have in your knowledge base.
|
||||
3. When the request involves actions that you cannot perform without tools.
|
||||
|
||||
Always prioritize using tools to provide the most accurate and helpful response. If tools are not available, inform the user that you cannot perform the requested action at the moment." %}
|
||||
|
||||
{{- bos_token }}
|
||||
|
||||
{%- set system_prompt = default_system_message %}
|
||||
{%- set loop_messages = messages %}
|
||||
|
||||
{%- if not tools is defined %}
|
||||
{%- set tools = none %}
|
||||
{%- endif %}
|
||||
|
||||
{%- if messages|length > 0 and messages[0]['role'] == 'system' %}
|
||||
{%- if messages[0]['content'] is string %}
|
||||
{%- set system_prompt = messages[0]['content'] %}
|
||||
{%- else %}
|
||||
{%- set system_prompt = messages[0]['content'][0]['text'] %}
|
||||
{%- endif %}
|
||||
{%- set loop_messages = messages[1:] %}
|
||||
{%- endif %}
|
||||
|
||||
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
|
||||
|
||||
{%- set ns = namespace(index=0) %}
|
||||
{%- for message in loop_messages %}
|
||||
{%- if not (message.role == "tool" or (message.get('tool_calls'))) %}
|
||||
{%- if (message["role"] == "user") != (ns.index % 2 == 0) %}
|
||||
{{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}
|
||||
{%- endif %}
|
||||
{%- set ns.index = ns.index + 1 %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
{{- '[SYSTEM_PROMPT]' + system_prompt + '[/SYSTEM_PROMPT]' }}
|
||||
|
||||
{%- for message in loop_messages %}
|
||||
{%- if message['role'] == 'system' %}
|
||||
{%- if message['content'] is string %}
|
||||
{{- '[SYSTEM_PROMPT]' + message['content'] + '[/SYSTEM_PROMPT]' }}
|
||||
{%- else %}
|
||||
{{- '[SYSTEM_PROMPT]' + message['content'][0]['text'] + '[/SYSTEM_PROMPT]' }}
|
||||
{%- endif %}
|
||||
{%- elif message['role'] == 'user' %}
|
||||
{%- if tools is not none and (message == user_messages[-1]) %}
|
||||
{{- '[AVAILABLE_TOOLS]' + tools|tojson + '[/AVAILABLE_TOOLS]' }}
|
||||
{%- endif %}
|
||||
{{- '[INST]' }}
|
||||
{%- if message['content'] is string %}
|
||||
{{- message['content'] }}
|
||||
{%- else %}
|
||||
{%- for block in message['content'] %}
|
||||
{%- if block['type'] == 'text' %}
|
||||
{{- block['text'] }}
|
||||
{%- elif block['type'] in ['image', 'image_url'] %}
|
||||
{{- '[IMG]' }}
|
||||
{%- else %}
|
||||
{{- raise_exception('Only text and image blocks are supported in message content!') }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '[/INST]' }}
|
||||
{%- elif message['role'] == 'assistant' %}
|
||||
{%- if message.get('tool_calls') %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{{- '[TOOL_CALLS]' + tool_call.function.name }}
|
||||
{%- if not tool_call.id is defined or tool_call.id is not string or tool_call.id|length != 9 %}
|
||||
{{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}
|
||||
{%- endif %}
|
||||
{{- '[CALL_ID]' + tool_call.id }}
|
||||
{{- '[ARGS]' + tool_call['function']['arguments']|tojson }}
|
||||
{%- endfor %}
|
||||
{{- eos_token }}
|
||||
{%- elif message['content'] is string %}
|
||||
{{- message['content'] + eos_token }}
|
||||
{%- else %}
|
||||
{%- for block in message['content'] %}
|
||||
{%- if block['type'] == 'text' %}
|
||||
{{- block['text'] }}
|
||||
{%- elif block['type'] in ['image', 'image_url'] %}
|
||||
{{- '[IMG]' }}
|
||||
{%- else %}
|
||||
{{- raise_exception('Only text and image blocks are supported in assistant content!') }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- eos_token }}
|
||||
{%- endif %}
|
||||
{%- elif message['role'] == 'tool_results' or message['role'] == 'tool' %}
|
||||
{%- if message.content is defined and message.content.content is defined %}
|
||||
{%- set content = message.content.content %}
|
||||
{%- else %}
|
||||
{%- set content = message.content %}
|
||||
{%- endif %}
|
||||
{%- if not message.tool_call_id is defined or message.tool_call_id is not string or message['tool_call_id']|length != 9 %}
|
||||
{{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}
|
||||
{%- endif %}
|
||||
{{- '[TOOL_RESULTS]' + message.tool_call_id + '[TOOL_CONTENT]' + content|string + '[/TOOL_RESULTS]' }}
|
||||
{%- else %}
|
||||
{{- raise_exception('Only system, user, assistant, and tool roles are supported!') }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
77
scripts/get_hf_chat_template.py
Normal file
77
scripts/get_hf_chat_template.py
Normal file
@@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python
|
||||
'''
|
||||
Fetches the Jinja chat template of a HuggingFace model.
|
||||
If a model has multiple chat templates, you can specify the variant name.
|
||||
|
||||
Syntax:
|
||||
./scripts/get_hf_chat_template.py model_id [variant]
|
||||
|
||||
Examples:
|
||||
./scripts/get_hf_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct
|
||||
./scripts/get_hf_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use
|
||||
./scripts/get_hf_chat_template.py meta-llama/Llama-3.2-3B-Instruct
|
||||
'''
|
||||
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
|
||||
|
||||
def get_hf_chat_template(model_id, variant=None):
|
||||
try:
|
||||
# Use huggingface_hub library if available.
|
||||
# Allows access to gated models if the user has access and ran `huggingface-cli login`.
|
||||
from huggingface_hub import hf_hub_download
|
||||
with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f:
|
||||
config_str = f.read()
|
||||
except ImportError:
|
||||
import requests
|
||||
assert re.match(r"^[\w.-]+/[\w.-]+$", model_id), f"Invalid model ID: {model_id}"
|
||||
response = requests.get(f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json")
|
||||
if response.status_code == 401:
|
||||
raise Exception('Access to this model is gated, please request access, authenticate with `huggingface-cli login` and make sure to run `pip install huggingface_hub`')
|
||||
response.raise_for_status()
|
||||
config_str = response.text
|
||||
|
||||
try:
|
||||
config = json.loads(config_str)
|
||||
except json.JSONDecodeError:
|
||||
# Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
|
||||
# (Remove extra '}' near the end of the file)
|
||||
config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str))
|
||||
|
||||
chat_template = config['chat_template']
|
||||
if isinstance(chat_template, str):
|
||||
return chat_template
|
||||
else:
|
||||
variants = {
|
||||
ct['name']: ct['template']
|
||||
for ct in chat_template
|
||||
}
|
||||
|
||||
def format_variants():
|
||||
return ', '.join(f'"{v}"' for v in variants.keys())
|
||||
|
||||
if variant is None:
|
||||
if 'default' not in variants:
|
||||
raise Exception(f'Please specify a chat template variant (one of {format_variants()})')
|
||||
variant = 'default'
|
||||
sys.stderr.write(f'Note: picked "default" chat template variant (out of {format_variants()})\n')
|
||||
elif variant not in variants:
|
||||
raise Exception(f"Variant {variant} not found in chat template (found {format_variants()})")
|
||||
|
||||
return variants[variant]
|
||||
|
||||
|
||||
def main(args):
|
||||
if len(args) < 1:
|
||||
raise ValueError("Please provide a model ID and an optional variant name")
|
||||
model_id = args[0]
|
||||
variant = None if len(args) < 2 else args[1]
|
||||
|
||||
template = get_hf_chat_template(model_id, variant)
|
||||
sys.stdout.write(template)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(sys.argv[1:])
|
||||
@@ -22,7 +22,7 @@ add_library(llama
|
||||
unicode-data.cpp
|
||||
)
|
||||
|
||||
target_include_directories(llama PUBLIC . ../include)
|
||||
target_include_directories(llama PUBLIC . ../include ../common)
|
||||
target_include_directories(llama PRIVATE ../ggml/src)
|
||||
target_compile_features (llama PUBLIC cxx_std_11) # don't bump
|
||||
|
||||
|
||||
@@ -123,15 +123,31 @@
|
||||
//
|
||||
|
||||
// trim whitespace from the beginning and end of a string
|
||||
//static std::string trim(const std::string & str) {
|
||||
// Fails for Chinese character
|
||||
// size_t start = 0;
|
||||
// size_t end = str.size();
|
||||
// while (start < end && isspace(str[start])) {
|
||||
// start += 1;
|
||||
// }
|
||||
// while (end > start && isspace(str[end - 1])) {
|
||||
// end -= 1;
|
||||
// }
|
||||
// return str.substr(start, end - start);
|
||||
//}
|
||||
|
||||
static bool is_utf8_whitespace(uint8_t c) {
|
||||
// Basic ASCII whitespace
|
||||
if (c <= 0x7F) return isspace(c);
|
||||
// Else: Not whitespace (or you'd need a full Unicode table)
|
||||
return false;
|
||||
}
|
||||
|
||||
static std::string trim(const std::string & str) {
|
||||
size_t start = 0;
|
||||
size_t end = str.size();
|
||||
while (start < end && isspace(str[start])) {
|
||||
start += 1;
|
||||
}
|
||||
while (end > start && isspace(str[end - 1])) {
|
||||
end -= 1;
|
||||
}
|
||||
while (start < end && is_utf8_whitespace(str[start])) start++;
|
||||
while (end > start && is_utf8_whitespace(str[end - 1])) end--;
|
||||
return str.substr(start, end - start);
|
||||
}
|
||||
|
||||
@@ -400,6 +416,8 @@ enum llm_kv {
|
||||
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
|
||||
LLM_KV_TOKENIZER_HF_JSON,
|
||||
LLM_KV_TOKENIZER_RWKV,
|
||||
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
|
||||
LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
|
||||
LLM_KV_TOKENIZER_FIM_PRE_ID,
|
||||
LLM_KV_TOKENIZER_FIM_SUF_ID,
|
||||
LLM_KV_TOKENIZER_FIM_MID_ID,
|
||||
@@ -512,6 +530,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },
|
||||
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
|
||||
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
|
||||
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
|
||||
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" },
|
||||
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
|
||||
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
|
||||
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
|
||||
@@ -530,13 +550,11 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
};
|
||||
|
||||
struct LLM_KV {
|
||||
LLM_KV(llm_arch arch) : arch(arch) {}
|
||||
LLM_KV(llm_arch arch, const char* suffix = nullptr);
|
||||
|
||||
llm_arch arch;
|
||||
|
||||
std::string operator()(llm_kv kv) const {
|
||||
return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
|
||||
}
|
||||
const char* suffix;
|
||||
std::string operator()(llm_kv kv) const;
|
||||
};
|
||||
|
||||
enum llm_tensor {
|
||||
@@ -634,6 +652,13 @@ enum llm_tensor {
|
||||
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
|
||||
};
|
||||
|
||||
LLM_KV::LLM_KV(llm_arch arch, const char* suffix) : arch(arch), suffix(suffix) {}
|
||||
|
||||
std::string LLM_KV::operator()(llm_kv kv) const {
|
||||
return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix)
|
||||
: ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
|
||||
}
|
||||
|
||||
static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
|
||||
{
|
||||
LLM_ARCH_LLAMA,
|
||||
@@ -21964,6 +21989,10 @@ void llama_free(struct llama_context * ctx) {
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
const struct llama_vocab* llama_model_get_vocab(const struct llama_model* model) {
|
||||
return &model->vocab;
|
||||
}
|
||||
|
||||
const struct llama_model * llama_get_model(const struct llama_context * ctx) {
|
||||
return &ctx->model;
|
||||
}
|
||||
@@ -22152,6 +22181,24 @@ uint64_t llama_model_size(const struct llama_model * model) {
|
||||
return size;
|
||||
}
|
||||
|
||||
const char* llama_model_chat_template(const struct llama_model* model, const char* name) {
|
||||
const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE)
|
||||
: LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
|
||||
const auto& it = model->gguf_kv.find(key);
|
||||
if (it == model->gguf_kv.end()) {
|
||||
// one-off fix for very popular models (so we are not flooded with issues)
|
||||
// do not extend this list unless absolutely necessary
|
||||
// Mistral-Small-2503 does not have built-in chat template
|
||||
llama_vocab_pre_type pre_type = model->vocab.type_pre;
|
||||
if (!name && pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
|
||||
return "mistral-v7-tekken";
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
return it->second.c_str();
|
||||
}
|
||||
|
||||
uint64_t llama_model_n_params(const struct llama_model * model) {
|
||||
uint64_t nparams = 0;
|
||||
for (const auto & it : model->tensors_by_name) {
|
||||
|
||||
@@ -7,6 +7,16 @@
|
||||
|
||||
#include "llama.h"
|
||||
#include "common.h"
|
||||
#include "chat-template.hpp"
|
||||
|
||||
static std::string normalize_newlines(const std::string& s) {
|
||||
#ifdef _WIN32
|
||||
static const std::regex nl_regex("\r\n");
|
||||
return std::regex_replace(s, nl_regex, "\n");
|
||||
#else
|
||||
return s;
|
||||
#endif
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
llama_chat_message conversation[] = {
|
||||
@@ -143,10 +153,10 @@ int main(void) {
|
||||
std::vector<llama_chat_msg> chat2;
|
||||
llama_chat_msg sys_msg{"system", "You are a helpful assistant"};
|
||||
|
||||
auto fmt_sys = [&](std::string tmpl) {
|
||||
auto output = llama_chat_format_single(nullptr, tmpl, chat2, sys_msg, false);
|
||||
printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str());
|
||||
printf("-------------------------\n");
|
||||
auto fmt_sys = [&](std::string tmpl_str) {
|
||||
minja::chat_template tmpl(tmpl_str, "", "");
|
||||
auto output = llama_chat_format_single(nullptr, tmpl, chat2, sys_msg, false, /* use_jinja= */ false);
|
||||
printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
|
||||
return output;
|
||||
};
|
||||
assert(fmt_sys("chatml") == "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n");
|
||||
@@ -162,9 +172,10 @@ int main(void) {
|
||||
chat2.push_back({"assistant", "I am assistant"});
|
||||
llama_chat_msg new_msg{"user", "How are you"};
|
||||
|
||||
auto fmt_single = [&](std::string tmpl) {
|
||||
auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true);
|
||||
printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str());
|
||||
auto fmt_single = [&](std::string tmpl_str) {
|
||||
minja::chat_template tmpl(tmpl_str, "", "");
|
||||
auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true, /* use_jinja= */ false);
|
||||
printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
|
||||
printf("-------------------------\n");
|
||||
return output;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user