diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp index da378eee..d7b0fcba 100644 --- a/common/grammar-parser.cpp +++ b/common/grammar-parser.cpp @@ -369,6 +369,9 @@ namespace grammar_parser { } // Validate the state to ensure that all rules are defined for (const auto & rule : state.rules) { + if (rule.empty()) { + throw std::runtime_error("Undefined rule"); + } for (const auto & elem : rule) { if (elem.type == LLAMA_GRETYPE_RULE_REF) { // Ensure that the rule at that location exists diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 11412c03..3c7b4cc8 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -1,4 +1,5 @@ #include "json-schema-to-grammar.h" +#include "common.h" #include #include #include @@ -19,6 +20,9 @@ static std::string repeat(const std::string & str, size_t n); static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") { auto has_max = max_items != std::numeric_limits::max(); + if (max_items == 0) { + return ""; + } if (min_items == 0 && max_items == 1) { return item_rule + "?"; } @@ -40,52 +44,9 @@ static std::string build_repetition(const std::string & item_rule, int min_items return result; } -/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */ -class string_view { - const std::string & _str; - const size_t _start; - const size_t _end; -public: - string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {} - - size_t size() const { - return _end - _start; - } - - size_t length() const { - return size(); - } - - operator std::string() const { - return str(); - } - - std::string str() const { - return _str.substr(_start, _end - _start); - } - - string_view substr(size_t pos, size_t len = std::string::npos) const { - return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len); - } - - char operator[](size_t pos) const { - auto index = _start + pos; - if (index >= _end) { - throw std::out_of_range("string_view index out of range"); - } - return _str[_start + pos]; - } - - bool operator==(const string_view & other) const { - std::string this_str = *this; - std::string other_str = other; - return this_str == other_str; - } -}; - -static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) { - auto has_min = min_value != std::numeric_limits::min(); - auto has_max = max_value != std::numeric_limits::max(); +static void _build_min_max_int(int64_t min_value, int64_t max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) { + auto has_min = min_value != std::numeric_limits::min(); + auto has_max = max_value != std::numeric_limits::max(); auto digit_range = [&](char from, char to) { out << "["; @@ -111,14 +72,14 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & } out << "}"; }; - std::function uniform_range = - [&](const string_view & from, const string_view & to) { + std::function uniform_range = + [&](const std::string_view & from, const std::string_view & to) { size_t i = 0; while (i < from.length() && i < to.length() && from[i] == to[i]) { i++; } if (i > 0) { - out << "\"" << from.substr(0, i).str() << "\""; + out << "\"" << from.substr(0, i) << "\""; } if (i < from.length() && i < to.length()) { if (i > 0) { @@ -201,7 +162,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & if (has_min) { if (min_value < 0) { out << "\"-\" ("; - _build_min_max_int(std::numeric_limits::min(), -min_value, out, decimals_left, /* top_level= */ false); + _build_min_max_int(std::numeric_limits::min(), -min_value, out, decimals_left, /* top_level= */ false); out << ") | [0] | [1-9] "; more_digits(0, decimals_left - 1); } else if (min_value == 0) { @@ -236,7 +197,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & } digit_range(c, c); out << " ("; - _build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits::max(), out, less_decimals, /* top_level= */ false); + _build_min_max_int(std::stoll(min_s.substr(1)), std::numeric_limits::max(), out, less_decimals, /* top_level= */ false); out << ")"; if (c < '9') { out << " | "; @@ -258,7 +219,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & _build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true); } else { out << "\"-\" ("; - _build_min_max_int(-max_value, std::numeric_limits::max(), out, decimals_left, /* top_level= */ false); + _build_min_max_int(-max_value, std::numeric_limits::max(), out, decimals_left, /* top_level= */ false); out << ")"; } return; @@ -615,7 +576,7 @@ private: } return join_seq(); }; - return _add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space"); + return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space"); } /* @@ -688,7 +649,10 @@ private: } std::string _resolve_ref(const std::string & ref) { - std::string ref_name = ref.substr(ref.find_last_of('/') + 1); + auto it = ref.find('#'); + std::string ref_fragment = it != std::string::npos ? ref.substr(it + 1) : ref; + static const std::regex nonalphanumeric_regex(R"([^a-zA-Z0-9-]+)"); + std::string ref_name = "ref" + std::regex_replace(ref_fragment, nonalphanumeric_regex, "-"); if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) { _refs_being_resolved.insert(ref); json resolved = _refs[ref]; @@ -861,11 +825,24 @@ public: std::vector tokens = split(pointer, "/"); for (size_t i = 1; i < tokens.size(); ++i) { std::string sel = tokens[i]; - if (target.is_null() || !target.contains(sel)) { + if (target.is_object() && target.contains(sel)) { + target = target[sel]; + } else if (target.is_array()) { + size_t sel_index; + try { + sel_index = std::stoul(sel); + } catch (const std::invalid_argument & e) { + sel_index = target.size(); + } + if (sel_index >= target.size()) { + _errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump()); + return; + } + target = target[sel_index]; + } else { _errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump()); return; } - target = target[sel]; } _refs[ref] = target; } @@ -931,9 +908,10 @@ public: _build_object_rule( properties, required, name, schema.contains("additionalProperties") ? schema["additionalProperties"] : json())); - } else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) { + } else if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) { std::unordered_set required; std::vector> properties; + std::map enum_values; std::string hybrid_name = name; std::function add_component = [&](const json & comp_schema, bool is_required) { if (comp_schema.contains("$ref")) { @@ -945,6 +923,14 @@ public: required.insert(prop.key()); } } + } else if (comp_schema.contains("enum")) { + for (const auto & v : comp_schema["enum"]) { + const auto rule = _generate_constant_rule(v); + if (enum_values.find(rule) == enum_values.end()) { + enum_values[rule] = 0; + } + enum_values[rule] += 1; + } } else { // todo warning } @@ -958,6 +944,17 @@ public: add_component(t, true); } } + if (!enum_values.empty()) { + std::vector enum_intersection; + for (const auto & p : enum_values) { + if (p.second == schema["allOf"].size()) { + enum_intersection.push_back(p.first); + } + } + if (!enum_intersection.empty()) { + return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space"); + } + } return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json())); } else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) { json items = schema.contains("items") ? schema["items"] : schema["prefixItems"]; @@ -992,17 +989,17 @@ public: int max_len = schema.contains("maxLength") ? schema["maxLength"].get() : std::numeric_limits::max(); return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space"); } else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) { - int min_value = std::numeric_limits::min(); - int max_value = std::numeric_limits::max(); + int64_t min_value = std::numeric_limits::min(); + int64_t max_value = std::numeric_limits::max(); if (schema.contains("minimum")) { - min_value = schema["minimum"].get(); + min_value = schema["minimum"].get(); } else if (schema.contains("exclusiveMinimum")) { - min_value = schema["exclusiveMinimum"].get() + 1; + min_value = schema["exclusiveMinimum"].get() + 1; } if (schema.contains("maximum")) { - max_value = schema["maximum"].get(); + max_value = schema["maximum"].get(); } else if (schema.contains("exclusiveMaximum")) { - max_value = schema["exclusiveMaximum"].get() - 1; + max_value = schema["exclusiveMaximum"].get() - 1; } std::stringstream out; out << "("; diff --git a/common/sampling.cpp b/common/sampling.cpp index 300ac312..0e19de4c 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -22,7 +22,6 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vo #endif // LLAMA_USE_LLGUIDANCE } else { - std::vector trigger_patterns; std::vector patterns_anywhere; std::vector trigger_tokens; @@ -70,30 +69,34 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vo trigger_tokens.data(), trigger_tokens.size()) : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); - // if there is a grammar, parse it - if (!params.grammar.empty()) { - result->parsed_grammar = grammar_parser::parse(params.grammar.c_str()); - if (result->parsed_grammar.success) { - // will be empty (default) if there are parse errors - if (result->parsed_grammar.rules.empty()) { - fprintf(stderr, "%s: failed to parse grammar\n", __func__); - delete result; - return nullptr; - } + //if (!grmr) { + // return nullptr; + //} - // Ensure that there is a "root" node. - if (result->parsed_grammar.symbol_ids.find("root") == result->parsed_grammar.symbol_ids.end()) { - fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__); - delete result; - return nullptr; - } + // if there is a grammar, parse it + if (!params.grammar.empty()) { + result->parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + if (result->parsed_grammar.success) { + // will be empty (default) if there are parse errors + if (result->parsed_grammar.rules.empty()) { + fprintf(stderr, "%s: failed to parse grammar\n", __func__); + delete result; + return nullptr; + } + + // Ensure that there is a "root" node. + if (result->parsed_grammar.symbol_ids.find("root") == result->parsed_grammar.symbol_ids.end()) { + fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__); + delete result; + return nullptr; + } if (grmr == nullptr) { - throw std::runtime_error("Failed to initialize llama_grammar"); + throw std::runtime_error("Failed to initialize llama_grammar"); + } + } } - } - } - result->prev.resize(params.n_prev); - result->n_valid = 0; + result->prev.resize(params.n_prev); + result->n_valid = 0; } result->grammar = grmr; // init DRY diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index 48a705e1..df968413 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -13,22 +13,14 @@ #include static bool llama_sample_grammar_string(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) { - auto decoded = decode_utf8(input_str, {}); - const auto & code_points = decoded.first; - - const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); - llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar); - + const auto cpts = unicode_cpts_from_utf8(input_str); + auto& cur_stacks = llama_grammar_get_stacks(grammar); size_t pos = 0; - for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy - - llama_grammar_accept(rules, prev_stacks, *it, cur_stacks); - + for (const auto& cpt : cpts) { + llama_grammar_accept(grammar, cpt); if (cur_stacks.empty()) { error_pos = pos; - error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'"; - cur_stacks = prev_stacks; + error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'"; return false; } ++pos; diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 06fd1bfe..26989157 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -10,6 +10,9 @@ from typing import Any, List, Optional, Set, Tuple, Union def _build_repetition(item_rule, min_items, max_items, separator_rule=None): + if max_items == 0: + return "" + if min_items == 0 and max_items == 1: return f'{item_rule}?' @@ -368,8 +371,17 @@ class SchemaConverter: raise ValueError(f'Unsupported ref {ref}') for sel in ref.split('#')[-1].split('/')[1:]: - assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' - target = target[sel] + assert target is not None, f'Error resolving ref {ref}: {sel} not in {target}' + if isinstance(target, list): + try: + sel_index = int(sel) + except ValueError: + raise ValueError(f'Error resolving ref {ref}: {sel} not in {target}') + assert 0 <= sel_index < len(target), f'Error resolving ref {ref}: {sel} not in {target}' + target = target[sel_index] + else: + assert sel in target, f'Error resolving ref {ref}: {sel} not in {target}' + target = target[sel] self._refs[ref] = target else: @@ -540,11 +552,12 @@ class SchemaConverter: return self._add_rule( name, to_rule(transform()) if self._raw_pattern \ - else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space") + else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space") def _resolve_ref(self, ref): - ref_name = ref.split('/')[-1] + ref_fragment = ref.split('#')[-1] + ref_name = 'ref' + re.sub(r'[^a-zA-Z0-9-]+', '-', ref_fragment) if ref_name not in self._rules and ref not in self._refs_being_resolved: self._refs_being_resolved.add(ref) resolved = self._refs[ref] @@ -583,9 +596,10 @@ class SchemaConverter: properties = list(schema.get('properties', {}).items()) return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties'))) - elif schema_type in (None, 'object') and 'allOf' in schema: + elif schema_type in (None, 'object', 'string') and 'allOf' in schema: required = set() properties = [] + enum_sets = [] hybrid_name = name def add_component(comp_schema, is_required): if (ref := comp_schema.get('$ref')) is not None: @@ -597,6 +611,9 @@ class SchemaConverter: if is_required: required.add(prop_name) + if 'enum' in comp_schema: + enum_sets.append(set(comp_schema['enum'])) + for t in schema['allOf']: if 'anyOf' in t: for tt in t['anyOf']: @@ -604,6 +621,15 @@ class SchemaConverter: else: add_component(t, is_required=True) + if enum_sets: + enum_intersection = enum_sets[0] + for s in enum_sets[1:]: + enum_intersection &= s + + if enum_intersection: + rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space' + return self._add_rule(rule_name, rule) + return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None)) elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema): diff --git a/examples/server/public_legacy/json-schema-to-grammar.mjs b/examples/server/public_legacy/json-schema-to-grammar.mjs index b12bf2ab..1d9dc510 100644 --- a/examples/server/public_legacy/json-schema-to-grammar.mjs +++ b/examples/server/public_legacy/json-schema-to-grammar.mjs @@ -345,10 +345,14 @@ export class SchemaConverter { const selectors = ref.split('#')[1].split('/').slice(1); for (const sel of selectors) { - if (!target || !(sel in target)) { + const selIndex = parseInt(sel, 10); + if (target && sel in target) { + target = target[sel]; + } else if (target && selIndex in target) { + target = target[selIndex]; + } else { throw new Error(`Error resolving ref ${ref}: ${sel} not in ${JSON.stringify(target)}`); } - target = target[sel]; } this._refs[ref] = target; @@ -594,7 +598,8 @@ export class SchemaConverter { } _resolveRef(ref) { - let refName = ref.split('/').pop(); + let refFragment = ref.split('#').pop(); + let refName = 'ref' + refFragment.replace(/[^a-zA-Z0-9-]+/g, '-'); if (!(refName in this._rules) && !this._refsBeingResolved.has(ref)) { this._refsBeingResolved.add(ref); const resolved = this._refs[ref]; @@ -631,9 +636,10 @@ export class SchemaConverter { const required = new Set(schema.required || []); const properties = Object.entries(schema.properties ?? {}); return this._addRule(ruleName, this._buildObjectRule(properties, required, name, schema.additionalProperties)); - } else if ((schemaType === undefined || schemaType === 'object') && 'allOf' in schema) { + } else if ((schemaType === undefined || schemaType === 'object' || schemaType === 'string') && 'allOf' in schema) { const required = new Set(); const properties = []; + const enumSets = []; const addComponent = (compSchema, isRequired) => { const ref = compSchema.$ref; if (ref !== undefined) { @@ -648,6 +654,10 @@ export class SchemaConverter { } } } + + if ('enum' in compSchema) { + enumSets.push(new Set(compSchema.enum || [])); + } }; for (const t of schema.allOf) { @@ -660,6 +670,14 @@ export class SchemaConverter { } } + if (enumSets.length > 0) { + const enumIntersection = new Set([...enumSets[0]].filter(v => enumSets.every(s => s.has(v)))); + if (enumIntersection.size > 0) { + const sortedEnums = [...enumIntersection].sort((a, b) => a.localeCompare(b)); + const rule = '(' + sortedEnums.map(v => this._generateConstantRule(v)).join(' | ') + ') space'; + return this._addRule(ruleName, rule); + } + } return this._addRule(ruleName, this._buildObjectRule(properties, required, name, null)); } else if ((schemaType === undefined || schemaType === 'array') && ('items' in schema || 'prefixItems' in schema)) { const items = schema.items ?? schema.prefixItems; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1a1e1b6c..50371d57 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3779,7 +3779,7 @@ struct server_context { common_prefix prefix = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, true); // string level match common_prefix prefix_nonexact = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, false); auto n_past0 = slot.cache_tokens.get_common_prefix_exact(prompt_tokens); // token level match - LLAMA_LOG_INFO("======== Cache: cache_size = %ld, n_past0 = %ld, n_past1 = %ld, n_past_prompt1 = %ld, n_past2 = %ld, n_past_prompt2 = %ld\n", (int32_t) slot.cache_tokens.size(), (int32_t) n_past0, (int32_t) prefix.first, prefix.second, (int32_t) prefix_nonexact.first, (int32_t) prefix_nonexact.second); + LLAMA_LOG_INFO("======== Cache: cache_size = %d, n_past0 = %d, n_past1 = %d, n_past_prompt1 = %d, n_past2 = %d, n_past_prompt2 = %d\n", (int32_t) slot.cache_tokens.size(), (int32_t) n_past0, (int32_t) prefix.first, (int32_t)prefix.second, (int32_t) prefix_nonexact.first, (int32_t) prefix_nonexact.second); int32_t size_threshold = 20; if (prefix.first + size_threshold < prefix_nonexact.first) { LLAMA_LOG_WARN("Common part contains missing or extra space and new line\n"); @@ -4042,7 +4042,7 @@ struct server_context { slot.state = SLOT_STATE_PROCESSING; slot.command = SLOT_COMMAND_NONE; slot.release(); - LLAMA_LOG_INFO("n_past =% d\n", slot.cache_tokens.size()); + LLAMA_LOG_INFO("n_past = %d\n", (int)slot.cache_tokens.size()); send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); } break; // break loop of n_batch diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 2bc76472..fdd41e1b 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -913,7 +913,9 @@ static json oaicompat_chat_params_parse( llama_params["chat_format"] = static_cast(chat_params.format); llama_params["prompt"] = chat_params.prompt; - llama_params["grammar"] = chat_params.grammar; + if (!chat_params.grammar.empty()) { + llama_params["grammar"] = chat_params.grammar; + } llama_params["grammar_lazy"] = chat_params.grammar_lazy; auto grammar_triggers = json::array(); for (const auto & trigger : chat_params.grammar_triggers) { diff --git a/grammars/english.gbnf b/grammars/english.gbnf new file mode 100644 index 00000000..2e53686c --- /dev/null +++ b/grammars/english.gbnf @@ -0,0 +1,6 @@ +# note: this might be incomplete, mostly an example +root ::= en-char+ ([ \t\n] en-char+)* +en-char ::= letter | digit | punctuation +letter ::= [a-zA-Z] +digit ::= [0-9] +punctuation ::= [!"#$%&'()*+,-./:;<=>?@[\\\]^_`{|}~] diff --git a/include/llama.h b/include/llama.h index 9f0fe3e1..ea63574f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1313,12 +1313,16 @@ extern "C" { LLAMA_API void llama_sampler_reset(struct llama_sampler* smpl); +/// @details Intializes a GBNF grammar, see grammars/README.md for details. +/// @param vocab The vocabulary that this grammar will be used with. +/// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails. +/// @param grammar_root The name of the start symbol for the grammar. LLAMA_API struct llama_grammar* llama_sampler_init_grammar( const struct llama_vocab* vocab, const char* grammar_str, - const char* grammar_root); - /// @details Lazy grammar sampler, introduced in https://github.com/ggerganov/llama.cpp/pull/9639 + +/// @details Lazy grammar sampler, introduced in https://github.com/ggerganov/llama.cpp/pull/9639 /// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future. /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. DEPRECATED(LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy( @@ -1473,11 +1477,7 @@ using llama_grammar_candidates = std::vector; const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); -void llama_grammar_accept( - const llama_grammar_rules & rules, - const llama_grammar_stacks & stacks, - const uint32_t chr, - llama_grammar_stacks & new_stacks); +void llama_grammar_accept(struct llama_grammar* grammar, uint32_t chr); std::vector llama_grammar_reject_candidates_for_stack( const llama_grammar_rules & rules, diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 12f20dae..12f6f3dc 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -5,8 +5,14 @@ #include #include +#include #include +#define MAX_REPETITION_THRESHOLD 2000 +// +// helpers +// + // NOTE: assumes valid utf8 (but checks for overrun) static std::pair decode_utf8(const char* src) { static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; @@ -349,8 +355,10 @@ const char* llama_grammar_parser::parse_sequence( size_t last_sym_start = rule.size(); const char* pos = src; - auto handle_repetitions = [&](int min_times, int max_times) { - + // use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used + // (though it's technically the same as -1 now) + auto handle_repetitions = [&](uint64_t min_times, uint64_t max_times) { + bool no_max = max_times == UINT64_MAX; if (last_sym_start == rule.size()) { throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); } @@ -378,20 +386,20 @@ const char* llama_grammar_parser::parse_sequence( } else { // Repeat the previous elements (min_times - 1) times - for (int i = 1; i < min_times; i++) { + for (uint64_t i = 1; i < min_times; i++) { rule.insert(rule.end(), prev_rule.begin(), prev_rule.end()); } } uint32_t last_rec_rule_id = 0; - auto n_opt = max_times < 0 ? 1 : max_times - min_times; + auto n_opt = no_max ? 1 : max_times - min_times; llama_grammar_rule rec_rule(prev_rule); - for (int i = 0; i < n_opt; i++) { + for (uint64_t i = 0; i < n_opt; i++) { rec_rule.resize(prev_rule.size()); uint32_t rec_rule_id = generate_symbol_id(rule_name); - if (i > 0 || max_times < 0) { - rec_rule.push_back({ LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id }); + if (i > 0 || no_max) { + rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, no_max ? rec_rule_id : last_rec_rule_id}); } rec_rule.push_back({ LLAMA_GRETYPE_ALT, 0 }); rec_rule.push_back({ LLAMA_GRETYPE_END, 0 }); @@ -491,10 +499,10 @@ const char* llama_grammar_parser::parse_sequence( throw std::runtime_error(std::string("expecting an int at ") + pos); } const char* int_end = parse_int(pos); - int min_times = std::stoul(std::string(pos, int_end - pos)); + uint64_t min_times = std::stoul(std::string(pos, int_end - pos)); pos = parse_space(int_end, is_nested); - int max_times = -1; + uint64_t max_times = UINT64_MAX; // default: no max limit if (*pos == '}') { max_times = min_times; @@ -517,6 +525,10 @@ const char* llama_grammar_parser::parse_sequence( else { throw std::runtime_error(std::string("expecting ',' at ") + pos); } + bool has_max = max_times != UINT64_MAX; + if (min_times > MAX_REPETITION_THRESHOLD || (has_max && max_times > MAX_REPETITION_THRESHOLD)) { + throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions")); + } handle_repetitions(min_times, max_times); } else { @@ -857,32 +869,30 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) // be positioned at a character range (see `llama_grammar_advance_stack`), and // produces the N possible stacks if the given char is accepted at those // positions -void llama_grammar_accept( - const llama_grammar_rules & rules, - const llama_grammar_stacks & stacks, - const uint32_t chr, - llama_grammar_stacks & new_stacks) { - new_stacks.clear(); +void llama_grammar_accept(struct llama_grammar* grammar, uint32_t chr) { + llama_grammar_stacks stacks_new; + stacks_new.reserve(grammar->stacks.size()); - for (const auto & stack : stacks) { + for (const auto& stack : grammar->stacks) { if (stack.empty()) { continue; } auto match = llama_grammar_match_char(stack.back(), chr); if (match.first) { - const llama_grammar_element * pos = match.second; + const llama_grammar_element* pos = match.second; // update top of stack to next element, if any llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); if (!llama_grammar_is_end_of_sequence(pos)) { new_stack.push_back(pos); } - llama_grammar_advance_stack(rules, new_stack, new_stacks); + llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new); } } -} + grammar->stacks = std::move(stacks_new); +} llama_grammar_candidates llama_grammar_reject_candidates_for_stack( const llama_grammar_rules & rules, @@ -1236,11 +1246,11 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc // std::string constrained_str(match[1].first, grammar.trigger_buffer.end()); grammar->trigger_buffer.clear(); llama_grammar_accept_str(grammar, constrained_str); - //LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str()); + LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str()); return; } } - //LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str()); + LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str()); return; } } @@ -1259,29 +1269,17 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc } void llama_grammar_accept_str(struct llama_grammar* grammar, const std::string& piece) { - // Note terminating 0 in decoded string - const auto decoded = decode_utf8(piece, grammar->partial_utf8); - const auto & code_points = decoded.first; - llama_grammar_stacks tmp_new_stacks; - for (auto it = code_points.begin(), end = code_points.end()-1; it != end; ++it) { - llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks); - // avoid empty grammar stack at the end of the code_points - // mainline has this bug too, reason unknown - if (end == code_points.end() - 1) { - if (tmp_new_stacks.size()) { - grammar->stacks = tmp_new_stacks; - } - } - else { - grammar->stacks = tmp_new_stacks; - } + const auto decoded = decode_utf8(piece, grammar->partial_utf8); + const auto& code_points = decoded.first; + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + llama_grammar_accept(grammar, *it); } grammar->partial_utf8 = decoded.second; if (grammar->stacks.empty()) { throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece); } - } + diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 532f9d22..f13953c1 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -55,7 +55,7 @@ struct llama_grammar { llama_partial_utf8 partial_utf8; // lazy grammars wait for trigger words or tokens before constraining the sampling. - // we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens. + // we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens. // (useful e.g. for tool_choice=required) bool lazy = false; bool awaiting_trigger = false; // Initialized to true for lazy grammars only diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 0d23e146..e7daa175 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1171,8 +1171,10 @@ struct llama_grammar* llama_sampler_init_grammar_impl( num_trigger_patterns = 1; } grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens); - } - else { + if (!grammar) { + return nullptr; + } + } else { grammar = nullptr; } return grammar; diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 78a19c8d..503eb43a 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -318,6 +318,30 @@ static void test_simple_grammar() { "0123", } ); + test_schema( + "min 1 max 900719925474091", + // Schema + R"""({ + "type": "integer", + "exclusiveMinimum": 0, + "maximum": 900719925474091 + })""", + // Passing strings + { + "1", + "2", + "10", + "900719925474090", + "900719925474091", + }, + // Failing strings + { + "0", + "01", + "900719925474092", + "9007199254740910", + } + ); test_schema( "min -1 max 1", R"""({ diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index 53474e0e..2b4a0b76 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -595,6 +595,22 @@ static void test_all(const std::string & lang, std::function