llama : add token matching support to llama-grammar (#1220)

* llama : add token matching support to llama-grammar

llama : add token matching support to llama-grammar (#17816)

common/grammar : replace problematic backtracking regex `[\s\S]*` (#18342)

* disable tests and fix warnings

---------

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2026-02-02 23:57:17 -06:00
committed by GitHub
parent 8ba7e2b40c
commit 7e8d444033
21 changed files with 644 additions and 916 deletions

View File

@@ -69,34 +69,10 @@ struct llama_sampling_context * common_sampler_init(const struct llama_vocab* vo
trigger_tokens.data(), trigger_tokens.size())
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
//if (!grmr) {
// return nullptr;
//}
// if there is a grammar, parse it
if (!params.grammar.empty()) {
result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
if (result->parsed_grammar.success) {
// will be empty (default) if there are parse errors
if (result->parsed_grammar.rules.empty()) {
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
delete result;
return nullptr;
}
// Ensure that there is a "root" node.
if (result->parsed_grammar.symbol_ids.find("root") == result->parsed_grammar.symbol_ids.end()) {
fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
delete result;
return nullptr;
}
if (grmr == nullptr) {
throw std::runtime_error("Failed to initialize llama_grammar");
}
}
}
result->prev.resize(params.n_prev);
result->n_valid = 0;
result->grammar_str = params.grammar;
result->grammar_root = "root";
}
result->grammar = grmr;
llama_sampling_set_rng_seed(result, params.seed);
@@ -140,71 +116,27 @@ void common_sampler_free(struct llama_sampling_context * ctx) {
delete ctx;
}
void common_sampler_reset(const struct llama_vocab* vocab, llama_sampling_context * ctx) {
if (ctx->grammar != NULL) {
llama_grammar_free(ctx->grammar);
ctx->grammar = NULL;
static void llama_grammar_reset(llama_sampling_context * ctx) {
ctx->prev.clear();
if (!ctx->grammar) {
return;
}
struct llama_grammar* grmr;
auto params = ctx->params;
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
#else
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE
std::vector<const char*> trigger_patterns_c;
trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size());
for (auto& trigger_pattern : ctx->grammar->trigger_patterns) {
trigger_patterns_c.push_back(trigger_pattern.pattern.c_str());
}
else {
std::vector<std::string> trigger_patterns;
std::vector<std::string> patterns_anywhere;
std::vector<llama_token> trigger_tokens;
for (const auto& trigger : params.grammar_triggers) {
switch (trigger.type) {
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
{
const auto& word = trigger.value;
patterns_anywhere.push_back(regex_escape(word));
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
{
patterns_anywhere.push_back(trigger.value);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
{
trigger_patterns.push_back(trigger.value);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
{
const auto token = trigger.token;
trigger_tokens.push_back(token);
break;
}
default:
GGML_ASSERT(false && "unknown trigger type");
}
}
if (!patterns_anywhere.empty()) {
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
}
std::vector<const char*> trigger_patterns_c;
trigger_patterns_c.reserve(trigger_patterns.size());
for (const auto& regex : trigger_patterns) {
trigger_patterns_c.push_back(regex.c_str());
}
auto* grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(),
ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
grmr = params.grammar_lazy
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size())
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
}
llama_grammar_free_impl(ctx->grammar);
ctx->grammar = grammar_new;
}
ctx->grammar = grmr;
void common_sampler_reset(const struct llama_vocab * vocab, llama_sampling_context * ctx) {
llama_grammar_reset(ctx);
llama_sampler_dry_reset(ctx->smpl);
}
@@ -215,13 +147,15 @@ void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t s
ctx->rng.seed(seed);
}
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
void common_sampler_clone(llama_sampling_context * src, llama_sampling_context * dst) {
if (dst->grammar) {
llama_grammar_free(dst->grammar);
dst->grammar = nullptr;
}
if (src->grammar) {
dst->grammar_root = src->grammar_root;
dst->grammar_str = src->grammar_str;
dst->grammar = llama_grammar_copy(src->grammar);
}