mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 09:09:50 +00:00
* grammar : fix JSON Schema for string regex with top-level alt. (#9903) Prior to this commit, using a JSON Schema containing a string with `pattern` regular expression that uses top-level alternation (e.g. `"pattern": "^A|B|C|D$"`) would result in invalid JSON output from the constrained sampling grammar, because it ended up creating a grammar rule like this for the string: ``` thing ::= "\"" "A" | "B" | "C" | "D" "\"" space ``` Note that this rule will only match a starting quote for the "A" case, and will only match an ending quote for the "D" case, so this rule will always produce invalid JSON when used for sampling (that is, the JSON will always be lacking the starting quote, the ending quote, or both). This was fixed in a simple way by adding parentheses to the generated rule (for all string pattern rules, to keep it simple), such that the new generated rule looks like this (correct): ``` thing ::= "\"" ("A" | "B" | "C" | "D") "\"" space ``` * grammars : add English-only grammar (#10612) * grammar : handle maxItems == 0 in JSON schema (#13117) Co-authored-by: Richard Lyons <frob@cloudstaff.com> * grammar-parser : fix possible null-deref (#9004) Fixes: https://bugs.chromium.org/p/oss-fuzz/issues/detail?id=70680 Signed-off-by: David Korczynski <david@adalogics.com> * llama : fix typo in llama-grammar.h [no ci] (#11816) * * server: fix "--grammar-file" parameter (#12285) * common : use std::string_view now that we target c++17 (#14319) * json : support `enum` values within `allOf` (#15830) * grammar : use int64_t to avoid int overflows in int schema to grammar conversion logic (#16626) * grammar : support array references in json schema (#16792) * grammar : support array references in json schema * Update json-schema-to-grammar.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * grammar : improve regex when naming ref derived rules * grammar : replace non-conformant definitions array with anyOf test case --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> # Conflicts: # tests/test-json-schema-to-grammar.cpp * merge fix * llama : minor grammar refactor (#10897) * llama: fix error on bad grammar (#12628) * grammar : fix integer overflow (#17381) * Fix DoS / integer overflow * Remove optional, use INT64_MAX instead as placeholder value (it's technically -1, so it fits :) * White space * Actually, since it's unsigned, use UINT64_MAX # Conflicts: # src/llama-grammar.cpp * grammar: fix regression caused by #17381 (#17412) * grammar: fix regression caused by #17381 * more readable # Conflicts: # src/llama-grammar.cpp * Merge Fix * Fix warnings --------- Signed-off-by: David Korczynski <david@adalogics.com> Co-authored-by: Joe Eli McIlvain <joe.eli.mac@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: frob <rick+github@frob.com.au> Co-authored-by: Richard Lyons <frob@cloudstaff.com> Co-authored-by: DavidKorczynski <david@adalogics.com> Co-authored-by: Daniel Bevenius <daniel.bevenius@gmail.com> Co-authored-by: firecoperana <firecoperana> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Co-authored-by: Aldehir Rojas <hello@alde.dev> Co-authored-by: Olivier Chafik <olivier.chafik@gmail.com> Co-authored-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com> Co-authored-by: Xuan-Son Nguyen <son@huggingface.co> Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
543 lines
22 KiB
C++
543 lines
22 KiB
C++
#include "grammar-parser.h"
|
|
#include <cstdint>
|
|
#include <cwchar>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <stdexcept>
|
|
#include <exception>
|
|
|
|
namespace grammar_parser {
|
|
// NOTE: assumes valid utf8 (but checks for overrun)
|
|
// copied from llama.cpp
|
|
static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
|
|
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
|
uint8_t first_byte = static_cast<uint8_t>(*src);
|
|
uint8_t highbits = first_byte >> 4;
|
|
int len = lookup[highbits];
|
|
uint8_t mask = (1 << (8 - len)) - 1;
|
|
uint32_t value = first_byte & mask;
|
|
const char * end = src + len; // may overrun!
|
|
const char * pos = src + 1;
|
|
for ( ; pos < end && *pos; pos++) {
|
|
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
|
}
|
|
return std::make_pair(value, pos);
|
|
}
|
|
|
|
static uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
|
|
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
|
|
auto result = state.symbol_ids.emplace(std::string(src, len), next_id);
|
|
return result.first->second;
|
|
}
|
|
|
|
static uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
|
|
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
|
|
state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
|
|
return next_id;
|
|
}
|
|
|
|
static void add_rule(
|
|
parse_state & state,
|
|
uint32_t rule_id,
|
|
const std::vector<llama_grammar_element> & rule) {
|
|
if (state.rules.size() <= rule_id) {
|
|
state.rules.resize(rule_id + 1);
|
|
}
|
|
state.rules[rule_id] = rule;
|
|
}
|
|
|
|
static bool is_digit_char(char c) {
|
|
return '0' <= c && c <= '9';
|
|
}
|
|
|
|
static bool is_word_char(char c) {
|
|
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
|
|
}
|
|
|
|
static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
|
|
const char * pos = src;
|
|
const char * end = src + size;
|
|
uint32_t value = 0;
|
|
for ( ; pos < end && *pos; pos++) {
|
|
value <<= 4;
|
|
char c = *pos;
|
|
if ('a' <= c && c <= 'f') {
|
|
value += c - 'a' + 10;
|
|
} else if ('A' <= c && c <= 'F') {
|
|
value += c - 'A' + 10;
|
|
} else if ('0' <= c && c <= '9') {
|
|
value += c - '0';
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
if (pos != end) {
|
|
throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
|
|
}
|
|
return std::make_pair(value, pos);
|
|
}
|
|
|
|
static const char * parse_space(const char * src, bool newline_ok) {
|
|
const char * pos = src;
|
|
while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
|
|
(newline_ok && (*pos == '\r' || *pos == '\n'))) {
|
|
if (*pos == '#') {
|
|
while (*pos && *pos != '\r' && *pos != '\n') {
|
|
pos++;
|
|
}
|
|
} else {
|
|
pos++;
|
|
}
|
|
}
|
|
return pos;
|
|
}
|
|
|
|
static const char * parse_name(const char * src) {
|
|
const char * pos = src;
|
|
while (is_word_char(*pos)) {
|
|
pos++;
|
|
}
|
|
if (pos == src) {
|
|
throw std::runtime_error(std::string("expecting name at ") + src);
|
|
}
|
|
return pos;
|
|
}
|
|
|
|
static const char * parse_int(const char * src) {
|
|
const char * pos = src;
|
|
while (is_digit_char(*pos)) {
|
|
pos++;
|
|
}
|
|
if (pos == src) {
|
|
throw std::runtime_error(std::string("expecting integer at ") + src);
|
|
}
|
|
return pos;
|
|
}
|
|
|
|
static std::pair<uint32_t, const char *> parse_char(const char * src) {
|
|
if (*src == '\\') {
|
|
switch (src[1]) {
|
|
case 'x': return parse_hex(src + 2, 2);
|
|
case 'u': return parse_hex(src + 2, 4);
|
|
case 'U': return parse_hex(src + 2, 8);
|
|
case 't': return std::make_pair('\t', src + 2);
|
|
case 'r': return std::make_pair('\r', src + 2);
|
|
case 'n': return std::make_pair('\n', src + 2);
|
|
case '\\':
|
|
case '"':
|
|
case '[':
|
|
case ']':
|
|
return std::make_pair(src[1], src + 2);
|
|
default:
|
|
throw std::runtime_error(std::string("unknown escape at ") + src);
|
|
}
|
|
} else if (*src) {
|
|
return decode_utf8(src);
|
|
}
|
|
throw std::runtime_error("unexpected end of input");
|
|
}
|
|
|
|
const char * parse_alternates(
|
|
parse_state & state,
|
|
const char * src,
|
|
const std::string & rule_name,
|
|
uint32_t rule_id,
|
|
bool is_nested);
|
|
|
|
static const char * parse_sequence(
|
|
parse_state & state,
|
|
const char * src,
|
|
const std::string & rule_name,
|
|
std::vector<llama_grammar_element> & out_elements,
|
|
bool is_nested) {
|
|
size_t last_sym_start = out_elements.size();
|
|
const char * pos = src;
|
|
|
|
auto handle_repetitions = [&](int min_times, int max_times) {
|
|
|
|
if (last_sym_start == out_elements.size()) {
|
|
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
|
|
}
|
|
|
|
// apply transformation to previous symbol (last_sym_start to end) according to
|
|
// the following rewrite rules:
|
|
// S{m,n} --> S S S (m times) S'(n-m)
|
|
// S'(x) ::= S S'(x-1) |
|
|
// (... n-m definitions of these S' rules ...)
|
|
// S'(1) ::= S |
|
|
// S{m,} --> S S S (m times) S'
|
|
// S' ::= S S' |
|
|
// S* --> S{0,}
|
|
// --> S' ::= S S' |
|
|
// S+ --> S{1,}
|
|
// --> S S'
|
|
// S' ::= S S' |
|
|
// S? --> S{0,1}
|
|
// --> S'
|
|
// S' ::= S |
|
|
|
|
std::vector<llama_grammar_element> previous_elements(out_elements.begin() + last_sym_start, out_elements.end());
|
|
if (min_times == 0) {
|
|
out_elements.resize(last_sym_start);
|
|
} else {
|
|
// Repeat the previous elements (min_times - 1) times
|
|
for (int i = 1; i < min_times; i++) {
|
|
out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end());
|
|
}
|
|
}
|
|
|
|
uint32_t last_rec_rule_id = 0;
|
|
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
|
|
|
|
std::vector<llama_grammar_element> rec_rule(previous_elements);
|
|
for (int i = 0; i < n_opt; i++) {
|
|
rec_rule.resize(previous_elements.size());
|
|
uint32_t rec_rule_id = generate_symbol_id(state, rule_name);
|
|
if (i > 0 || max_times < 0) {
|
|
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
|
|
}
|
|
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
|
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
|
|
add_rule(state, rec_rule_id, rec_rule);
|
|
last_rec_rule_id = rec_rule_id;
|
|
}
|
|
if (n_opt > 0) {
|
|
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
|
|
}
|
|
};
|
|
|
|
while (*pos) {
|
|
if (*pos == '"') { // literal string
|
|
pos++;
|
|
last_sym_start = out_elements.size();
|
|
while (*pos != '"') {
|
|
if (!*pos) {
|
|
throw std::runtime_error("unexpected end of input");
|
|
}
|
|
auto char_pair = parse_char(pos);
|
|
pos = char_pair.second;
|
|
out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
|
|
}
|
|
pos = parse_space(pos + 1, is_nested);
|
|
} else if (*pos == '[') { // char range(s)
|
|
pos++;
|
|
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
|
|
if (*pos == '^') {
|
|
pos++;
|
|
start_type = LLAMA_GRETYPE_CHAR_NOT;
|
|
}
|
|
last_sym_start = out_elements.size();
|
|
while (*pos != ']') {
|
|
if (!*pos) {
|
|
throw std::runtime_error("unexpected end of input");
|
|
}
|
|
auto char_pair = parse_char(pos);
|
|
pos = char_pair.second;
|
|
enum llama_gretype type = last_sym_start < out_elements.size()
|
|
? LLAMA_GRETYPE_CHAR_ALT
|
|
: start_type;
|
|
|
|
out_elements.push_back({type, char_pair.first});
|
|
if (pos[0] == '-' && pos[1] != ']') {
|
|
if (!pos[1]) {
|
|
throw std::runtime_error("unexpected end of input");
|
|
}
|
|
auto endchar_pair = parse_char(pos + 1);
|
|
pos = endchar_pair.second;
|
|
out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
|
|
}
|
|
}
|
|
pos = parse_space(pos + 1, is_nested);
|
|
} else if (is_word_char(*pos)) { // rule reference
|
|
const char * name_end = parse_name(pos);
|
|
uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
|
|
pos = parse_space(name_end, is_nested);
|
|
last_sym_start = out_elements.size();
|
|
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
|
|
} else if (*pos == '(') { // grouping
|
|
// parse nested alternates into synthesized rule
|
|
pos = parse_space(pos + 1, true);
|
|
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
|
|
pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
|
|
last_sym_start = out_elements.size();
|
|
// output reference to synthesized rule
|
|
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
|
if (*pos != ')') {
|
|
throw std::runtime_error(std::string("expecting ')' at ") + pos);
|
|
}
|
|
pos = parse_space(pos + 1, is_nested);
|
|
} else if (*pos == '.') { // any char
|
|
last_sym_start = out_elements.size();
|
|
out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
|
|
pos = parse_space(pos + 1, is_nested);
|
|
} else if (*pos == '*') {
|
|
pos = parse_space(pos + 1, is_nested);
|
|
handle_repetitions(0, -1);
|
|
} else if (*pos == '+') {
|
|
pos = parse_space(pos + 1, is_nested);
|
|
handle_repetitions(1, -1);
|
|
} else if (*pos == '?') {
|
|
pos = parse_space(pos + 1, is_nested);
|
|
handle_repetitions(0, 1);
|
|
} else if (*pos == '{') {
|
|
pos = parse_space(pos + 1, is_nested);
|
|
|
|
if (!is_digit_char(*pos)) {
|
|
throw std::runtime_error(std::string("expecting an int at ") + pos);
|
|
}
|
|
const char * int_end = parse_int(pos);
|
|
int min_times = std::stoul(std::string(pos, int_end - pos));
|
|
pos = parse_space(int_end, is_nested);
|
|
|
|
int max_times = -1;
|
|
|
|
if (*pos == '}') {
|
|
max_times = min_times;
|
|
pos = parse_space(pos + 1, is_nested);
|
|
} else if (*pos == ',') {
|
|
pos = parse_space(pos + 1, is_nested);
|
|
|
|
if (is_digit_char(*pos)) {
|
|
const char * int_end = parse_int(pos);
|
|
max_times = std::stoul(std::string(pos, int_end - pos));
|
|
pos = parse_space(int_end, is_nested);
|
|
}
|
|
|
|
if (*pos != '}') {
|
|
throw std::runtime_error(std::string("expecting '}' at ") + pos);
|
|
}
|
|
pos = parse_space(pos + 1, is_nested);
|
|
} else {
|
|
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
|
}
|
|
handle_repetitions(min_times, max_times);
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
return pos;
|
|
}
|
|
|
|
const char * parse_alternates(
|
|
parse_state & state,
|
|
const char * src,
|
|
const std::string & rule_name,
|
|
uint32_t rule_id,
|
|
bool is_nested) {
|
|
std::vector<llama_grammar_element> rule;
|
|
const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
|
|
while (*pos == '|') {
|
|
rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
|
pos = parse_space(pos + 1, true);
|
|
pos = parse_sequence(state, pos, rule_name, rule, is_nested);
|
|
}
|
|
rule.push_back({LLAMA_GRETYPE_END, 0});
|
|
add_rule(state, rule_id, rule);
|
|
return pos;
|
|
}
|
|
|
|
static const char * parse_rule(parse_state & state, const char * src) {
|
|
const char * name_end = parse_name(src);
|
|
const char * pos = parse_space(name_end, false);
|
|
size_t name_len = name_end - src;
|
|
uint32_t rule_id = get_symbol_id(state, src, name_len);
|
|
const std::string name(src, name_len);
|
|
|
|
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
|
|
throw std::runtime_error(std::string("expecting ::= at ") + pos);
|
|
}
|
|
pos = parse_space(pos + 3, true);
|
|
|
|
pos = parse_alternates(state, pos, name, rule_id, false);
|
|
|
|
if (*pos == '\r') {
|
|
pos += pos[1] == '\n' ? 2 : 1;
|
|
} else if (*pos == '\n') {
|
|
pos++;
|
|
} else if (*pos) {
|
|
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
|
|
}
|
|
return parse_space(pos, true);
|
|
}
|
|
|
|
parse_state parse(const char * src) {
|
|
try {
|
|
parse_state state;
|
|
const char * pos = parse_space(src, true);
|
|
while (*pos) {
|
|
pos = parse_rule(state, pos);
|
|
}
|
|
// Validate the state to ensure that all rules are defined
|
|
for (const auto & rule : state.rules) {
|
|
if (rule.empty()) {
|
|
throw std::runtime_error("Undefined rule");
|
|
}
|
|
for (const auto & elem : rule) {
|
|
if (elem.type == LLAMA_GRETYPE_RULE_REF) {
|
|
// Ensure that the rule at that location exists
|
|
if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) {
|
|
// Get the name of the rule that is missing
|
|
for (const auto & kv : state.symbol_ids) {
|
|
if (kv.second == elem.value) {
|
|
throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
state.success = true;
|
|
return state;
|
|
} catch (const std::exception & err) {
|
|
fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src);
|
|
parse_state state;
|
|
state.success = false;
|
|
return state;
|
|
}
|
|
}
|
|
|
|
static void print_grammar_char(FILE * file, uint32_t c) {
|
|
if (0x20 <= c && c <= 0x7f) {
|
|
fprintf(file, "%c", static_cast<char>(c));
|
|
} else {
|
|
// cop out of encoding UTF-8
|
|
fprintf(file, "<U+%04X>", c);
|
|
}
|
|
}
|
|
|
|
static bool is_char_element(llama_grammar_element elem) {
|
|
switch (elem.type) {
|
|
case LLAMA_GRETYPE_CHAR: return true;
|
|
case LLAMA_GRETYPE_CHAR_NOT: return true;
|
|
case LLAMA_GRETYPE_CHAR_ALT: return true;
|
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
|
|
case LLAMA_GRETYPE_CHAR_ANY: return true;
|
|
default: return false;
|
|
}
|
|
}
|
|
|
|
static void print_rule_binary(FILE * file, const std::vector<llama_grammar_element> & rule) {
|
|
for (auto elem : rule) {
|
|
switch (elem.type) {
|
|
case LLAMA_GRETYPE_END: fprintf(file, "END"); break;
|
|
case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
|
|
case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
|
|
case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
|
|
case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
|
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
|
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
|
|
case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
|
|
}
|
|
switch (elem.type) {
|
|
case LLAMA_GRETYPE_END:
|
|
case LLAMA_GRETYPE_ALT:
|
|
case LLAMA_GRETYPE_RULE_REF:
|
|
fprintf(file, "(%u) ", elem.value);
|
|
break;
|
|
case LLAMA_GRETYPE_CHAR:
|
|
case LLAMA_GRETYPE_CHAR_NOT:
|
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
|
case LLAMA_GRETYPE_CHAR_ALT:
|
|
case LLAMA_GRETYPE_CHAR_ANY:
|
|
fprintf(file, "(\"");
|
|
print_grammar_char(file, elem.value);
|
|
fprintf(file, "\") ");
|
|
break;
|
|
}
|
|
}
|
|
fprintf(file, "\n");
|
|
}
|
|
|
|
static void print_rule(
|
|
FILE * file,
|
|
uint32_t rule_id,
|
|
const std::vector<llama_grammar_element> & rule,
|
|
const std::map<uint32_t, std::string> & symbol_id_names) {
|
|
if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
|
|
throw std::runtime_error(
|
|
"malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
|
|
}
|
|
fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
|
|
for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
|
|
llama_grammar_element elem = rule[i];
|
|
switch (elem.type) {
|
|
case LLAMA_GRETYPE_END:
|
|
throw std::runtime_error(
|
|
"unexpected end of rule: " + std::to_string(rule_id) + "," +
|
|
std::to_string(i));
|
|
case LLAMA_GRETYPE_ALT:
|
|
fprintf(file, "| ");
|
|
break;
|
|
case LLAMA_GRETYPE_RULE_REF:
|
|
fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
|
|
break;
|
|
case LLAMA_GRETYPE_CHAR:
|
|
fprintf(file, "[");
|
|
print_grammar_char(file, elem.value);
|
|
break;
|
|
case LLAMA_GRETYPE_CHAR_NOT:
|
|
fprintf(file, "[^");
|
|
print_grammar_char(file, elem.value);
|
|
break;
|
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
|
if (i == 0 || !is_char_element(rule[i - 1])) {
|
|
throw std::runtime_error(
|
|
"LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
|
|
std::to_string(rule_id) + "," + std::to_string(i));
|
|
}
|
|
fprintf(file, "-");
|
|
print_grammar_char(file, elem.value);
|
|
break;
|
|
case LLAMA_GRETYPE_CHAR_ALT:
|
|
if (i == 0 || !is_char_element(rule[i - 1])) {
|
|
throw std::runtime_error(
|
|
"LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
|
|
std::to_string(rule_id) + "," + std::to_string(i));
|
|
}
|
|
print_grammar_char(file, elem.value);
|
|
break;
|
|
case LLAMA_GRETYPE_CHAR_ANY:
|
|
fprintf(file, ".");
|
|
break;
|
|
}
|
|
if (is_char_element(elem)) {
|
|
switch (rule[i + 1].type) {
|
|
case LLAMA_GRETYPE_CHAR_ALT:
|
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
|
case LLAMA_GRETYPE_CHAR_ANY:
|
|
break;
|
|
default:
|
|
fprintf(file, "] ");
|
|
}
|
|
}
|
|
}
|
|
fprintf(file, "\n");
|
|
}
|
|
|
|
void print_grammar(FILE * file, const parse_state & state) {
|
|
try {
|
|
std::map<uint32_t, std::string> symbol_id_names;
|
|
for (const auto & kv : state.symbol_ids) {
|
|
symbol_id_names[kv.second] = kv.first;
|
|
}
|
|
for (size_t i = 0, end = state.rules.size(); i < end; i++) {
|
|
// fprintf(file, "%zu: ", i);
|
|
// print_rule_binary(file, state.rules[i]);
|
|
print_rule(file, uint32_t(i), state.rules[i], symbol_id_names);
|
|
// fprintf(file, "\n");
|
|
}
|
|
} catch (const std::exception & err) {
|
|
fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
|
|
}
|
|
}
|
|
|
|
std::vector<const llama_grammar_element *> parse_state::c_rules() {
|
|
std::vector<const llama_grammar_element *> ret;
|
|
ret.reserve(rules.size());
|
|
for (const auto & rule : rules) {
|
|
ret.push_back(rule.data());
|
|
}
|
|
return ret;
|
|
}
|
|
}
|