mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-07 15:00:11 +00:00
llama : add token matching support to llama-grammar (#1220)
* llama : add token matching support to llama-grammar llama : add token matching support to llama-grammar (#17816) common/grammar : replace problematic backtracking regex `[\s\S]*` (#18342) * disable tests and fix warnings --------- Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -69,34 +69,10 @@ struct llama_sampling_context * common_sampler_init(const struct llama_vocab* vo
|
||||
trigger_tokens.data(), trigger_tokens.size())
|
||||
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||
|
||||
//if (!grmr) {
|
||||
// return nullptr;
|
||||
//}
|
||||
|
||||
// if there is a grammar, parse it
|
||||
if (!params.grammar.empty()) {
|
||||
result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
||||
if (result->parsed_grammar.success) {
|
||||
// will be empty (default) if there are parse errors
|
||||
if (result->parsed_grammar.rules.empty()) {
|
||||
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
|
||||
delete result;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Ensure that there is a "root" node.
|
||||
if (result->parsed_grammar.symbol_ids.find("root") == result->parsed_grammar.symbol_ids.end()) {
|
||||
fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
|
||||
delete result;
|
||||
return nullptr;
|
||||
}
|
||||
if (grmr == nullptr) {
|
||||
throw std::runtime_error("Failed to initialize llama_grammar");
|
||||
}
|
||||
}
|
||||
}
|
||||
result->prev.resize(params.n_prev);
|
||||
result->n_valid = 0;
|
||||
result->grammar_str = params.grammar;
|
||||
result->grammar_root = "root";
|
||||
}
|
||||
result->grammar = grmr;
|
||||
llama_sampling_set_rng_seed(result, params.seed);
|
||||
@@ -140,71 +116,27 @@ void common_sampler_free(struct llama_sampling_context * ctx) {
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
void common_sampler_reset(const struct llama_vocab* vocab, llama_sampling_context * ctx) {
|
||||
|
||||
if (ctx->grammar != NULL) {
|
||||
llama_grammar_free(ctx->grammar);
|
||||
ctx->grammar = NULL;
|
||||
static void llama_grammar_reset(llama_sampling_context * ctx) {
|
||||
ctx->prev.clear();
|
||||
if (!ctx->grammar) {
|
||||
return;
|
||||
}
|
||||
struct llama_grammar* grmr;
|
||||
auto params = ctx->params;
|
||||
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
||||
#ifdef LLAMA_USE_LLGUIDANCE
|
||||
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
|
||||
#else
|
||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
std::vector<const char*> trigger_patterns_c;
|
||||
trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size());
|
||||
for (auto& trigger_pattern : ctx->grammar->trigger_patterns) {
|
||||
trigger_patterns_c.push_back(trigger_pattern.pattern.c_str());
|
||||
}
|
||||
else {
|
||||
std::vector<std::string> trigger_patterns;
|
||||
std::vector<std::string> patterns_anywhere;
|
||||
std::vector<llama_token> trigger_tokens;
|
||||
for (const auto& trigger : params.grammar_triggers) {
|
||||
switch (trigger.type) {
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
||||
{
|
||||
const auto& word = trigger.value;
|
||||
patterns_anywhere.push_back(regex_escape(word));
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||
{
|
||||
patterns_anywhere.push_back(trigger.value);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
||||
{
|
||||
trigger_patterns.push_back(trigger.value);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
||||
{
|
||||
const auto token = trigger.token;
|
||||
trigger_tokens.push_back(token);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown trigger type");
|
||||
}
|
||||
}
|
||||
if (!patterns_anywhere.empty()) {
|
||||
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
|
||||
}
|
||||
|
||||
std::vector<const char*> trigger_patterns_c;
|
||||
trigger_patterns_c.reserve(trigger_patterns.size());
|
||||
for (const auto& regex : trigger_patterns) {
|
||||
trigger_patterns_c.push_back(regex.c_str());
|
||||
}
|
||||
auto* grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
|
||||
ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||
ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
|
||||
|
||||
grmr = params.grammar_lazy
|
||||
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
||||
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||
trigger_tokens.data(), trigger_tokens.size())
|
||||
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||
}
|
||||
llama_grammar_free_impl(ctx->grammar);
|
||||
ctx->grammar = grammar_new;
|
||||
}
|
||||
|
||||
ctx->grammar = grmr;
|
||||
void common_sampler_reset(const struct llama_vocab * vocab, llama_sampling_context * ctx) {
|
||||
llama_grammar_reset(ctx);
|
||||
llama_sampler_dry_reset(ctx->smpl);
|
||||
}
|
||||
|
||||
@@ -215,13 +147,15 @@ void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t s
|
||||
ctx->rng.seed(seed);
|
||||
}
|
||||
|
||||
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
|
||||
void common_sampler_clone(llama_sampling_context * src, llama_sampling_context * dst) {
|
||||
if (dst->grammar) {
|
||||
llama_grammar_free(dst->grammar);
|
||||
dst->grammar = nullptr;
|
||||
}
|
||||
|
||||
if (src->grammar) {
|
||||
dst->grammar_root = src->grammar_root;
|
||||
dst->grammar_str = src->grammar_str;
|
||||
dst->grammar = llama_grammar_copy(src->grammar);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user