From 7e8d4440338ddfb6b53b500549090e1f30be55f4 Mon Sep 17 00:00:00 2001 From: firecoperana <18252262+firecoperana@users.noreply.github.com> Date: Mon, 2 Feb 2026 23:57:17 -0600 Subject: [PATCH] llama : add token matching support to llama-grammar (#1220) * llama : add token matching support to llama-grammar llama : add token matching support to llama-grammar (#17816) common/grammar : replace problematic backtracking regex `[\s\S]*` (#18342) * disable tests and fix warnings --------- Co-authored-by: firecoperana --- common/CMakeLists.txt | 2 - common/chat.cpp | 8 +- common/grammar-parser.cpp | 542 --------------------- common/grammar-parser.h | 30 -- common/regex-partial.cpp | 26 +- common/sampling.cpp | 108 +--- common/sampling.h | 10 +- examples/gbnf-validator/gbnf-validator.cpp | 21 +- examples/infill/infill.cpp | 2 +- examples/speculative/speculative.cpp | 4 +- grammars/README.md | 24 + include/llama.h | 73 +-- src/llama-grammar.cpp | 408 ++++++++++++---- src/llama-grammar.h | 96 +++- src/llama-sampling.cpp | 2 +- src/llama.cpp | 27 +- tests/CMakeLists.txt | 22 +- tests/test-grammar-integration.cpp | 111 ++++- tests/test-grammar-parser.cpp | 14 + tests/test-llama-grammar.cpp | 2 +- tests/test-regex-partial.cpp | 28 +- 21 files changed, 644 insertions(+), 916 deletions(-) delete mode 100644 common/grammar-parser.cpp delete mode 100644 common/grammar-parser.h diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 4d3f462b..e1992589 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -63,8 +63,6 @@ add_library(${TARGET} STATIC sampling.cpp console.h console.cpp - grammar-parser.h - grammar-parser.cpp json-partial.h json-partial.cpp llguidance.cpp diff --git a/common/chat.cpp b/common/chat.cpp index 850340b8..34a48ea5 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1519,7 +1519,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp // Trigger on tool calls that appear in the commentary channel data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, - "<\\|channel\\|>(commentary|analysis) to" + "<\\|channel\\|>(?:commentary|analysis) to" }); // Trigger tool calls that appear in the role section, either at the @@ -1850,17 +1850,17 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat (inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call)); // Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives) data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, // If thinking_forced_open, then we capture the tag in the grammar, // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) - std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)" : "(?:[\\s\\S]*?\\s*)?") + ( + std::string(data.thinking_forced_open ? "(\\s*)" : "") + ( "\\s*(" "(?:" "||||)?" "\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\"" ")" - ")[\\s\\S]*" + ")" ), }); data.preserved_tokens = { diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp deleted file mode 100644 index d7b0fcba..00000000 --- a/common/grammar-parser.cpp +++ /dev/null @@ -1,542 +0,0 @@ -#include "grammar-parser.h" -#include -#include -#include -#include -#include -#include - -namespace grammar_parser { - // NOTE: assumes valid utf8 (but checks for overrun) - // copied from llama.cpp - static std::pair decode_utf8(const char * src) { - static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; - uint8_t first_byte = static_cast(*src); - uint8_t highbits = first_byte >> 4; - int len = lookup[highbits]; - uint8_t mask = (1 << (8 - len)) - 1; - uint32_t value = first_byte & mask; - const char * end = src + len; // may overrun! - const char * pos = src + 1; - for ( ; pos < end && *pos; pos++) { - value = (value << 6) + (static_cast(*pos) & 0x3F); - } - return std::make_pair(value, pos); - } - - static uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { - uint32_t next_id = static_cast(state.symbol_ids.size()); - auto result = state.symbol_ids.emplace(std::string(src, len), next_id); - return result.first->second; - } - - static uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { - uint32_t next_id = static_cast(state.symbol_ids.size()); - state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; - return next_id; - } - - static void add_rule( - parse_state & state, - uint32_t rule_id, - const std::vector & rule) { - if (state.rules.size() <= rule_id) { - state.rules.resize(rule_id + 1); - } - state.rules[rule_id] = rule; - } - - static bool is_digit_char(char c) { - return '0' <= c && c <= '9'; - } - - static bool is_word_char(char c) { - return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c); - } - - static std::pair parse_hex(const char * src, int size) { - const char * pos = src; - const char * end = src + size; - uint32_t value = 0; - for ( ; pos < end && *pos; pos++) { - value <<= 4; - char c = *pos; - if ('a' <= c && c <= 'f') { - value += c - 'a' + 10; - } else if ('A' <= c && c <= 'F') { - value += c - 'A' + 10; - } else if ('0' <= c && c <= '9') { - value += c - '0'; - } else { - break; - } - } - if (pos != end) { - throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); - } - return std::make_pair(value, pos); - } - - static const char * parse_space(const char * src, bool newline_ok) { - const char * pos = src; - while (*pos == ' ' || *pos == '\t' || *pos == '#' || - (newline_ok && (*pos == '\r' || *pos == '\n'))) { - if (*pos == '#') { - while (*pos && *pos != '\r' && *pos != '\n') { - pos++; - } - } else { - pos++; - } - } - return pos; - } - - static const char * parse_name(const char * src) { - const char * pos = src; - while (is_word_char(*pos)) { - pos++; - } - if (pos == src) { - throw std::runtime_error(std::string("expecting name at ") + src); - } - return pos; - } - - static const char * parse_int(const char * src) { - const char * pos = src; - while (is_digit_char(*pos)) { - pos++; - } - if (pos == src) { - throw std::runtime_error(std::string("expecting integer at ") + src); - } - return pos; - } - - static std::pair parse_char(const char * src) { - if (*src == '\\') { - switch (src[1]) { - case 'x': return parse_hex(src + 2, 2); - case 'u': return parse_hex(src + 2, 4); - case 'U': return parse_hex(src + 2, 8); - case 't': return std::make_pair('\t', src + 2); - case 'r': return std::make_pair('\r', src + 2); - case 'n': return std::make_pair('\n', src + 2); - case '\\': - case '"': - case '[': - case ']': - return std::make_pair(src[1], src + 2); - default: - throw std::runtime_error(std::string("unknown escape at ") + src); - } - } else if (*src) { - return decode_utf8(src); - } - throw std::runtime_error("unexpected end of input"); - } - - const char * parse_alternates( - parse_state & state, - const char * src, - const std::string & rule_name, - uint32_t rule_id, - bool is_nested); - - static const char * parse_sequence( - parse_state & state, - const char * src, - const std::string & rule_name, - std::vector & out_elements, - bool is_nested) { - size_t last_sym_start = out_elements.size(); - const char * pos = src; - - auto handle_repetitions = [&](int min_times, int max_times) { - - if (last_sym_start == out_elements.size()) { - throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); - } - - // apply transformation to previous symbol (last_sym_start to end) according to - // the following rewrite rules: - // S{m,n} --> S S S (m times) S'(n-m) - // S'(x) ::= S S'(x-1) | - // (... n-m definitions of these S' rules ...) - // S'(1) ::= S | - // S{m,} --> S S S (m times) S' - // S' ::= S S' | - // S* --> S{0,} - // --> S' ::= S S' | - // S+ --> S{1,} - // --> S S' - // S' ::= S S' | - // S? --> S{0,1} - // --> S' - // S' ::= S | - - std::vector previous_elements(out_elements.begin() + last_sym_start, out_elements.end()); - if (min_times == 0) { - out_elements.resize(last_sym_start); - } else { - // Repeat the previous elements (min_times - 1) times - for (int i = 1; i < min_times; i++) { - out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end()); - } - } - - uint32_t last_rec_rule_id = 0; - auto n_opt = max_times < 0 ? 1 : max_times - min_times; - - std::vector rec_rule(previous_elements); - for (int i = 0; i < n_opt; i++) { - rec_rule.resize(previous_elements.size()); - uint32_t rec_rule_id = generate_symbol_id(state, rule_name); - if (i > 0 || max_times < 0) { - rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); - } - rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); - rec_rule.push_back({LLAMA_GRETYPE_END, 0}); - add_rule(state, rec_rule_id, rec_rule); - last_rec_rule_id = rec_rule_id; - } - if (n_opt > 0) { - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); - } - }; - - while (*pos) { - if (*pos == '"') { // literal string - pos++; - last_sym_start = out_elements.size(); - while (*pos != '"') { - if (!*pos) { - throw std::runtime_error("unexpected end of input"); - } - auto char_pair = parse_char(pos); - pos = char_pair.second; - out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); - } - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '[') { // char range(s) - pos++; - enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; - if (*pos == '^') { - pos++; - start_type = LLAMA_GRETYPE_CHAR_NOT; - } - last_sym_start = out_elements.size(); - while (*pos != ']') { - if (!*pos) { - throw std::runtime_error("unexpected end of input"); - } - auto char_pair = parse_char(pos); - pos = char_pair.second; - enum llama_gretype type = last_sym_start < out_elements.size() - ? LLAMA_GRETYPE_CHAR_ALT - : start_type; - - out_elements.push_back({type, char_pair.first}); - if (pos[0] == '-' && pos[1] != ']') { - if (!pos[1]) { - throw std::runtime_error("unexpected end of input"); - } - auto endchar_pair = parse_char(pos + 1); - pos = endchar_pair.second; - out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); - } - } - pos = parse_space(pos + 1, is_nested); - } else if (is_word_char(*pos)) { // rule reference - const char * name_end = parse_name(pos); - uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); - pos = parse_space(name_end, is_nested); - last_sym_start = out_elements.size(); - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); - } else if (*pos == '(') { // grouping - // parse nested alternates into synthesized rule - pos = parse_space(pos + 1, true); - uint32_t sub_rule_id = generate_symbol_id(state, rule_name); - pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); - last_sym_start = out_elements.size(); - // output reference to synthesized rule - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); - if (*pos != ')') { - throw std::runtime_error(std::string("expecting ')' at ") + pos); - } - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '.') { // any char - last_sym_start = out_elements.size(); - out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '*') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(0, -1); - } else if (*pos == '+') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(1, -1); - } else if (*pos == '?') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(0, 1); - } else if (*pos == '{') { - pos = parse_space(pos + 1, is_nested); - - if (!is_digit_char(*pos)) { - throw std::runtime_error(std::string("expecting an int at ") + pos); - } - const char * int_end = parse_int(pos); - int min_times = std::stoul(std::string(pos, int_end - pos)); - pos = parse_space(int_end, is_nested); - - int max_times = -1; - - if (*pos == '}') { - max_times = min_times; - pos = parse_space(pos + 1, is_nested); - } else if (*pos == ',') { - pos = parse_space(pos + 1, is_nested); - - if (is_digit_char(*pos)) { - const char * int_end = parse_int(pos); - max_times = std::stoul(std::string(pos, int_end - pos)); - pos = parse_space(int_end, is_nested); - } - - if (*pos != '}') { - throw std::runtime_error(std::string("expecting '}' at ") + pos); - } - pos = parse_space(pos + 1, is_nested); - } else { - throw std::runtime_error(std::string("expecting ',' at ") + pos); - } - handle_repetitions(min_times, max_times); - } else { - break; - } - } - return pos; - } - - const char * parse_alternates( - parse_state & state, - const char * src, - const std::string & rule_name, - uint32_t rule_id, - bool is_nested) { - std::vector rule; - const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); - while (*pos == '|') { - rule.push_back({LLAMA_GRETYPE_ALT, 0}); - pos = parse_space(pos + 1, true); - pos = parse_sequence(state, pos, rule_name, rule, is_nested); - } - rule.push_back({LLAMA_GRETYPE_END, 0}); - add_rule(state, rule_id, rule); - return pos; - } - - static const char * parse_rule(parse_state & state, const char * src) { - const char * name_end = parse_name(src); - const char * pos = parse_space(name_end, false); - size_t name_len = name_end - src; - uint32_t rule_id = get_symbol_id(state, src, name_len); - const std::string name(src, name_len); - - if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { - throw std::runtime_error(std::string("expecting ::= at ") + pos); - } - pos = parse_space(pos + 3, true); - - pos = parse_alternates(state, pos, name, rule_id, false); - - if (*pos == '\r') { - pos += pos[1] == '\n' ? 2 : 1; - } else if (*pos == '\n') { - pos++; - } else if (*pos) { - throw std::runtime_error(std::string("expecting newline or end at ") + pos); - } - return parse_space(pos, true); - } - - parse_state parse(const char * src) { - try { - parse_state state; - const char * pos = parse_space(src, true); - while (*pos) { - pos = parse_rule(state, pos); - } - // Validate the state to ensure that all rules are defined - for (const auto & rule : state.rules) { - if (rule.empty()) { - throw std::runtime_error("Undefined rule"); - } - for (const auto & elem : rule) { - if (elem.type == LLAMA_GRETYPE_RULE_REF) { - // Ensure that the rule at that location exists - if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) { - // Get the name of the rule that is missing - for (const auto & kv : state.symbol_ids) { - if (kv.second == elem.value) { - throw std::runtime_error("Undefined rule identifier '" + kv.first + "'"); - } - } - } - } - } - } - state.success = true; - return state; - } catch (const std::exception & err) { - fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src); - parse_state state; - state.success = false; - return state; - } - } - - static void print_grammar_char(FILE * file, uint32_t c) { - if (0x20 <= c && c <= 0x7f) { - fprintf(file, "%c", static_cast(c)); - } else { - // cop out of encoding UTF-8 - fprintf(file, "", c); - } - } - - static bool is_char_element(llama_grammar_element elem) { - switch (elem.type) { - case LLAMA_GRETYPE_CHAR: return true; - case LLAMA_GRETYPE_CHAR_NOT: return true; - case LLAMA_GRETYPE_CHAR_ALT: return true; - case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; - case LLAMA_GRETYPE_CHAR_ANY: return true; - default: return false; - } - } - - static void print_rule_binary(FILE * file, const std::vector & rule) { - for (auto elem : rule) { - switch (elem.type) { - case LLAMA_GRETYPE_END: fprintf(file, "END"); break; - case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; - case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; - case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; - case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; - case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; - case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; - case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; - } - switch (elem.type) { - case LLAMA_GRETYPE_END: - case LLAMA_GRETYPE_ALT: - case LLAMA_GRETYPE_RULE_REF: - fprintf(file, "(%u) ", elem.value); - break; - case LLAMA_GRETYPE_CHAR: - case LLAMA_GRETYPE_CHAR_NOT: - case LLAMA_GRETYPE_CHAR_RNG_UPPER: - case LLAMA_GRETYPE_CHAR_ALT: - case LLAMA_GRETYPE_CHAR_ANY: - fprintf(file, "(\""); - print_grammar_char(file, elem.value); - fprintf(file, "\") "); - break; - } - } - fprintf(file, "\n"); - } - - static void print_rule( - FILE * file, - uint32_t rule_id, - const std::vector & rule, - const std::map & symbol_id_names) { - if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { - throw std::runtime_error( - "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); - } - fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); - for (size_t i = 0, end = rule.size() - 1; i < end; i++) { - llama_grammar_element elem = rule[i]; - switch (elem.type) { - case LLAMA_GRETYPE_END: - throw std::runtime_error( - "unexpected end of rule: " + std::to_string(rule_id) + "," + - std::to_string(i)); - case LLAMA_GRETYPE_ALT: - fprintf(file, "| "); - break; - case LLAMA_GRETYPE_RULE_REF: - fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); - break; - case LLAMA_GRETYPE_CHAR: - fprintf(file, "["); - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_NOT: - fprintf(file, "[^"); - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_RNG_UPPER: - if (i == 0 || !is_char_element(rule[i - 1])) { - throw std::runtime_error( - "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + - std::to_string(rule_id) + "," + std::to_string(i)); - } - fprintf(file, "-"); - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_ALT: - if (i == 0 || !is_char_element(rule[i - 1])) { - throw std::runtime_error( - "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + - std::to_string(rule_id) + "," + std::to_string(i)); - } - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_ANY: - fprintf(file, "."); - break; - } - if (is_char_element(elem)) { - switch (rule[i + 1].type) { - case LLAMA_GRETYPE_CHAR_ALT: - case LLAMA_GRETYPE_CHAR_RNG_UPPER: - case LLAMA_GRETYPE_CHAR_ANY: - break; - default: - fprintf(file, "] "); - } - } - } - fprintf(file, "\n"); - } - - void print_grammar(FILE * file, const parse_state & state) { - try { - std::map symbol_id_names; - for (const auto & kv : state.symbol_ids) { - symbol_id_names[kv.second] = kv.first; - } - for (size_t i = 0, end = state.rules.size(); i < end; i++) { - // fprintf(file, "%zu: ", i); - // print_rule_binary(file, state.rules[i]); - print_rule(file, uint32_t(i), state.rules[i], symbol_id_names); - // fprintf(file, "\n"); - } - } catch (const std::exception & err) { - fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); - } - } - - std::vector parse_state::c_rules() { - std::vector ret; - ret.reserve(rules.size()); - for (const auto & rule : rules) { - ret.push_back(rule.data()); - } - return ret; - } -} diff --git a/common/grammar-parser.h b/common/grammar-parser.h deleted file mode 100644 index 3939bc30..00000000 --- a/common/grammar-parser.h +++ /dev/null @@ -1,30 +0,0 @@ -// Implements a parser for an extended Backus-Naur form (BNF), producing the -// binary context-free grammar format specified by llama.h. Supports character -// ranges, grouping, and repetition operators. As an example, a grammar for -// arithmetic might look like: -// -// root ::= expr -// expr ::= term ([-+*/] term)* -// term ::= num | "(" space expr ")" space -// num ::= [0-9]+ space -// space ::= [ \t\n]* - -#pragma once -#include "llama.h" -#include -#include -#include -#include - -namespace grammar_parser { - struct parse_state { - std::map symbol_ids; - std::vector> rules; - - std::vector c_rules(); - bool success; - }; - - parse_state parse(const char * src); - void print_grammar(FILE * file, const parse_state & state); -} diff --git a/common/regex-partial.cpp b/common/regex-partial.cpp index 4bff6b66..e667a209 100644 --- a/common/regex-partial.cpp +++ b/common/regex-partial.cpp @@ -27,7 +27,7 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b return res; } std::match_results srmatch; - if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) { + if (std::regex_search(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial, std::regex_constants::match_continuous)) { auto group = srmatch[1].str(); if (group.length() != 0) { auto it = srmatch[1].second.base(); @@ -55,18 +55,18 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b to see if a string ends with a partial regex match, but but it's not in std::regex yet. Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input. - - /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).* - - /a|b/ -> (a|b).* + - /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:(?:d)?c)?b)?a) + - /a|b/ -> ^(a|b) - /a*?/ -> error, could match "" - - /a*b/ -> ((?:b)?a*+).* (final repetitions become eager) - - /.*?ab/ -> ((?:b)?a).* (merge .*) - - /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches) - - /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).* - - /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).* - - /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).* + - /a*b/ -> ^((?:b)?a*+) (final repetitions become eager) + - /.*?ab/ -> ^((?:b)?a) (omit .*) + - /a.*?b/ -> ^((?:b)?.*?a) (keep reluctant matches) + - /a(bc)d/ -> ^((?:(?:d)?(?:(?:c)?b))?a) + - /a(bc|de)/ -> ^((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a) + - /ab{2,4}c/ -> ^cbbb?b?a -> ^((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a) - The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern - (i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored) + The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern. + All other groups are turned into non-capturing groups, and reluctant quantifiers are ignored. */ std::string regex_to_reversed_partial_regex(const std::string & pattern) { auto it = pattern.begin(); @@ -177,7 +177,7 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) { } } - // /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).* + // /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:d)?c)?b)?a) // if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group // We'll do the outermost capturing group and final .* in the enclosing function. std::vector res_alts; @@ -200,5 +200,5 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) { throw std::runtime_error("Unmatched '(' in pattern"); } - return "(" + res + ")[\\s\\S]*"; + return "^(" + res + ")"; } diff --git a/common/sampling.cpp b/common/sampling.cpp index cfec27fd..7c5accf4 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -69,34 +69,10 @@ struct llama_sampling_context * common_sampler_init(const struct llama_vocab* vo trigger_tokens.data(), trigger_tokens.size()) : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); - //if (!grmr) { - // return nullptr; - //} - - // if there is a grammar, parse it - if (!params.grammar.empty()) { - result->parsed_grammar = grammar_parser::parse(params.grammar.c_str()); - if (result->parsed_grammar.success) { - // will be empty (default) if there are parse errors - if (result->parsed_grammar.rules.empty()) { - fprintf(stderr, "%s: failed to parse grammar\n", __func__); - delete result; - return nullptr; - } - - // Ensure that there is a "root" node. - if (result->parsed_grammar.symbol_ids.find("root") == result->parsed_grammar.symbol_ids.end()) { - fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__); - delete result; - return nullptr; - } - if (grmr == nullptr) { - throw std::runtime_error("Failed to initialize llama_grammar"); - } - } - } result->prev.resize(params.n_prev); result->n_valid = 0; + result->grammar_str = params.grammar; + result->grammar_root = "root"; } result->grammar = grmr; llama_sampling_set_rng_seed(result, params.seed); @@ -140,71 +116,27 @@ void common_sampler_free(struct llama_sampling_context * ctx) { delete ctx; } -void common_sampler_reset(const struct llama_vocab* vocab, llama_sampling_context * ctx) { - - if (ctx->grammar != NULL) { - llama_grammar_free(ctx->grammar); - ctx->grammar = NULL; +static void llama_grammar_reset(llama_sampling_context * ctx) { + ctx->prev.clear(); + if (!ctx->grammar) { + return; } - struct llama_grammar* grmr; - auto params = ctx->params; - if (params.grammar.compare(0, 11, "%llguidance") == 0) { -#ifdef LLAMA_USE_LLGUIDANCE - grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()); -#else - GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); -#endif // LLAMA_USE_LLGUIDANCE + std::vector trigger_patterns_c; + trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size()); + for (auto& trigger_pattern : ctx->grammar->trigger_patterns) { + trigger_patterns_c.push_back(trigger_pattern.pattern.c_str()); } - else { - std::vector trigger_patterns; - std::vector patterns_anywhere; - std::vector trigger_tokens; - for (const auto& trigger : params.grammar_triggers) { - switch (trigger.type) { - case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: - { - const auto& word = trigger.value; - patterns_anywhere.push_back(regex_escape(word)); - break; - } - case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: - { - patterns_anywhere.push_back(trigger.value); - break; - } - case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL: - { - trigger_patterns.push_back(trigger.value); - break; - } - case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN: - { - const auto token = trigger.token; - trigger_tokens.push_back(token); - break; - } - default: - GGML_ASSERT(false && "unknown trigger type"); - } - } - if (!patterns_anywhere.empty()) { - trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*"); - } - std::vector trigger_patterns_c; - trigger_patterns_c.reserve(trigger_patterns.size()); - for (const auto& regex : trigger_patterns) { - trigger_patterns_c.push_back(regex.c_str()); - } + auto* grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(), + ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(), + ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size()); - grmr = params.grammar_lazy - ? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root", - trigger_patterns_c.data(), trigger_patterns_c.size(), - trigger_tokens.data(), trigger_tokens.size()) - : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); - } + llama_grammar_free_impl(ctx->grammar); + ctx->grammar = grammar_new; +} - ctx->grammar = grmr; +void common_sampler_reset(const struct llama_vocab * vocab, llama_sampling_context * ctx) { + llama_grammar_reset(ctx); llama_sampler_dry_reset(ctx->smpl); } @@ -215,13 +147,15 @@ void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t s ctx->rng.seed(seed); } -void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { +void common_sampler_clone(llama_sampling_context * src, llama_sampling_context * dst) { if (dst->grammar) { llama_grammar_free(dst->grammar); dst->grammar = nullptr; } if (src->grammar) { + dst->grammar_root = src->grammar_root; + dst->grammar_str = src->grammar_str; dst->grammar = llama_grammar_copy(src->grammar); } diff --git a/common/sampling.h b/common/sampling.h index f2d1b1bf..6ccfca6a 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -1,7 +1,7 @@ #pragma once #include "llama.h" -#include "grammar-parser.h" +#include "llama-grammar.h" #include #include #include @@ -113,10 +113,10 @@ struct llama_sampling_context { // mirostat sampler state float mirostat_mu; - llama_grammar * grammar; + std::string grammar_str; + std::string grammar_root; - // internal - grammar_parser::parse_state parsed_grammar; + llama_grammar * grammar; // TODO: replace with ring-buffer std::vector prev; @@ -148,7 +148,7 @@ void common_sampler_reset(const struct llama_vocab* vocab, llama_sampling_contex void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed); // Copy the sampler context -void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); +void common_sampler_clone(llama_sampling_context * src, llama_sampling_context * dst); // Get the last sampled token llama_token llama_sampling_last(llama_sampling_context * ctx); diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index df968413..ae22553c 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -1,6 +1,6 @@ #define LLAMA_API_INTERNAL -#include "grammar-parser.h" +#include "llama-grammar.h" #include "ggml.h" #include "llama.h" #include "unicode.h" @@ -77,27 +77,30 @@ int main(int argc, char** argv) { grammar_str = buffer.str(); } + // Parse the GBNF grammar - auto parsed_grammar = grammar_parser::parse(grammar_str.c_str()); + llama_grammar_parser parser; + auto parsed_grammar = parser.parse(grammar_str.c_str()); // will be empty (default) if there are parse errors - if (parsed_grammar.rules.empty()) { - fprintf(stdout, "%s: failed to parse grammar\n", __func__); + if (!parser.parse(grammar_str.c_str()) || parser.rules.empty()) { + fprintf(stderr, "%s: failed to parse grammar\n", __func__); return 1; } // Ensure that there is a "root" node. - if (parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()) { - fprintf(stdout, "%s: grammar does not contain a 'root' symbol\n", __func__); + if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) { + fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__); return 1; } - std::vector grammar_rules(parsed_grammar.c_rules()); + std::vector grammar_rules(parser.c_rules()); // Create the LLAMA grammar - auto grammar = llama_grammar_init( + auto grammar = llama_grammar_init_impl( grammar_rules.data(), - grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + grammar_rules.size(), parser.symbol_ids.at("root")); + if (grammar == nullptr) { throw std::runtime_error("Failed to initialize llama_grammar"); } diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index f98b2aba..60eed056 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -2,7 +2,7 @@ #include "console.h" #include "llama.h" -#include "grammar-parser.h" +#include "llama-grammar.h" #include #include diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 8fb1c7d3..487b02a8 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -434,7 +434,7 @@ int main(int argc, char ** argv) { break; } - llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling); + common_sampler_clone(ctx_sampling, drafts[0].ctx_sampling); int n_seq_cur = 1; int n_past_cur = n_past_dft; @@ -503,7 +503,7 @@ int main(int argc, char ** argv) { drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft; drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt; - llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling); + common_sampler_clone(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling); sa.push_back(n_seq_cur); diff --git a/grammars/README.md b/grammars/README.md index 01b02abb..873925c9 100644 --- a/grammars/README.md +++ b/grammars/README.md @@ -67,6 +67,30 @@ Parentheses `()` can be used to group sequences, which allows for embedding alte - `{m,n}` repeats the precedent symbol or sequence at between `m` and `n` times (included) - `{0,n}` repeats the precedent symbol or sequence at most `n` times (included) +## Tokens + +Tokens allow grammars to match specific tokenizer tokens rather than character sequences. This is useful for constraining outputs based on special tokens (like `` or ``). + +Tokens can be specified in two ways: + +1. **Token ID**: Use angle brackets with the token ID in square brackets: `<[token-id]>`. For example, `<[1000]>` matches the token with ID 1000. + +2. **Token string**: Use angle brackets with the token text directly: ``. For example, `` will match the token whose text is exactly ``. This only works if the string tokenizes to exactly one token in the vocabulary, otherwise the grammar will fail to parse. + +You can negate token matches using the `!` prefix: `!<[1000]>` or `!` matches any token *except* the specified one. + +``` +# Match a thinking block: ... +# Using token strings (requires these to be single tokens in the vocab) +root ::= thinking .* +thinking ::= !* + +# Equivalent grammar using explicit token IDs +# Assumes token 1000 = , token 1001 = +root ::= <[1000]> thinking <[1001]> .* +thinking ::= !<[1001]>* +``` + ## Comments and newlines Comments can be specified with `#`: diff --git a/include/llama.h b/include/llama.h index a0a8e3ac..73c8c293 100644 --- a/include/llama.h +++ b/include/llama.h @@ -489,39 +489,7 @@ extern "C" { // grammar types struct llama_grammar; - // grammar element type - enum llama_gretype { - // end of rule definition - LLAMA_GRETYPE_END = 0, - // start of alternate definition for rule - LLAMA_GRETYPE_ALT = 1, - - // non-terminal element: reference to rule - LLAMA_GRETYPE_RULE_REF = 2, - - // terminal element: character (code point) - LLAMA_GRETYPE_CHAR = 3, - - // inverse char(s) ([^a], [^a-b] [^abc]) - LLAMA_GRETYPE_CHAR_NOT = 4, - - // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to - // be an inclusive range ([a-z]) - LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, - - // modifies a preceding LLAMA_GRETYPE_CHAR or - // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) - LLAMA_GRETYPE_CHAR_ALT = 6, - - // any character (.) - LLAMA_GRETYPE_CHAR_ANY = 7, - }; - - typedef struct llama_grammar_element { - enum llama_gretype type; - uint32_t value; // Unicode code point or rule ID - } llama_grammar_element; // performance timing information struct llama_timings { @@ -1194,10 +1162,10 @@ extern "C" { /// @param n_rules The number of rules. /// @param start_rule_index The index of the root rule (the starting point of the grammar). /// @return The initialized llama_grammar or nullptr if initialization failed. - LLAMA_API struct llama_grammar * llama_grammar_init( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index); + //LLAMA_API struct llama_grammar * llama_grammar_init( + // const llama_grammar_element ** rules, + // size_t n_rules, + // size_t start_rule_index); struct llama_sampler_grammar; LLAMA_API void llama_grammar_init_lazy(struct llama_sampler_grammar * grammar); @@ -1489,39 +1457,6 @@ const std::vector> & llama_internal struct llama_context * ctx ); -struct llama_partial_utf8 { - uint32_t value; // bit value so far (unshifted) - int n_remain; // num bytes remaining; -1 indicates invalid sequence -}; - -struct llama_grammar_candidate { - size_t index; - const uint32_t * code_points; - llama_partial_utf8 partial_utf8; -}; - -using llama_grammar_rule = std::vector< llama_grammar_element>; -using llama_grammar_stack = std::vector; - -using llama_grammar_rules = std::vector; -using llama_grammar_stacks = std::vector; -using llama_grammar_candidates = std::vector; - -const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); - llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); - -void llama_grammar_accept(struct llama_grammar* grammar, uint32_t chr); - -std::vector llama_grammar_reject_candidates_for_stack( - const llama_grammar_rules & rules, - const llama_grammar_stack & stack, - const llama_grammar_candidates & candidates); - -std::pair, llama_partial_utf8> decode_utf8( - const std::string & src, - llama_partial_utf8 partial_start); - - // Randomly selects a token from the candidates based on their probabilities using given std::mt19937. // This is a temporary workaround in order to fix race conditions when sampling with multiple sequences. diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 12f6f3dc..5d3a864e 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -31,7 +31,7 @@ static std::pair decode_utf8(const char* src) { // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as // pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. -std::pair, llama_partial_utf8> decode_utf8( +static std::pair, llama_partial_utf8> decode_utf8( const std::string & src, llama_partial_utf8 partial_start) { static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; @@ -188,7 +188,53 @@ static std::pair parse_char(const char* src) { throw std::runtime_error("unexpected end of input"); } -static void print_grammar_char(FILE* file, uint32_t c) { +static std::pair parse_token(const llama_vocab * vocab, const char * src) { + const char * pos = src; + if (*pos != '<') { + throw std::runtime_error(std::string("expecting '<' at ") + pos); + } + pos++; + + // Parse <[id]> + if (*pos == '[') { + pos++; + const char * int_end = parse_int(pos); + uint32_t token_id = std::stoul(std::string(pos, int_end - pos)); + pos = int_end; + if (*pos != ']') { + throw std::runtime_error(std::string("expecting ']' at ") + pos); + } + pos++; + if (*pos != '>') { + throw std::runtime_error(std::string("expecting '>' at ") + pos); + } + pos++; + return std::make_pair(token_id, pos); + } + + if (vocab == nullptr) { + throw std::runtime_error(std::string("no vocab to parse token at ") + src); + } + + // Parse and tokenize to obtain the token id + while (*pos != 0 && *pos != '>') { + pos++; + } + if (*pos != '>') { + throw std::runtime_error(std::string("expecting '>' at ") + pos); + } + pos++; + + llama_token tokens[2]; + int32_t n_tokens = vocab->tokenize(src, static_cast(pos - src), tokens, 2, false, true); + if (n_tokens != 1) { + // must tokenize to exactly 1 token + throw std::runtime_error("invalid token '" + std::string(src, pos - src) + "'"); + } + return std::make_pair(tokens[0], pos); +} + +static void print_grammar_char(FILE * file, uint32_t c) { if (0x20 <= c && c <= 0x7f) { fprintf(file, "%c", static_cast(c)); } @@ -220,6 +266,8 @@ static void print_rule_binary(FILE* file, const llama_grammar_rule& rule) { case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; + case LLAMA_GRETYPE_TOKEN: fprintf(file, "TOKEN"); break; + case LLAMA_GRETYPE_TOKEN_NOT: fprintf(file, "TOKEN_NOT"); break; } switch (elem.type) { case LLAMA_GRETYPE_END: @@ -236,6 +284,17 @@ static void print_rule_binary(FILE* file, const llama_grammar_rule& rule) { print_grammar_char(file, elem.value); fprintf(file, "\") "); break; + case LLAMA_GRETYPE_TOKEN: + fprintf(file, "<["); + fprintf(file, "%u", elem.value); + fprintf(file, "]> "); + break; + case LLAMA_GRETYPE_TOKEN_NOT: + fprintf(file, "!"); + fprintf(file, "<["); + fprintf(file, "%u", elem.value); + fprintf(file, "]> "); + break; } } fprintf(file, "\n"); @@ -292,6 +351,17 @@ static void print_rule( case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "."); break; + case LLAMA_GRETYPE_TOKEN: + fprintf(file, "<["); + fprintf(file, "%u", elem.value); + fprintf(file, "]> "); + break; + case LLAMA_GRETYPE_TOKEN_NOT: + fprintf(file, "!"); + fprintf(file, "<["); + fprintf(file, "%u", elem.value); + fprintf(file, "]> "); + break; } if (is_char_element(elem)) { switch (rule[i + 1].type) { @@ -307,6 +377,44 @@ static void print_rule( fprintf(file, "\n"); } +// +// Regex utilities +// + +size_t llama_grammar_trigger_pattern::find(const std::string & input) const { + auto find_start_pos = [](const std::smatch & match) { + // get from the first matched capturing group to the end of the string + size_t start = std::string::npos; + for (auto i = 1u; i < match.size(); i++) { + if (match.length(i) > 0) { + start = match.position(i); + break; + } + } + if (start == std::string::npos) { + start = match.position(0); + } + return start; + }; + + if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') { + // match against the entire input + std::smatch match; + if (std::regex_match(input, match, regex)) { + return find_start_pos(match); + } + } + + // search anywhere + std::smatch match; + if (std::regex_search(input, match, regex)) { + return find_start_pos(match); + } + + return std::string::npos; +} + + // // implementation // @@ -454,9 +562,19 @@ const char* llama_grammar_parser::parse_sequence( } } pos = parse_space(pos + 1, is_nested); - } - else if (is_word_char(*pos)) { // rule reference - const char* name_end = parse_name(pos); + } else if (*pos == '<' || *pos == '!') { // token + auto type = LLAMA_GRETYPE_TOKEN; + if (*pos == '!') { // token inverse + type = LLAMA_GRETYPE_TOKEN_NOT; + pos++; + } + auto token_pair = parse_token(vocab, pos); + const char * token_end = token_pair.second; + last_sym_start = rule.size(); + rule.push_back({type, token_pair.first}); + pos = parse_space(token_end, is_nested); + } else if (is_word_char(*pos)) { // rule reference + const char * name_end = parse_name(pos); uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos); pos = parse_space(name_end, is_nested); last_sym_start = rule.size(); @@ -720,6 +838,21 @@ static bool llama_grammar_match_partial_char( return !is_positive_char; } +// returns true iff token matches the rule at pos (regular or inverse) +// asserts that pos is pointing to a token element +static bool llama_grammar_match_token( + const llama_grammar_element * pos, + const llama_token token) { + GGML_ASSERT(pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT); + if (pos->type == LLAMA_GRETYPE_TOKEN) { + return pos->value == static_cast(token); + } + if (pos->type == LLAMA_GRETYPE_TOKEN_NOT) { + return pos->value != static_cast(token); + } + return false; +} + // transforms a grammar pushdown stack into N possible stacks, all ending // at a character range (terminal element) static void llama_grammar_advance_stack( @@ -768,6 +901,8 @@ static void llama_grammar_advance_stack( case LLAMA_GRETYPE_CHAR: case LLAMA_GRETYPE_CHAR_NOT: case LLAMA_GRETYPE_CHAR_ANY: + case LLAMA_GRETYPE_TOKEN: + case LLAMA_GRETYPE_TOKEN_NOT: if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { // only add the stack if it's not a duplicate of one we already have new_stacks.emplace_back(stack); @@ -864,31 +999,38 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) return grammar->stacks; } +static void llama_grammar_accept_chr( + struct llama_grammar & grammar, + const llama_grammar_stack & stack, + uint32_t chr, + llama_grammar_stacks & new_stacks) { + if (stack.empty()) { + return; + } -// takes a set of possible pushdown stacks on a grammar, which are required to -// be positioned at a character range (see `llama_grammar_advance_stack`), and -// produces the N possible stacks if the given char is accepted at those -// positions -void llama_grammar_accept(struct llama_grammar* grammar, uint32_t chr) { + const llama_grammar_element * pos = stack.back(); + + // ignore if this turns into a token + if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) { + return; + } + + auto match = llama_grammar_match_char(pos, chr); + if (match.first) { + llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(match.second)) { + new_stack.push_back(match.second); + } + llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks); + } +} + +void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) { llama_grammar_stacks stacks_new; stacks_new.reserve(grammar->stacks.size()); for (const auto& stack : grammar->stacks) { - if (stack.empty()) { - continue; - } - - auto match = llama_grammar_match_char(stack.back(), chr); - if (match.first) { - const llama_grammar_element* pos = match.second; - - // update top of stack to next element, if any - llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); - if (!llama_grammar_is_end_of_sequence(pos)) { - new_stack.push_back(pos); - } - llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new); - } + llama_grammar_accept_chr(*grammar, stack, chr, stacks_new); } grammar->stacks = std::move(stacks_new); @@ -913,6 +1055,22 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( const llama_grammar_element * stack_pos = stack.back(); + // if the top of the stack is a token rule, then we only need to check the token id + if (stack_pos->type == LLAMA_GRETYPE_TOKEN || stack_pos->type == LLAMA_GRETYPE_TOKEN_NOT) { + for (const auto & tok : candidates) { + if (*tok.code_points == 0) { + // reached the end of a token consumed by char rules, reject iff it ended + // in a partial response + if (tok.partial_utf8.n_remain != 0) { + rejects.push_back(tok); + } + } else if (!llama_grammar_match_token(stack_pos, tok.id)) { + rejects.push_back(tok); + } + } + return rejects; + } + llama_grammar_candidates next_candidates; next_candidates.reserve(candidates.size()); @@ -925,7 +1083,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( rejects.push_back(tok); } } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) { - next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 }); + next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8, tok.id }); } else { rejects.push_back(tok); } @@ -943,7 +1101,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); for (const auto & tok : next_rejects) { - rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 }); + rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8, tok.id }); } return rejects; @@ -969,10 +1127,9 @@ struct llama_grammar* llama_grammar_init_impl( for (size_t i = 0; i < n_rules; i++) { for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { vec_rules[i].push_back(*pos); - } - vec_rules[i].push_back({ LLAMA_GRETYPE_END, 0 }); } - + vec_rules[i].push_back({ LLAMA_GRETYPE_END, 0 }); + } // Check for left recursion std::vector rules_visited(n_rules); std::vector rules_in_progress(n_rules); @@ -1017,12 +1174,13 @@ struct llama_grammar* llama_grammar_init_impl( NULL, std::move(vec_rules), std::move(stacks), - /* .partial_utf8 = */ {}, - /* .lazy =*/ false, - /* .awaiting_trigger = */ false, - /* .trigger_buffer = */ "", - /* .trigger_tokens = */ {}, - /* .trigger_patterns = */ {}, + /* .partial_utf8 = */ {}, + /* .lazy = */ false, + /* .awaiting_trigger = */ false, + /* .trigger_buffer = */ "", + /* .trigger_buffer_positions = */ {}, + /* .trigger_tokens = */ {}, + /* .trigger_patterns = */ {}, }; } @@ -1035,7 +1193,7 @@ struct llama_grammar* llama_grammar_init_impl( size_t num_trigger_patterns, const llama_token* trigger_tokens, size_t num_trigger_tokens) { - llama_grammar_parser parser; + llama_grammar_parser parser(vocab); // if there is a grammar, parse it // rules will be empty (default) if there are parse errors @@ -1124,38 +1282,44 @@ struct llama_grammar* llama_grammar_init_impl( vocab, std::move(vec_rules), std::move(stacks), - /* .partial_utf8 = */ {}, - /* .lazy = */ lazy, - /* .awaiting_trigger = */ lazy, - /* .trigger_buffer = */ "", + /* .partial_utf8 = */ {}, + /* .lazy = */ lazy, + /* .awaiting_trigger = */ lazy, + /* .trigger_buffer = */ "", + /* .trigger_buffer_positions = */ {}, std::move(vec_trigger_tokens), std::move(vec_trigger_patterns), }; } void llama_grammar_free_impl(struct llama_grammar * grammar) { + if (grammar == nullptr) { + return; + } delete grammar; } -struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) { +struct llama_grammar* llama_grammar_clone_impl(const struct llama_grammar& grammar) { auto* result = new llama_grammar{ - grammar->vocab, - grammar->rules, - grammar->stacks, - grammar->partial_utf8, - grammar->lazy, - grammar->awaiting_trigger, - grammar->trigger_buffer, - grammar->trigger_tokens, - grammar->trigger_patterns, + grammar.vocab, + grammar.rules, + grammar.stacks, + grammar.partial_utf8, + grammar.lazy, + grammar.awaiting_trigger, + grammar.trigger_buffer, + grammar.trigger_buffer_positions, + grammar.trigger_tokens, + grammar.trigger_patterns, }; + // redirect elements in stacks to point to new rules for (size_t is = 0; is < result->stacks.size(); is++) { for (size_t ie = 0; ie < result->stacks[is].size(); ie++) { - for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) { - for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) { - if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) { - result->stacks[is][ie] = &result->rules[ir0][ir1]; + for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) { + for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) { + if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) { + result->stacks[is][ie] = &result->rules[ir0][ir1]; } } } @@ -1199,7 +1363,7 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc candidates->data[i].logit = -INFINITY; } else { candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8)); - candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); + candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second, id }); } } @@ -1208,44 +1372,49 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc candidates->data[reject.index].logit = -INFINITY; } if (!smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; -} + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + } } -void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) { +void llama_grammar_accept_impl(struct llama_grammar & grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) { const int64_t t_start_sample_us = ggml_time_us(); - GGML_ASSERT(grammar->vocab != nullptr); - const auto& piece = grammar->vocab->token_to_piece(token); + GGML_ASSERT(grammar.vocab != nullptr); + const auto& piece = grammar.vocab->token_to_piece(token); - if (grammar->awaiting_trigger) { - if (std::find(grammar->trigger_tokens.begin(), grammar->trigger_tokens.end(), token) != grammar->trigger_tokens.end()) { - grammar->awaiting_trigger = false; - grammar->trigger_buffer.clear(); - llama_grammar_accept_str(grammar, piece); + if (grammar.awaiting_trigger) { + if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { + grammar.awaiting_trigger = false; + grammar.trigger_buffer.clear(); + llama_grammar_accept_token(grammar, token, piece); LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str()); return; - } else { - grammar->trigger_buffer += piece; + } + else { + auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size()); + grammar.trigger_buffer_positions.push_back(std::make_pair(token, position)); + grammar.trigger_buffer += piece; - std::smatch match; - for (const auto& trigger_pattern : grammar->trigger_patterns) { - if (std::regex_match(grammar->trigger_buffer, match, trigger_pattern.regex)) { - grammar->awaiting_trigger = false; - // get from the first matched capturing group to the end of the string - size_t start = std::string::npos; - for (auto i = 1u; i < match.size(); i++) { - if (match.length(i) > 0) { - start = match.position(i); - break; + for (const auto& trigger_pattern : grammar.trigger_patterns) { + auto start = trigger_pattern.find(grammar.trigger_buffer); + if (start != std::string::npos) { + grammar.awaiting_trigger = false; + + // replay tokens that overlap with [start, end) + for (const auto& [tok, tok_pos] : grammar.trigger_buffer_positions) { + auto [tok_start, tok_end] = tok_pos; + if (tok_end <= start) { + continue; } + + size_t piece_start = (tok_start < start) ? start : tok_start; // allow for partial token pieces + size_t piece_len = tok_end - piece_start; + auto tok_piece = grammar.trigger_buffer.substr(piece_start, piece_len); + llama_grammar_accept_token(grammar, tok, tok_piece); } - if (start == std::string::npos) { - start = match.position(0); - } - auto constrained_str = grammar->trigger_buffer.substr(start); - // std::string constrained_str(match[1].first, grammar.trigger_buffer.end()); - grammar->trigger_buffer.clear(); - llama_grammar_accept_str(grammar, constrained_str); + + auto constrained_str = grammar.trigger_buffer.substr(start); + grammar.trigger_buffer.clear(); + grammar.trigger_buffer_positions.clear(); LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str()); return; } @@ -1256,7 +1425,7 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc } if (llama_token_is_eog(vocab, token)) { - for (const auto & stack : grammar->stacks) { + for (const auto & stack : grammar.stacks) { if (stack.empty()) { return; } @@ -1264,22 +1433,77 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc GGML_ABORT("fatal error"); } - llama_grammar_accept_str(grammar, piece); + llama_grammar_accept_token(grammar, token, piece); smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } -void llama_grammar_accept_str(struct llama_grammar* grammar, const std::string& piece) { +void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) { // Note terminating 0 in decoded string - const auto decoded = decode_utf8(piece, grammar->partial_utf8); + const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto& code_points = decoded.first; for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - llama_grammar_accept(grammar, *it); + llama_grammar_accept(&grammar, *it); } - grammar->partial_utf8 = decoded.second; - if (grammar->stacks.empty()) { + grammar.partial_utf8 = decoded.second; + if (grammar.stacks.empty()) { throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece); } } +void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) { + // Note terminating 0 in decoded string + const auto decoded = decode_utf8(piece, grammar.partial_utf8); + const auto & code_points = decoded.first; + + llama_grammar_stacks stacks_new; + stacks_new.reserve(grammar.stacks.size()); + + for (const auto & stack : grammar.stacks) { + if (stack.empty()) { + continue; + } + + const llama_grammar_element * pos = stack.back(); + + if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) { + if (llama_grammar_match_token(pos, token)) { + llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos + 1)) { + new_stack.push_back(pos + 1); + } + llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new); + } + } else { + llama_grammar_stacks current_stacks = {stack}; + + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + llama_grammar_stacks next_stacks; + + for (const auto & cur_stack : current_stacks) { + llama_grammar_accept_chr(grammar, cur_stack, *it, next_stacks); + } + + current_stacks = std::move(next_stacks); + if (current_stacks.empty()) { + break; + } + } + + for (auto & surviving_stack : current_stacks) { + if (std::find(stacks_new.begin(), stacks_new.end(), surviving_stack) == stacks_new.end()) { + stacks_new.emplace_back(surviving_stack); + } + } + } + } + + grammar.stacks = std::move(stacks_new); + grammar.partial_utf8 = decoded.second; + + if (grammar.stacks.empty()) { + throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")"); + } +} + diff --git a/src/llama-grammar.h b/src/llama-grammar.h index f13953c1..78f9c5d7 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -9,11 +9,84 @@ struct llama_vocab; struct llama_sampling; +// grammar element type +enum llama_gretype { + // end of rule definition + LLAMA_GRETYPE_END = 0, + + // start of alternate definition for rule + LLAMA_GRETYPE_ALT = 1, + + // non-terminal element: reference to rule + LLAMA_GRETYPE_RULE_REF = 2, + + // terminal element: character (code point) + LLAMA_GRETYPE_CHAR = 3, + + // inverse char(s) ([^a], [^a-b] [^abc]) + LLAMA_GRETYPE_CHAR_NOT = 4, + + // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + // be an inclusive range ([a-z]) + LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, + + // modifies a preceding LLAMA_GRETYPE_CHAR or + // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + LLAMA_GRETYPE_CHAR_ALT = 6, + + // any character (.) + LLAMA_GRETYPE_CHAR_ANY = 7, + + // terminal element: token (<[token-id]>) + LLAMA_GRETYPE_TOKEN = 8, + + // inverse token (!<[token-id]>) + LLAMA_GRETYPE_TOKEN_NOT = 9, +}; + +typedef struct llama_grammar_element { + enum llama_gretype type; + uint32_t value; // Unicode code point, rule ID, or token ID +} llama_grammar_element; + + +struct llama_partial_utf8 { + uint32_t value; // bit value so far (unshifted) + int n_remain; // num bytes remaining; -1 indicates invalid sequence +}; + +struct llama_grammar_candidate { + size_t index; + const uint32_t* code_points; + llama_partial_utf8 partial_utf8; + llama_token id; +}; + +using llama_grammar_rule = std::vector< llama_grammar_element>; +using llama_grammar_stack = std::vector; + +using llama_grammar_rules = std::vector; +using llama_grammar_stacks = std::vector; +using llama_grammar_candidates = std::vector; + +const llama_grammar_rules& llama_grammar_get_rules(const struct llama_grammar* grammar); +llama_grammar_stacks& llama_grammar_get_stacks(struct llama_grammar* grammar); + +void llama_grammar_accept(struct llama_grammar* grammar, uint32_t chr); + +std::vector llama_grammar_reject_candidates_for_stack( + const llama_grammar_rules& rules, + const llama_grammar_stack& stack, + const llama_grammar_candidates& candidates); + struct llama_grammar_parser { + const llama_vocab * vocab; std::map symbol_ids; llama_grammar_rules rules; + llama_grammar_parser(const struct llama_vocab * vocab = nullptr) : vocab(vocab) {} + llama_grammar_stack c_rules() const; uint32_t get_symbol_id(const char* src, size_t len); @@ -42,9 +115,15 @@ struct llama_grammar_parser { struct llama_grammar_trigger_pattern { std::string pattern; std::regex regex; + + size_t find(const std::string & input) const; }; + struct llama_grammar { + // maintain a list of llama_tokens and their positions in the trigger_buffer + using token_pos = std::pair>; + // note: allow null vocab for testing (not great) const llama_vocab* vocab; @@ -60,6 +139,7 @@ struct llama_grammar { bool lazy = false; bool awaiting_trigger = false; // Initialized to true for lazy grammars only std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. + std::vector trigger_buffer_positions; // Tokens buffered by lazy grammar. Used to replay when a trigger is found. std::vector trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). std::vector trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated @@ -88,7 +168,8 @@ struct llama_grammar* llama_grammar_init_impl( void llama_grammar_free_impl(struct llama_grammar * grammar); -struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar); + +struct llama_grammar* llama_grammar_clone_impl(const struct llama_grammar& grammar); void llama_grammar_sample_impl( const struct llama_grammar * grammar, @@ -96,13 +177,18 @@ void llama_grammar_sample_impl( const struct llama_sampling * smpl, llama_token_data_array * candidates); -void llama_grammar_accept_token_impl( - struct llama_grammar * grammar, +void llama_grammar_accept_impl( + struct llama_grammar & grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token); void llama_grammar_accept_str( - struct llama_grammar* grammar, - const std::string& piece); + struct llama_grammar & grammar, + const std::string & piece); + +void llama_grammar_accept_token( + struct llama_grammar & grammar, + llama_token token, + const std::string & piece); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index bb94af7a..c8cca55c 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1217,7 +1217,7 @@ static const char* llama_sampler_grammar_name(const struct llama_sampler* /*smpl static void llama_sampler_grammar_accept_impl(struct llama_sampler* smpl, llama_token token) { auto* ctx = (llama_sampler_grammar*)smpl->ctx; if (ctx->grammar) { - llama_grammar_accept_token_impl(ctx->grammar,ctx->vocab ,nullptr, token); + llama_grammar_accept_impl(*ctx->grammar,ctx->vocab ,nullptr, token); } } diff --git a/src/llama.cpp b/src/llama.cpp index a78fef0d..36d28c7e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7611,36 +7611,13 @@ int32_t llama_chat_builtin_templates(const char ** output, size_t len) { // grammar // -struct llama_grammar * llama_grammar_init( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index) { - return llama_grammar_init_impl(rules, n_rules, start_rule_index); -} - void llama_grammar_free(struct llama_grammar * grammar) { llama_grammar_free_impl(grammar); } -// -//void llama_grammar_init_lazy(struct llama_sampler* smpl) { -// -// if (!grammar) { -// return; -// } -// std::vector trigger_patterns_c; -// trigger_patterns_c.reserve(grammar.grammar->trigger_patterns.size()); -// for (auto& trigger_pattern : grammar.grammar->trigger_patterns) { -// trigger_patterns_c.push_back(trigger_pattern.pattern.c_str()); -// } -// //auto* grammar_new = llama_grammar_init_impl(grammar->vocab, "", "root", -// // grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(), -// // grammar->trigger_tokens.data(), grammar->trigger_tokens.size()); -// -//} struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) { - return llama_grammar_copy_impl(grammar); + return llama_grammar_clone_impl(*grammar); } void llama_grammar_sample( @@ -7661,7 +7638,7 @@ void llama_grammar_accept_token( struct llama_grammar * grammar, struct llama_context * ctx, llama_token token) { - llama_grammar_accept_token_impl(grammar, &ctx->model.vocab, &ctx->sampling, token); + llama_grammar_accept_impl(*grammar, &ctx->model.vocab, &ctx->sampling, token); } // diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d0334217..18b35616 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -96,16 +96,16 @@ if (NOT WIN32) #llama_target_and_test(test-llama-grammar.cpp) #llama_target_and_test(test-chat.cpp) # TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8 - if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") - llama_target_and_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..) - target_include_directories(test-json-schema-to-grammar PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../examples/server) - endif() + #if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") + # llama_target_and_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..) + # target_include_directories(test-json-schema-to-grammar PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../examples/server) + #endif() # build test-tokenizer-1-bpe target once and add many tests - add_executable(test-tokenizer-1-bpe test-tokenizer-1-bpe.cpp) - target_link_libraries(test-tokenizer-1-bpe PRIVATE common) - install(TARGETS test-tokenizer-1-bpe RUNTIME) + # add_executable(test-tokenizer-1-bpe test-tokenizer-1-bpe.cpp) + # target_link_libraries(test-tokenizer-1-bpe PRIVATE common) + # install(TARGETS test-tokenizer-1-bpe RUNTIME) # TODO: disabled due to slowness #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-aquila ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf) @@ -118,11 +118,11 @@ if (NOT WIN32) #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf) # build test-tokenizer-1-spm target once and add many tests - add_executable(test-tokenizer-1-spm test-tokenizer-1-spm.cpp) - target_link_libraries(test-tokenizer-1-spm PRIVATE common) - install(TARGETS test-tokenizer-1-spm RUNTIME) + # add_executable(test-tokenizer-1-spm test-tokenizer-1-spm.cpp) + # target_link_libraries(test-tokenizer-1-spm PRIVATE common) + # install(TARGETS test-tokenizer-1-spm RUNTIME) - llama_test(test-tokenizer-1-spm NAME test-tokenizer-1-llama-spm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-spm.gguf) + # llama_test(test-tokenizer-1-spm NAME test-tokenizer-1-llama-spm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-spm.gguf) #llama_test(test-tokenizer-1-spm NAME test-tokenizer-1-baichuan ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf) # llama_target_and_test(test-double-float.cpp) # SLOW diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 503eb43a..b6bf394c 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -44,13 +44,66 @@ static bool test_build_grammar_fails(const std::string & grammar_str) { return grammar_fails; } +struct token_and_piece { + llama_token token; + std::string piece; +}; + +// token() encodes a 32-bit ID as 5 bytes: a 0xff marker followed by the ID in big-endian order. +static std::string token(llama_token id) { + return std::string{ + static_cast(0xff), + static_cast((id >> 24) & 0xff), + static_cast((id >> 16) & 0xff), + static_cast((id >> 8) & 0xff), + static_cast(id & 0xff) + }; +} + +// parse_tokens() parses the token encodes above and UTF-8 text. +static std::vector parse_tokens(const std::string & input) { + std::vector result; + result.reserve(input.size()); + size_t offset = 0; + while (offset < input.size()) { + try { + if (static_cast(input[offset]) == 0xff) { + if (offset + 5 > input.size()) { + throw std::runtime_error("not enough bytes for token id"); + } + uint32_t val = + (static_cast(input[offset + 1]) << 24) | + (static_cast(input[offset + 2]) << 16) | + (static_cast(input[offset + 3]) << 8) | + (static_cast(input[offset + 4])); + auto piece = "<[" + std::to_string(val) + "]>"; + result.push_back({static_cast(val), piece}); + offset += 5; + } else { + uint32_t cpt = unicode_cpt_from_utf8(input, offset); + result.push_back({0, unicode_cpt_to_utf8(cpt)}); + } + } catch (const std::invalid_argument & /*ex*/) { + // Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize + ++offset; + result.push_back({0, unicode_cpt_to_utf8(0xFFFD)}); // replacement character + } + } + return result; +} + static bool match_string(const std::string & input, llama_grammar * grammar) { - auto decoded = decode_utf8(input, {}); + const auto parsed = parse_tokens(input); const auto & code_points = decoded.first; - const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); - llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar); + for (const auto & in : parsed) { + try { + llama_grammar_accept_token(*grammar, in.token, in.piece); + } catch (const std::runtime_error & /*e*/) { + // normally this shouldn't get hit because of llama_grammar_apply + return false; + } for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy @@ -443,6 +496,30 @@ static void test_simple_grammar() { "12a45", } ); + + // Test case for a simple grammar with tokens + test_grammar( + "simple grammar with tokens", + R"""( + root ::= <[10]> content <[11]> + content ::= (!<[11]>)*)""", + // Passing strings + { + token(10) + "hello world" + token(11), + token(10) + "text with " + token(12) + " other tokens " + token(13) + " mixed in" + token(11), + token(10) + token(11), + token(10) + token(12) + token(13) + token(14) + token(15) + token(11), + token(10) + "a" + token(11), + }, + // Failing strings + { + token(10) + "missing end token", + token(10), + "missing start token" + token(11), + token(10) + token(11) + token(11), // double end token + token(11) + "wrong order" + token(10), + } + ); } static void test_complex_grammar() { @@ -504,6 +581,34 @@ static void test_complex_grammar() { "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/", } ); + + // Test case for a more complex grammar with tokens + test_grammar( + "complex grammar with tokens", + R"""( + root ::= reasoning+ content tool-call* + reasoning ::= <[10]> (!<[11]>)* <[11]> + content ::= <[20]> (!<[21]>)* <[21]> + tool-call ::= <[12]> name <[13]> args <[14]> + name ::= (!<[13]>)+ + args ::= (!<[14]>)*)""", + // Passing strings + { + token(10) + "I am thinking" + token(11) + token(20) + "hello world!" + token(21) + token(12) + "search" + token(13) + "query=test" + token(14), + token(10) + "reasoning 1" + token(11) + token(10) + "reasoning 2" + token(11) + token(20) + token(21) + token(12) + "tool" + token(13) + token(14), + token(10) + token(11) + token(20) + "content" + token(21), + token(10) + "think" + token(12) + " nested" + token(11) + token(20) + token(10) + "more content" + token(21) + token(12) + "fn" + token(13) + "x=1,y=2" + token(14) + token(12) + "fn2" + token(13) + token(14), + token(10) + "reasoning" + token(11) + token(10) + "more" + token(11) + token(10) + "even more" + token(11) + token(20) + "text" + token(21) + token(12) + "a" + token(13) + "b" + token(14) + token(12) + "c" + token(13) + "d" + token(14), + }, + // Failing strings + { + token(20) + "content only" + token(21), + token(10) + "no closing reasoning", + token(10) + token(11) + token(20) + "no closing content", + token(10) + token(11) + token(20) + token(21) + token(12) + "incomplete tool", + token(10) + token(11) + token(11) + token(20) + token(21), + } + ); } static void test_special_chars() { diff --git a/tests/test-grammar-parser.cpp b/tests/test-grammar-parser.cpp index 5df5abb2..68a38639 100644 --- a/tests/test-grammar-parser.cpp +++ b/tests/test-grammar-parser.cpp @@ -511,5 +511,19 @@ int main() {LLAMA_GRETYPE_END, 0}, }); + // <[1000]> = "" + // <[1001]> = "" + verify_parsing(R"""( + root ::= <[1000]> !<[1001]> <[1001]> + )""", { + {"root", 0} + }, { + // root (index 0) + {LLAMA_GRETYPE_TOKEN, 1000}, + {LLAMA_GRETYPE_TOKEN_NOT, 1001}, + {LLAMA_GRETYPE_TOKEN, 1001}, + {LLAMA_GRETYPE_END, 0}, + }); + return 0; } diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp index 1f3a267b..3814f124 100644 --- a/tests/test-llama-grammar.cpp +++ b/tests/test-llama-grammar.cpp @@ -204,7 +204,7 @@ int main() uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point cp[0] = 37 + i; cp[1] = 0; - next_candidates[i] = {i, cp, {}}; + next_candidates[i] = {i, cp, {}, 0}; } std::vector>> expected_reject = { diff --git a/tests/test-regex-partial.cpp b/tests/test-regex-partial.cpp index ffad1897..70af6d75 100644 --- a/tests/test-regex-partial.cpp +++ b/tests/test-regex-partial.cpp @@ -232,52 +232,52 @@ static void test_regex_to_reversed_partial_regex() { printf("[%s]\n", __func__); assert_equals( - "((?:(?:c)?b)?a)[\\s\\S]*", + "^((?:(?:c)?b)?a)", regex_to_reversed_partial_regex("abc")); assert_equals( - "(a+)[\\s\\S]*", + "^(a+)", regex_to_reversed_partial_regex("a+")); assert_equals( - "(a*)[\\s\\S]*", + "^(a*)", regex_to_reversed_partial_regex("a*")); assert_equals( - "(a?)[\\s\\S]*", + "^(a?)", regex_to_reversed_partial_regex("a?")); assert_equals( - "([a-z])[\\s\\S]*", + "^([a-z])", regex_to_reversed_partial_regex("[a-z]")); assert_equals( - "((?:\\w+)?[a-z])[\\s\\S]*", + "^((?:\\w+)?[a-z])", regex_to_reversed_partial_regex("[a-z]\\w+")); assert_equals( - "((?:a|b))[\\s\\S]*", + "^((?:a|b))", regex_to_reversed_partial_regex("(?:a|b)")); assert_equals( - "((?:(?:(?:d)?c)?b)?a)[\\s\\S]*", + "^((?:(?:(?:d)?c)?b)?a)", regex_to_reversed_partial_regex("abcd")); assert_equals( - "((?:b)?a*)[\\s\\S]*", // TODO: ((?:b)?a*+).* ?? + "^((?:b)?a*)", // TODO: ((?:b)?a*+).* ?? regex_to_reversed_partial_regex("a*b")); assert_equals( - "((?:(?:b)?a)?.*)[\\s\\S]*", + "^((?:(?:b)?a)?.*)", regex_to_reversed_partial_regex(".*?ab")); assert_equals( - "((?:(?:b)?.*)?a)[\\s\\S]*", + "^((?:(?:b)?.*)?a)", regex_to_reversed_partial_regex("a.*?b")); assert_equals( - "((?:(?:d)?(?:(?:c)?b))?a)[\\s\\S]*", + "^((?:(?:d)?(?:(?:c)?b))?a)", regex_to_reversed_partial_regex("a(bc)d")); assert_equals( - "((?:(?:(?:c)?b|(?:e)?d))?a)[\\s\\S]*", + "^((?:(?:(?:c)?b|(?:e)?d))?a)", regex_to_reversed_partial_regex("a(bc|de)")); assert_equals( - "((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)[\\s\\S]*", + "^((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)", regex_to_reversed_partial_regex("ab{2,4}c")); }