mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
add jinja template support (#677)
Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -15,9 +15,11 @@
|
||||
#include "json.hpp"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "llama.h"
|
||||
#include "chat-template.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <codecvt>
|
||||
#include <cstdarg>
|
||||
@@ -199,6 +201,16 @@ int32_t cpu_get_num_math() {
|
||||
return cpu_get_num_physical_cores();
|
||||
}
|
||||
|
||||
|
||||
static std::string read_file(const std::string& fname) {
|
||||
std::ifstream file(fname);
|
||||
if (!file) {
|
||||
throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
|
||||
}
|
||||
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
file.close();
|
||||
return content;
|
||||
}
|
||||
//
|
||||
// CLI argument parsing
|
||||
//
|
||||
@@ -278,6 +290,13 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
|
||||
if (!params.chat_template.empty() && !llama_chat_verify_template(nullptr, params.chat_template, params.use_jinja)) {
|
||||
throw std::runtime_error(string_format(
|
||||
"error: the supplied chat template is not supported: %s%s\n",
|
||||
params.chat_template.c_str(),
|
||||
params.use_jinja ? "" : "\nnote: llama.cpp was started without --jinja, we only support commonly used templates"
|
||||
));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1425,7 +1444,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
}
|
||||
if (arg == "--chat-template") {
|
||||
CHECK_ARG
|
||||
if (!llama_chat_verify_template(argv[i])) {
|
||||
if (!llama_chat_verify_template(nullptr, argv[i], false)) {
|
||||
fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]);
|
||||
fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n");
|
||||
invalid_param = true;
|
||||
@@ -1434,6 +1453,22 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
params.chat_template = argv[i];
|
||||
return true;
|
||||
}
|
||||
if (arg == "--chat-template-file") {
|
||||
CHECK_ARG
|
||||
std::string chat_template = read_file(std::string(argv[i]));
|
||||
if (!llama_chat_verify_template(nullptr, chat_template, false)) {
|
||||
fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]);
|
||||
fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n");
|
||||
invalid_param = true;
|
||||
return true;
|
||||
}
|
||||
params.chat_template = chat_template;
|
||||
return true;
|
||||
}
|
||||
if (arg == "--jinja") {
|
||||
params.use_jinja = true;
|
||||
return true;
|
||||
}
|
||||
if (arg == "--slot-prompt-similarity" || arg == "-sps") {
|
||||
CHECK_ARG
|
||||
params.slot_prompt_similarity = std::stof(argv[i]);
|
||||
@@ -1984,6 +2019,22 @@ std::string gpt_params_get_system_info(const gpt_params & params) {
|
||||
// String utils
|
||||
//
|
||||
|
||||
std::string string_format(const char* fmt, ...) {
|
||||
va_list ap;
|
||||
va_list ap2;
|
||||
va_start(ap, fmt);
|
||||
va_copy(ap2, ap);
|
||||
int size = vsnprintf(NULL, 0, fmt, ap);
|
||||
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
|
||||
std::vector<char> buf(size + 1);
|
||||
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
|
||||
GGML_ASSERT(size2 == size);
|
||||
va_end(ap2);
|
||||
va_end(ap);
|
||||
return std::string(buf.data(), size);
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::string> string_split(std::string input, char separator) {
|
||||
std::vector<std::string> parts;
|
||||
size_t separator_pos = input.find(separator);
|
||||
@@ -2985,6 +3036,22 @@ std::string llama_token_to_piece(const struct llama_context * ctx, llama_token t
|
||||
return piece;
|
||||
}
|
||||
|
||||
std::string llama_token_to_piece(const struct llama_model* model, llama_token token, bool special) {
|
||||
std::string piece;
|
||||
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
|
||||
const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
|
||||
if (n_chars < 0) {
|
||||
piece.resize(-n_chars);
|
||||
int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
|
||||
GGML_ASSERT(check == -n_chars);
|
||||
}
|
||||
else {
|
||||
piece.resize(n_chars);
|
||||
}
|
||||
|
||||
return piece;
|
||||
}
|
||||
|
||||
std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
|
||||
std::string text;
|
||||
text.resize(std::max(text.capacity(), tokens.size()));
|
||||
@@ -3011,50 +3078,60 @@ bool llama_should_add_bos_token(const llama_model * model) {
|
||||
// Chat template utils
|
||||
//
|
||||
|
||||
bool llama_chat_verify_template(const std::string & tmpl) {
|
||||
bool llama_chat_verify_template(const struct llama_model* model, const std::string& tmpl, bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
try {
|
||||
auto chat_template = minja::chat_template(tmpl, "<s>", "</s>");
|
||||
chat_template.apply({ {
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
} }, json(), true);
|
||||
return true;
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
fprintf(stdout,"%s: failed to apply template: %s\n", __func__, e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
llama_chat_message chat[] = {{"user", "test"}};
|
||||
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
|
||||
const int res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
|
||||
return res >= 0;
|
||||
}
|
||||
|
||||
std::string llama_chat_apply_template(const struct llama_model * model,
|
||||
const std::string & tmpl,
|
||||
const common_chat_template& tmpl,
|
||||
const std::vector<llama_chat_msg> & msgs,
|
||||
bool add_ass) {
|
||||
bool add_ass,
|
||||
bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
auto messages = json::array();
|
||||
for (const auto& msg : msgs) {
|
||||
messages.push_back({ {"role", msg.role}, {"content", msg.content} });
|
||||
}
|
||||
return tmpl.apply(messages, /* tools= */ json(), add_ass);
|
||||
}
|
||||
int alloc_size = 0;
|
||||
bool fallback = false; // indicate if we must fallback to default chatml
|
||||
std::vector<llama_chat_message> chat;
|
||||
for (auto & msg : msgs) {
|
||||
chat.push_back({msg.role.c_str(), msg.content.c_str()});
|
||||
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
|
||||
}
|
||||
|
||||
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
|
||||
std::vector<char> buf(alloc_size);
|
||||
|
||||
// run the first time to get the total output length
|
||||
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
|
||||
int32_t res = llama_chat_apply_template(model, tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
// error: chat template is not supported
|
||||
if (res < 0) {
|
||||
if (ptr_tmpl != nullptr) {
|
||||
// if the custom "tmpl" is not supported, we throw an error
|
||||
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
||||
throw std::runtime_error("this custom template is not supported");
|
||||
} else {
|
||||
// If the built-in template is not supported, we default to chatml
|
||||
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
fallback = true;
|
||||
}
|
||||
}
|
||||
|
||||
// if it turns out that our buffer is too small, we resize it
|
||||
if ((size_t) res > buf.size()) {
|
||||
buf.resize(res);
|
||||
res = llama_chat_apply_template(
|
||||
fallback ? nullptr : model,
|
||||
fallback ? "chatml" : ptr_tmpl,
|
||||
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
res = llama_chat_apply_template(model, tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
}
|
||||
|
||||
std::string formatted_chat(buf.data(), res);
|
||||
@@ -3062,12 +3139,13 @@ std::string llama_chat_apply_template(const struct llama_model * model,
|
||||
}
|
||||
|
||||
std::string llama_chat_format_single(const struct llama_model * model,
|
||||
const std::string & tmpl,
|
||||
const common_chat_template& tmpl,
|
||||
const std::vector<llama_chat_msg> & past_msg,
|
||||
const llama_chat_msg & new_msg,
|
||||
bool add_ass) {
|
||||
bool add_ass,
|
||||
bool use_jinja) {
|
||||
std::ostringstream ss;
|
||||
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false);
|
||||
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja);
|
||||
std::vector<llama_chat_msg> chat_new(past_msg);
|
||||
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
||||
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
||||
@@ -3075,21 +3153,77 @@ std::string llama_chat_format_single(const struct llama_model * model,
|
||||
};
|
||||
// format chat with new_msg
|
||||
chat_new.push_back(new_msg);
|
||||
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass);
|
||||
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja);
|
||||
// get the diff part
|
||||
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
std::string llama_chat_format_example(const struct llama_model * model,
|
||||
const std::string & tmpl) {
|
||||
std::string llama_chat_format_example(const struct llama_model * model, const common_chat_template& tmpl, bool use_jinja) {
|
||||
std::vector<llama_chat_msg> msgs = {
|
||||
{"system", "You are a helpful assistant"},
|
||||
{"user", "Hello"},
|
||||
{"assistant", "Hi there"},
|
||||
{"user", "How are you?"},
|
||||
};
|
||||
return llama_chat_apply_template(model, tmpl, msgs, true);
|
||||
return llama_chat_apply_template(model, tmpl, msgs, true, use_jinja);
|
||||
}
|
||||
|
||||
|
||||
common_chat_templates llama_chat_templates_from_model(const struct llama_model* model, const std::string& chat_template_override)
|
||||
{
|
||||
auto vocab = llama_model_get_vocab(model);
|
||||
std::string default_template_src = chat_template_override;
|
||||
std::string template_tool_use_src = chat_template_override;
|
||||
bool has_explicit_template = !chat_template_override.empty();
|
||||
if (chat_template_override.empty()) {
|
||||
auto str = llama_model_chat_template(model, /* name */ nullptr);
|
||||
if (str) {
|
||||
default_template_src = str;
|
||||
has_explicit_template = true;
|
||||
}
|
||||
str = llama_model_chat_template(model, /* name */ "tool_use");
|
||||
if (str) {
|
||||
template_tool_use_src = str;
|
||||
has_explicit_template = true;
|
||||
}
|
||||
}
|
||||
if (default_template_src.empty() || default_template_src == "chatml") {
|
||||
if (!template_tool_use_src.empty()) {
|
||||
default_template_src = template_tool_use_src;
|
||||
}
|
||||
else {
|
||||
default_template_src = R"(
|
||||
{%- for message in messages -%}
|
||||
{{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{- "<|im_start|>assistant\n" -}}
|
||||
{%- endif -%}
|
||||
)";
|
||||
}
|
||||
}
|
||||
const auto get_token = [&](llama_token token, const char* name, const char* jinja_variable_name) {
|
||||
if (token == LLAMA_TOKEN_NULL) {
|
||||
if (default_template_src.find(jinja_variable_name) != std::string::npos
|
||||
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
|
||||
fprintf(stdout, "%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
|
||||
}
|
||||
return std::string();
|
||||
}
|
||||
else {
|
||||
return llama_token_to_piece(model, token, true);
|
||||
}
|
||||
};
|
||||
auto token_bos = get_token(llama_token_bos(model), "BOS", "bos_token");
|
||||
auto token_eos = get_token(llama_token_eos(model), "EOS", "eos_token");
|
||||
return {
|
||||
has_explicit_template,
|
||||
std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
|
||||
template_tool_use_src.empty()
|
||||
? nullptr
|
||||
: std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos)
|
||||
};
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
Reference in New Issue
Block a user