mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Update grammar (#1023)
* grammar : fix JSON Schema for string regex with top-level alt. (#9903) Prior to this commit, using a JSON Schema containing a string with `pattern` regular expression that uses top-level alternation (e.g. `"pattern": "^A|B|C|D$"`) would result in invalid JSON output from the constrained sampling grammar, because it ended up creating a grammar rule like this for the string: ``` thing ::= "\"" "A" | "B" | "C" | "D" "\"" space ``` Note that this rule will only match a starting quote for the "A" case, and will only match an ending quote for the "D" case, so this rule will always produce invalid JSON when used for sampling (that is, the JSON will always be lacking the starting quote, the ending quote, or both). This was fixed in a simple way by adding parentheses to the generated rule (for all string pattern rules, to keep it simple), such that the new generated rule looks like this (correct): ``` thing ::= "\"" ("A" | "B" | "C" | "D") "\"" space ``` * grammars : add English-only grammar (#10612) * grammar : handle maxItems == 0 in JSON schema (#13117) Co-authored-by: Richard Lyons <frob@cloudstaff.com> * grammar-parser : fix possible null-deref (#9004) Fixes: https://bugs.chromium.org/p/oss-fuzz/issues/detail?id=70680 Signed-off-by: David Korczynski <david@adalogics.com> * llama : fix typo in llama-grammar.h [no ci] (#11816) * * server: fix "--grammar-file" parameter (#12285) * common : use std::string_view now that we target c++17 (#14319) * json : support `enum` values within `allOf` (#15830) * grammar : use int64_t to avoid int overflows in int schema to grammar conversion logic (#16626) * grammar : support array references in json schema (#16792) * grammar : support array references in json schema * Update json-schema-to-grammar.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * grammar : improve regex when naming ref derived rules * grammar : replace non-conformant definitions array with anyOf test case --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> # Conflicts: # tests/test-json-schema-to-grammar.cpp * merge fix * llama : minor grammar refactor (#10897) * llama: fix error on bad grammar (#12628) * grammar : fix integer overflow (#17381) * Fix DoS / integer overflow * Remove optional, use INT64_MAX instead as placeholder value (it's technically -1, so it fits :) * White space * Actually, since it's unsigned, use UINT64_MAX # Conflicts: # src/llama-grammar.cpp * grammar: fix regression caused by #17381 (#17412) * grammar: fix regression caused by #17381 * more readable # Conflicts: # src/llama-grammar.cpp * Merge Fix * Fix warnings --------- Signed-off-by: David Korczynski <david@adalogics.com> Co-authored-by: Joe Eli McIlvain <joe.eli.mac@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: frob <rick+github@frob.com.au> Co-authored-by: Richard Lyons <frob@cloudstaff.com> Co-authored-by: DavidKorczynski <david@adalogics.com> Co-authored-by: Daniel Bevenius <daniel.bevenius@gmail.com> Co-authored-by: firecoperana <firecoperana> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Co-authored-by: Aldehir Rojas <hello@alde.dev> Co-authored-by: Olivier Chafik <olivier.chafik@gmail.com> Co-authored-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com> Co-authored-by: Xuan-Son Nguyen <son@huggingface.co> Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -369,6 +369,9 @@ namespace grammar_parser {
|
|||||||
}
|
}
|
||||||
// Validate the state to ensure that all rules are defined
|
// Validate the state to ensure that all rules are defined
|
||||||
for (const auto & rule : state.rules) {
|
for (const auto & rule : state.rules) {
|
||||||
|
if (rule.empty()) {
|
||||||
|
throw std::runtime_error("Undefined rule");
|
||||||
|
}
|
||||||
for (const auto & elem : rule) {
|
for (const auto & elem : rule) {
|
||||||
if (elem.type == LLAMA_GRETYPE_RULE_REF) {
|
if (elem.type == LLAMA_GRETYPE_RULE_REF) {
|
||||||
// Ensure that the rule at that location exists
|
// Ensure that the rule at that location exists
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#include "json-schema-to-grammar.h"
|
#include "json-schema-to-grammar.h"
|
||||||
|
#include "common.h"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <map>
|
#include <map>
|
||||||
@@ -19,6 +20,9 @@ static std::string repeat(const std::string & str, size_t n);
|
|||||||
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
|
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
|
||||||
auto has_max = max_items != std::numeric_limits<int>::max();
|
auto has_max = max_items != std::numeric_limits<int>::max();
|
||||||
|
|
||||||
|
if (max_items == 0) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
if (min_items == 0 && max_items == 1) {
|
if (min_items == 0 && max_items == 1) {
|
||||||
return item_rule + "?";
|
return item_rule + "?";
|
||||||
}
|
}
|
||||||
@@ -40,52 +44,9 @@ static std::string build_repetition(const std::string & item_rule, int min_items
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */
|
static void _build_min_max_int(int64_t min_value, int64_t max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
|
||||||
class string_view {
|
auto has_min = min_value != std::numeric_limits<int64_t>::min();
|
||||||
const std::string & _str;
|
auto has_max = max_value != std::numeric_limits<int64_t>::max();
|
||||||
const size_t _start;
|
|
||||||
const size_t _end;
|
|
||||||
public:
|
|
||||||
string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {}
|
|
||||||
|
|
||||||
size_t size() const {
|
|
||||||
return _end - _start;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t length() const {
|
|
||||||
return size();
|
|
||||||
}
|
|
||||||
|
|
||||||
operator std::string() const {
|
|
||||||
return str();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string str() const {
|
|
||||||
return _str.substr(_start, _end - _start);
|
|
||||||
}
|
|
||||||
|
|
||||||
string_view substr(size_t pos, size_t len = std::string::npos) const {
|
|
||||||
return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len);
|
|
||||||
}
|
|
||||||
|
|
||||||
char operator[](size_t pos) const {
|
|
||||||
auto index = _start + pos;
|
|
||||||
if (index >= _end) {
|
|
||||||
throw std::out_of_range("string_view index out of range");
|
|
||||||
}
|
|
||||||
return _str[_start + pos];
|
|
||||||
}
|
|
||||||
|
|
||||||
bool operator==(const string_view & other) const {
|
|
||||||
std::string this_str = *this;
|
|
||||||
std::string other_str = other;
|
|
||||||
return this_str == other_str;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
|
|
||||||
auto has_min = min_value != std::numeric_limits<int>::min();
|
|
||||||
auto has_max = max_value != std::numeric_limits<int>::max();
|
|
||||||
|
|
||||||
auto digit_range = [&](char from, char to) {
|
auto digit_range = [&](char from, char to) {
|
||||||
out << "[";
|
out << "[";
|
||||||
@@ -111,14 +72,14 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
|||||||
}
|
}
|
||||||
out << "}";
|
out << "}";
|
||||||
};
|
};
|
||||||
std::function<void(const string_view &, const string_view &)> uniform_range =
|
std::function<void(const std::string_view &, const std::string_view &)> uniform_range =
|
||||||
[&](const string_view & from, const string_view & to) {
|
[&](const std::string_view & from, const std::string_view & to) {
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
while (i < from.length() && i < to.length() && from[i] == to[i]) {
|
while (i < from.length() && i < to.length() && from[i] == to[i]) {
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
if (i > 0) {
|
if (i > 0) {
|
||||||
out << "\"" << from.substr(0, i).str() << "\"";
|
out << "\"" << from.substr(0, i) << "\"";
|
||||||
}
|
}
|
||||||
if (i < from.length() && i < to.length()) {
|
if (i < from.length() && i < to.length()) {
|
||||||
if (i > 0) {
|
if (i > 0) {
|
||||||
@@ -201,7 +162,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
|||||||
if (has_min) {
|
if (has_min) {
|
||||||
if (min_value < 0) {
|
if (min_value < 0) {
|
||||||
out << "\"-\" (";
|
out << "\"-\" (";
|
||||||
_build_min_max_int(std::numeric_limits<int>::min(), -min_value, out, decimals_left, /* top_level= */ false);
|
_build_min_max_int(std::numeric_limits<int64_t>::min(), -min_value, out, decimals_left, /* top_level= */ false);
|
||||||
out << ") | [0] | [1-9] ";
|
out << ") | [0] | [1-9] ";
|
||||||
more_digits(0, decimals_left - 1);
|
more_digits(0, decimals_left - 1);
|
||||||
} else if (min_value == 0) {
|
} else if (min_value == 0) {
|
||||||
@@ -236,7 +197,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
|||||||
}
|
}
|
||||||
digit_range(c, c);
|
digit_range(c, c);
|
||||||
out << " (";
|
out << " (";
|
||||||
_build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits<int>::max(), out, less_decimals, /* top_level= */ false);
|
_build_min_max_int(std::stoll(min_s.substr(1)), std::numeric_limits<int64_t>::max(), out, less_decimals, /* top_level= */ false);
|
||||||
out << ")";
|
out << ")";
|
||||||
if (c < '9') {
|
if (c < '9') {
|
||||||
out << " | ";
|
out << " | ";
|
||||||
@@ -258,7 +219,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
|||||||
_build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
|
_build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
|
||||||
} else {
|
} else {
|
||||||
out << "\"-\" (";
|
out << "\"-\" (";
|
||||||
_build_min_max_int(-max_value, std::numeric_limits<int>::max(), out, decimals_left, /* top_level= */ false);
|
_build_min_max_int(-max_value, std::numeric_limits<int64_t>::max(), out, decimals_left, /* top_level= */ false);
|
||||||
out << ")";
|
out << ")";
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
@@ -615,7 +576,7 @@ private:
|
|||||||
}
|
}
|
||||||
return join_seq();
|
return join_seq();
|
||||||
};
|
};
|
||||||
return _add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space");
|
return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space");
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@@ -688,7 +649,10 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string _resolve_ref(const std::string & ref) {
|
std::string _resolve_ref(const std::string & ref) {
|
||||||
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
|
auto it = ref.find('#');
|
||||||
|
std::string ref_fragment = it != std::string::npos ? ref.substr(it + 1) : ref;
|
||||||
|
static const std::regex nonalphanumeric_regex(R"([^a-zA-Z0-9-]+)");
|
||||||
|
std::string ref_name = "ref" + std::regex_replace(ref_fragment, nonalphanumeric_regex, "-");
|
||||||
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
|
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
|
||||||
_refs_being_resolved.insert(ref);
|
_refs_being_resolved.insert(ref);
|
||||||
json resolved = _refs[ref];
|
json resolved = _refs[ref];
|
||||||
@@ -861,11 +825,24 @@ public:
|
|||||||
std::vector<std::string> tokens = split(pointer, "/");
|
std::vector<std::string> tokens = split(pointer, "/");
|
||||||
for (size_t i = 1; i < tokens.size(); ++i) {
|
for (size_t i = 1; i < tokens.size(); ++i) {
|
||||||
std::string sel = tokens[i];
|
std::string sel = tokens[i];
|
||||||
if (target.is_null() || !target.contains(sel)) {
|
if (target.is_object() && target.contains(sel)) {
|
||||||
|
target = target[sel];
|
||||||
|
} else if (target.is_array()) {
|
||||||
|
size_t sel_index;
|
||||||
|
try {
|
||||||
|
sel_index = std::stoul(sel);
|
||||||
|
} catch (const std::invalid_argument & e) {
|
||||||
|
sel_index = target.size();
|
||||||
|
}
|
||||||
|
if (sel_index >= target.size()) {
|
||||||
|
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
target = target[sel_index];
|
||||||
|
} else {
|
||||||
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
|
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
target = target[sel];
|
|
||||||
}
|
}
|
||||||
_refs[ref] = target;
|
_refs[ref] = target;
|
||||||
}
|
}
|
||||||
@@ -931,9 +908,10 @@ public:
|
|||||||
_build_object_rule(
|
_build_object_rule(
|
||||||
properties, required, name,
|
properties, required, name,
|
||||||
schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
|
schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
|
||||||
} else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) {
|
} else if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) {
|
||||||
std::unordered_set<std::string> required;
|
std::unordered_set<std::string> required;
|
||||||
std::vector<std::pair<std::string, json>> properties;
|
std::vector<std::pair<std::string, json>> properties;
|
||||||
|
std::map<std::string, size_t> enum_values;
|
||||||
std::string hybrid_name = name;
|
std::string hybrid_name = name;
|
||||||
std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
|
std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
|
||||||
if (comp_schema.contains("$ref")) {
|
if (comp_schema.contains("$ref")) {
|
||||||
@@ -945,6 +923,14 @@ public:
|
|||||||
required.insert(prop.key());
|
required.insert(prop.key());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if (comp_schema.contains("enum")) {
|
||||||
|
for (const auto & v : comp_schema["enum"]) {
|
||||||
|
const auto rule = _generate_constant_rule(v);
|
||||||
|
if (enum_values.find(rule) == enum_values.end()) {
|
||||||
|
enum_values[rule] = 0;
|
||||||
|
}
|
||||||
|
enum_values[rule] += 1;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// todo warning
|
// todo warning
|
||||||
}
|
}
|
||||||
@@ -958,6 +944,17 @@ public:
|
|||||||
add_component(t, true);
|
add_component(t, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (!enum_values.empty()) {
|
||||||
|
std::vector<std::string> enum_intersection;
|
||||||
|
for (const auto & p : enum_values) {
|
||||||
|
if (p.second == schema["allOf"].size()) {
|
||||||
|
enum_intersection.push_back(p.first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!enum_intersection.empty()) {
|
||||||
|
return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space");
|
||||||
|
}
|
||||||
|
}
|
||||||
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
|
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
|
||||||
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
|
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
|
||||||
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
|
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
|
||||||
@@ -992,17 +989,17 @@ public:
|
|||||||
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
|
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
|
||||||
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
|
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
|
||||||
} else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
|
} else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
|
||||||
int min_value = std::numeric_limits<int>::min();
|
int64_t min_value = std::numeric_limits<int64_t>::min();
|
||||||
int max_value = std::numeric_limits<int>::max();
|
int64_t max_value = std::numeric_limits<int64_t>::max();
|
||||||
if (schema.contains("minimum")) {
|
if (schema.contains("minimum")) {
|
||||||
min_value = schema["minimum"].get<int>();
|
min_value = schema["minimum"].get<int64_t>();
|
||||||
} else if (schema.contains("exclusiveMinimum")) {
|
} else if (schema.contains("exclusiveMinimum")) {
|
||||||
min_value = schema["exclusiveMinimum"].get<int>() + 1;
|
min_value = schema["exclusiveMinimum"].get<int64_t>() + 1;
|
||||||
}
|
}
|
||||||
if (schema.contains("maximum")) {
|
if (schema.contains("maximum")) {
|
||||||
max_value = schema["maximum"].get<int>();
|
max_value = schema["maximum"].get<int64_t>();
|
||||||
} else if (schema.contains("exclusiveMaximum")) {
|
} else if (schema.contains("exclusiveMaximum")) {
|
||||||
max_value = schema["exclusiveMaximum"].get<int>() - 1;
|
max_value = schema["exclusiveMaximum"].get<int64_t>() - 1;
|
||||||
}
|
}
|
||||||
std::stringstream out;
|
std::stringstream out;
|
||||||
out << "(";
|
out << "(";
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vo
|
|||||||
#endif // LLAMA_USE_LLGUIDANCE
|
#endif // LLAMA_USE_LLGUIDANCE
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
std::vector<std::string> trigger_patterns;
|
std::vector<std::string> trigger_patterns;
|
||||||
std::vector<std::string> patterns_anywhere;
|
std::vector<std::string> patterns_anywhere;
|
||||||
std::vector<llama_token> trigger_tokens;
|
std::vector<llama_token> trigger_tokens;
|
||||||
@@ -70,6 +69,10 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vo
|
|||||||
trigger_tokens.data(), trigger_tokens.size())
|
trigger_tokens.data(), trigger_tokens.size())
|
||||||
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||||
|
|
||||||
|
//if (!grmr) {
|
||||||
|
// return nullptr;
|
||||||
|
//}
|
||||||
|
|
||||||
// if there is a grammar, parse it
|
// if there is a grammar, parse it
|
||||||
if (!params.grammar.empty()) {
|
if (!params.grammar.empty()) {
|
||||||
result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
||||||
|
|||||||
@@ -13,22 +13,14 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
static bool llama_sample_grammar_string(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
|
static bool llama_sample_grammar_string(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
|
||||||
auto decoded = decode_utf8(input_str, {});
|
const auto cpts = unicode_cpts_from_utf8(input_str);
|
||||||
const auto & code_points = decoded.first;
|
auto& cur_stacks = llama_grammar_get_stacks(grammar);
|
||||||
|
|
||||||
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
|
||||||
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
|
|
||||||
|
|
||||||
size_t pos = 0;
|
size_t pos = 0;
|
||||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
for (const auto& cpt : cpts) {
|
||||||
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
|
llama_grammar_accept(grammar, cpt);
|
||||||
|
|
||||||
llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
|
|
||||||
|
|
||||||
if (cur_stacks.empty()) {
|
if (cur_stacks.empty()) {
|
||||||
error_pos = pos;
|
error_pos = pos;
|
||||||
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
|
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
|
||||||
cur_stacks = prev_stacks;
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
++pos;
|
++pos;
|
||||||
|
|||||||
@@ -10,6 +10,9 @@ from typing import Any, List, Optional, Set, Tuple, Union
|
|||||||
|
|
||||||
def _build_repetition(item_rule, min_items, max_items, separator_rule=None):
|
def _build_repetition(item_rule, min_items, max_items, separator_rule=None):
|
||||||
|
|
||||||
|
if max_items == 0:
|
||||||
|
return ""
|
||||||
|
|
||||||
if min_items == 0 and max_items == 1:
|
if min_items == 0 and max_items == 1:
|
||||||
return f'{item_rule}?'
|
return f'{item_rule}?'
|
||||||
|
|
||||||
@@ -368,7 +371,16 @@ class SchemaConverter:
|
|||||||
raise ValueError(f'Unsupported ref {ref}')
|
raise ValueError(f'Unsupported ref {ref}')
|
||||||
|
|
||||||
for sel in ref.split('#')[-1].split('/')[1:]:
|
for sel in ref.split('#')[-1].split('/')[1:]:
|
||||||
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
|
assert target is not None, f'Error resolving ref {ref}: {sel} not in {target}'
|
||||||
|
if isinstance(target, list):
|
||||||
|
try:
|
||||||
|
sel_index = int(sel)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f'Error resolving ref {ref}: {sel} not in {target}')
|
||||||
|
assert 0 <= sel_index < len(target), f'Error resolving ref {ref}: {sel} not in {target}'
|
||||||
|
target = target[sel_index]
|
||||||
|
else:
|
||||||
|
assert sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
|
||||||
target = target[sel]
|
target = target[sel]
|
||||||
|
|
||||||
self._refs[ref] = target
|
self._refs[ref] = target
|
||||||
@@ -540,11 +552,12 @@ class SchemaConverter:
|
|||||||
return self._add_rule(
|
return self._add_rule(
|
||||||
name,
|
name,
|
||||||
to_rule(transform()) if self._raw_pattern \
|
to_rule(transform()) if self._raw_pattern \
|
||||||
else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space")
|
else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space")
|
||||||
|
|
||||||
|
|
||||||
def _resolve_ref(self, ref):
|
def _resolve_ref(self, ref):
|
||||||
ref_name = ref.split('/')[-1]
|
ref_fragment = ref.split('#')[-1]
|
||||||
|
ref_name = 'ref' + re.sub(r'[^a-zA-Z0-9-]+', '-', ref_fragment)
|
||||||
if ref_name not in self._rules and ref not in self._refs_being_resolved:
|
if ref_name not in self._rules and ref not in self._refs_being_resolved:
|
||||||
self._refs_being_resolved.add(ref)
|
self._refs_being_resolved.add(ref)
|
||||||
resolved = self._refs[ref]
|
resolved = self._refs[ref]
|
||||||
@@ -583,9 +596,10 @@ class SchemaConverter:
|
|||||||
properties = list(schema.get('properties', {}).items())
|
properties = list(schema.get('properties', {}).items())
|
||||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
|
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
|
||||||
|
|
||||||
elif schema_type in (None, 'object') and 'allOf' in schema:
|
elif schema_type in (None, 'object', 'string') and 'allOf' in schema:
|
||||||
required = set()
|
required = set()
|
||||||
properties = []
|
properties = []
|
||||||
|
enum_sets = []
|
||||||
hybrid_name = name
|
hybrid_name = name
|
||||||
def add_component(comp_schema, is_required):
|
def add_component(comp_schema, is_required):
|
||||||
if (ref := comp_schema.get('$ref')) is not None:
|
if (ref := comp_schema.get('$ref')) is not None:
|
||||||
@@ -597,6 +611,9 @@ class SchemaConverter:
|
|||||||
if is_required:
|
if is_required:
|
||||||
required.add(prop_name)
|
required.add(prop_name)
|
||||||
|
|
||||||
|
if 'enum' in comp_schema:
|
||||||
|
enum_sets.append(set(comp_schema['enum']))
|
||||||
|
|
||||||
for t in schema['allOf']:
|
for t in schema['allOf']:
|
||||||
if 'anyOf' in t:
|
if 'anyOf' in t:
|
||||||
for tt in t['anyOf']:
|
for tt in t['anyOf']:
|
||||||
@@ -604,6 +621,15 @@ class SchemaConverter:
|
|||||||
else:
|
else:
|
||||||
add_component(t, is_required=True)
|
add_component(t, is_required=True)
|
||||||
|
|
||||||
|
if enum_sets:
|
||||||
|
enum_intersection = enum_sets[0]
|
||||||
|
for s in enum_sets[1:]:
|
||||||
|
enum_intersection &= s
|
||||||
|
|
||||||
|
if enum_intersection:
|
||||||
|
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space'
|
||||||
|
return self._add_rule(rule_name, rule)
|
||||||
|
|
||||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
|
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
|
||||||
|
|
||||||
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
|
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
|
||||||
|
|||||||
@@ -345,10 +345,14 @@ export class SchemaConverter {
|
|||||||
|
|
||||||
const selectors = ref.split('#')[1].split('/').slice(1);
|
const selectors = ref.split('#')[1].split('/').slice(1);
|
||||||
for (const sel of selectors) {
|
for (const sel of selectors) {
|
||||||
if (!target || !(sel in target)) {
|
const selIndex = parseInt(sel, 10);
|
||||||
|
if (target && sel in target) {
|
||||||
|
target = target[sel];
|
||||||
|
} else if (target && selIndex in target) {
|
||||||
|
target = target[selIndex];
|
||||||
|
} else {
|
||||||
throw new Error(`Error resolving ref ${ref}: ${sel} not in ${JSON.stringify(target)}`);
|
throw new Error(`Error resolving ref ${ref}: ${sel} not in ${JSON.stringify(target)}`);
|
||||||
}
|
}
|
||||||
target = target[sel];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
this._refs[ref] = target;
|
this._refs[ref] = target;
|
||||||
@@ -594,7 +598,8 @@ export class SchemaConverter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_resolveRef(ref) {
|
_resolveRef(ref) {
|
||||||
let refName = ref.split('/').pop();
|
let refFragment = ref.split('#').pop();
|
||||||
|
let refName = 'ref' + refFragment.replace(/[^a-zA-Z0-9-]+/g, '-');
|
||||||
if (!(refName in this._rules) && !this._refsBeingResolved.has(ref)) {
|
if (!(refName in this._rules) && !this._refsBeingResolved.has(ref)) {
|
||||||
this._refsBeingResolved.add(ref);
|
this._refsBeingResolved.add(ref);
|
||||||
const resolved = this._refs[ref];
|
const resolved = this._refs[ref];
|
||||||
@@ -631,9 +636,10 @@ export class SchemaConverter {
|
|||||||
const required = new Set(schema.required || []);
|
const required = new Set(schema.required || []);
|
||||||
const properties = Object.entries(schema.properties ?? {});
|
const properties = Object.entries(schema.properties ?? {});
|
||||||
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, schema.additionalProperties));
|
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, schema.additionalProperties));
|
||||||
} else if ((schemaType === undefined || schemaType === 'object') && 'allOf' in schema) {
|
} else if ((schemaType === undefined || schemaType === 'object' || schemaType === 'string') && 'allOf' in schema) {
|
||||||
const required = new Set();
|
const required = new Set();
|
||||||
const properties = [];
|
const properties = [];
|
||||||
|
const enumSets = [];
|
||||||
const addComponent = (compSchema, isRequired) => {
|
const addComponent = (compSchema, isRequired) => {
|
||||||
const ref = compSchema.$ref;
|
const ref = compSchema.$ref;
|
||||||
if (ref !== undefined) {
|
if (ref !== undefined) {
|
||||||
@@ -648,6 +654,10 @@ export class SchemaConverter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ('enum' in compSchema) {
|
||||||
|
enumSets.push(new Set(compSchema.enum || []));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
for (const t of schema.allOf) {
|
for (const t of schema.allOf) {
|
||||||
@@ -660,6 +670,14 @@ export class SchemaConverter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (enumSets.length > 0) {
|
||||||
|
const enumIntersection = new Set([...enumSets[0]].filter(v => enumSets.every(s => s.has(v))));
|
||||||
|
if (enumIntersection.size > 0) {
|
||||||
|
const sortedEnums = [...enumIntersection].sort((a, b) => a.localeCompare(b));
|
||||||
|
const rule = '(' + sortedEnums.map(v => this._generateConstantRule(v)).join(' | ') + ') space';
|
||||||
|
return this._addRule(ruleName, rule);
|
||||||
|
}
|
||||||
|
}
|
||||||
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, null));
|
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, null));
|
||||||
} else if ((schemaType === undefined || schemaType === 'array') && ('items' in schema || 'prefixItems' in schema)) {
|
} else if ((schemaType === undefined || schemaType === 'array') && ('items' in schema || 'prefixItems' in schema)) {
|
||||||
const items = schema.items ?? schema.prefixItems;
|
const items = schema.items ?? schema.prefixItems;
|
||||||
|
|||||||
@@ -3779,7 +3779,7 @@ struct server_context {
|
|||||||
common_prefix prefix = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, true); // string level match
|
common_prefix prefix = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, true); // string level match
|
||||||
common_prefix prefix_nonexact = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, false);
|
common_prefix prefix_nonexact = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, false);
|
||||||
auto n_past0 = slot.cache_tokens.get_common_prefix_exact(prompt_tokens); // token level match
|
auto n_past0 = slot.cache_tokens.get_common_prefix_exact(prompt_tokens); // token level match
|
||||||
LLAMA_LOG_INFO("======== Cache: cache_size = %ld, n_past0 = %ld, n_past1 = %ld, n_past_prompt1 = %ld, n_past2 = %ld, n_past_prompt2 = %ld\n", (int32_t) slot.cache_tokens.size(), (int32_t) n_past0, (int32_t) prefix.first, prefix.second, (int32_t) prefix_nonexact.first, (int32_t) prefix_nonexact.second);
|
LLAMA_LOG_INFO("======== Cache: cache_size = %d, n_past0 = %d, n_past1 = %d, n_past_prompt1 = %d, n_past2 = %d, n_past_prompt2 = %d\n", (int32_t) slot.cache_tokens.size(), (int32_t) n_past0, (int32_t) prefix.first, (int32_t)prefix.second, (int32_t) prefix_nonexact.first, (int32_t) prefix_nonexact.second);
|
||||||
int32_t size_threshold = 20;
|
int32_t size_threshold = 20;
|
||||||
if (prefix.first + size_threshold < prefix_nonexact.first) {
|
if (prefix.first + size_threshold < prefix_nonexact.first) {
|
||||||
LLAMA_LOG_WARN("Common part contains missing or extra space and new line\n");
|
LLAMA_LOG_WARN("Common part contains missing or extra space and new line\n");
|
||||||
@@ -4042,7 +4042,7 @@ struct server_context {
|
|||||||
slot.state = SLOT_STATE_PROCESSING;
|
slot.state = SLOT_STATE_PROCESSING;
|
||||||
slot.command = SLOT_COMMAND_NONE;
|
slot.command = SLOT_COMMAND_NONE;
|
||||||
slot.release();
|
slot.release();
|
||||||
LLAMA_LOG_INFO("n_past =% d\n", slot.cache_tokens.size());
|
LLAMA_LOG_INFO("n_past = %d\n", (int)slot.cache_tokens.size());
|
||||||
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
|
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
|
||||||
}
|
}
|
||||||
break; // break loop of n_batch
|
break; // break loop of n_batch
|
||||||
|
|||||||
@@ -913,7 +913,9 @@ static json oaicompat_chat_params_parse(
|
|||||||
|
|
||||||
llama_params["chat_format"] = static_cast<int>(chat_params.format);
|
llama_params["chat_format"] = static_cast<int>(chat_params.format);
|
||||||
llama_params["prompt"] = chat_params.prompt;
|
llama_params["prompt"] = chat_params.prompt;
|
||||||
|
if (!chat_params.grammar.empty()) {
|
||||||
llama_params["grammar"] = chat_params.grammar;
|
llama_params["grammar"] = chat_params.grammar;
|
||||||
|
}
|
||||||
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
|
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
|
||||||
auto grammar_triggers = json::array();
|
auto grammar_triggers = json::array();
|
||||||
for (const auto & trigger : chat_params.grammar_triggers) {
|
for (const auto & trigger : chat_params.grammar_triggers) {
|
||||||
|
|||||||
6
grammars/english.gbnf
Normal file
6
grammars/english.gbnf
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
# note: this might be incomplete, mostly an example
|
||||||
|
root ::= en-char+ ([ \t\n] en-char+)*
|
||||||
|
en-char ::= letter | digit | punctuation
|
||||||
|
letter ::= [a-zA-Z]
|
||||||
|
digit ::= [0-9]
|
||||||
|
punctuation ::= [!"#$%&'()*+,-./:;<=>?@[\\\]^_`{|}~]
|
||||||
@@ -1313,12 +1313,16 @@ extern "C" {
|
|||||||
|
|
||||||
LLAMA_API void llama_sampler_reset(struct llama_sampler* smpl);
|
LLAMA_API void llama_sampler_reset(struct llama_sampler* smpl);
|
||||||
|
|
||||||
|
/// @details Intializes a GBNF grammar, see grammars/README.md for details.
|
||||||
|
/// @param vocab The vocabulary that this grammar will be used with.
|
||||||
|
/// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails.
|
||||||
|
/// @param grammar_root The name of the start symbol for the grammar.
|
||||||
LLAMA_API struct llama_grammar* llama_sampler_init_grammar(
|
LLAMA_API struct llama_grammar* llama_sampler_init_grammar(
|
||||||
const struct llama_vocab* vocab,
|
const struct llama_vocab* vocab,
|
||||||
const char* grammar_str,
|
const char* grammar_str,
|
||||||
|
|
||||||
const char* grammar_root);
|
const char* grammar_root);
|
||||||
/// @details Lazy grammar sampler, introduced in https://github.com/ggerganov/llama.cpp/pull/9639
|
|
||||||
|
/// @details Lazy grammar sampler, introduced in https://github.com/ggerganov/llama.cpp/pull/9639
|
||||||
/// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future.
|
/// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future.
|
||||||
/// @param trigger_tokens A list of tokens that will trigger the grammar sampler.
|
/// @param trigger_tokens A list of tokens that will trigger the grammar sampler.
|
||||||
DEPRECATED(LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy(
|
DEPRECATED(LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy(
|
||||||
@@ -1473,11 +1477,7 @@ using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
|
|||||||
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
|
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
|
||||||
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
|
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
|
||||||
|
|
||||||
void llama_grammar_accept(
|
void llama_grammar_accept(struct llama_grammar* grammar, uint32_t chr);
|
||||||
const llama_grammar_rules & rules,
|
|
||||||
const llama_grammar_stacks & stacks,
|
|
||||||
const uint32_t chr,
|
|
||||||
llama_grammar_stacks & new_stacks);
|
|
||||||
|
|
||||||
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
||||||
const llama_grammar_rules & rules,
|
const llama_grammar_rules & rules,
|
||||||
|
|||||||
@@ -5,8 +5,14 @@
|
|||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cstdint>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
|
#define MAX_REPETITION_THRESHOLD 2000
|
||||||
|
//
|
||||||
|
// helpers
|
||||||
|
//
|
||||||
|
|
||||||
// NOTE: assumes valid utf8 (but checks for overrun)
|
// NOTE: assumes valid utf8 (but checks for overrun)
|
||||||
static std::pair<uint32_t, const char*> decode_utf8(const char* src) {
|
static std::pair<uint32_t, const char*> decode_utf8(const char* src) {
|
||||||
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||||
@@ -349,8 +355,10 @@ const char* llama_grammar_parser::parse_sequence(
|
|||||||
size_t last_sym_start = rule.size();
|
size_t last_sym_start = rule.size();
|
||||||
const char* pos = src;
|
const char* pos = src;
|
||||||
|
|
||||||
auto handle_repetitions = [&](int min_times, int max_times) {
|
// use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
|
||||||
|
// (though it's technically the same as -1 now)
|
||||||
|
auto handle_repetitions = [&](uint64_t min_times, uint64_t max_times) {
|
||||||
|
bool no_max = max_times == UINT64_MAX;
|
||||||
if (last_sym_start == rule.size()) {
|
if (last_sym_start == rule.size()) {
|
||||||
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
|
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
|
||||||
}
|
}
|
||||||
@@ -378,20 +386,20 @@ const char* llama_grammar_parser::parse_sequence(
|
|||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// Repeat the previous elements (min_times - 1) times
|
// Repeat the previous elements (min_times - 1) times
|
||||||
for (int i = 1; i < min_times; i++) {
|
for (uint64_t i = 1; i < min_times; i++) {
|
||||||
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
|
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t last_rec_rule_id = 0;
|
uint32_t last_rec_rule_id = 0;
|
||||||
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
|
auto n_opt = no_max ? 1 : max_times - min_times;
|
||||||
|
|
||||||
llama_grammar_rule rec_rule(prev_rule);
|
llama_grammar_rule rec_rule(prev_rule);
|
||||||
for (int i = 0; i < n_opt; i++) {
|
for (uint64_t i = 0; i < n_opt; i++) {
|
||||||
rec_rule.resize(prev_rule.size());
|
rec_rule.resize(prev_rule.size());
|
||||||
uint32_t rec_rule_id = generate_symbol_id(rule_name);
|
uint32_t rec_rule_id = generate_symbol_id(rule_name);
|
||||||
if (i > 0 || max_times < 0) {
|
if (i > 0 || no_max) {
|
||||||
rec_rule.push_back({ LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id });
|
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, no_max ? rec_rule_id : last_rec_rule_id});
|
||||||
}
|
}
|
||||||
rec_rule.push_back({ LLAMA_GRETYPE_ALT, 0 });
|
rec_rule.push_back({ LLAMA_GRETYPE_ALT, 0 });
|
||||||
rec_rule.push_back({ LLAMA_GRETYPE_END, 0 });
|
rec_rule.push_back({ LLAMA_GRETYPE_END, 0 });
|
||||||
@@ -491,10 +499,10 @@ const char* llama_grammar_parser::parse_sequence(
|
|||||||
throw std::runtime_error(std::string("expecting an int at ") + pos);
|
throw std::runtime_error(std::string("expecting an int at ") + pos);
|
||||||
}
|
}
|
||||||
const char* int_end = parse_int(pos);
|
const char* int_end = parse_int(pos);
|
||||||
int min_times = std::stoul(std::string(pos, int_end - pos));
|
uint64_t min_times = std::stoul(std::string(pos, int_end - pos));
|
||||||
pos = parse_space(int_end, is_nested);
|
pos = parse_space(int_end, is_nested);
|
||||||
|
|
||||||
int max_times = -1;
|
uint64_t max_times = UINT64_MAX; // default: no max limit
|
||||||
|
|
||||||
if (*pos == '}') {
|
if (*pos == '}') {
|
||||||
max_times = min_times;
|
max_times = min_times;
|
||||||
@@ -517,6 +525,10 @@ const char* llama_grammar_parser::parse_sequence(
|
|||||||
else {
|
else {
|
||||||
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
||||||
}
|
}
|
||||||
|
bool has_max = max_times != UINT64_MAX;
|
||||||
|
if (min_times > MAX_REPETITION_THRESHOLD || (has_max && max_times > MAX_REPETITION_THRESHOLD)) {
|
||||||
|
throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions"));
|
||||||
|
}
|
||||||
handle_repetitions(min_times, max_times);
|
handle_repetitions(min_times, max_times);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
@@ -857,32 +869,30 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
|
|||||||
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
||||||
// produces the N possible stacks if the given char is accepted at those
|
// produces the N possible stacks if the given char is accepted at those
|
||||||
// positions
|
// positions
|
||||||
void llama_grammar_accept(
|
void llama_grammar_accept(struct llama_grammar* grammar, uint32_t chr) {
|
||||||
const llama_grammar_rules & rules,
|
llama_grammar_stacks stacks_new;
|
||||||
const llama_grammar_stacks & stacks,
|
stacks_new.reserve(grammar->stacks.size());
|
||||||
const uint32_t chr,
|
|
||||||
llama_grammar_stacks & new_stacks) {
|
|
||||||
new_stacks.clear();
|
|
||||||
|
|
||||||
for (const auto & stack : stacks) {
|
for (const auto& stack : grammar->stacks) {
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto match = llama_grammar_match_char(stack.back(), chr);
|
auto match = llama_grammar_match_char(stack.back(), chr);
|
||||||
if (match.first) {
|
if (match.first) {
|
||||||
const llama_grammar_element * pos = match.second;
|
const llama_grammar_element* pos = match.second;
|
||||||
|
|
||||||
// update top of stack to next element, if any
|
// update top of stack to next element, if any
|
||||||
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
||||||
if (!llama_grammar_is_end_of_sequence(pos)) {
|
if (!llama_grammar_is_end_of_sequence(pos)) {
|
||||||
new_stack.push_back(pos);
|
new_stack.push_back(pos);
|
||||||
}
|
}
|
||||||
llama_grammar_advance_stack(rules, new_stack, new_stacks);
|
llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
grammar->stacks = std::move(stacks_new);
|
||||||
|
}
|
||||||
|
|
||||||
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
||||||
const llama_grammar_rules & rules,
|
const llama_grammar_rules & rules,
|
||||||
@@ -1236,11 +1246,11 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc
|
|||||||
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
|
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
|
||||||
grammar->trigger_buffer.clear();
|
grammar->trigger_buffer.clear();
|
||||||
llama_grammar_accept_str(grammar, constrained_str);
|
llama_grammar_accept_str(grammar, constrained_str);
|
||||||
//LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
|
LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str());
|
LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1259,29 +1269,17 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llama_grammar_accept_str(struct llama_grammar* grammar, const std::string& piece) {
|
void llama_grammar_accept_str(struct llama_grammar* grammar, const std::string& piece) {
|
||||||
|
|
||||||
// Note terminating 0 in decoded string
|
// Note terminating 0 in decoded string
|
||||||
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
|
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
|
||||||
const auto & code_points = decoded.first;
|
const auto& code_points = decoded.first;
|
||||||
llama_grammar_stacks tmp_new_stacks;
|
|
||||||
for (auto it = code_points.begin(), end = code_points.end()-1; it != end; ++it) {
|
|
||||||
llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
|
|
||||||
// avoid empty grammar stack at the end of the code_points
|
|
||||||
// mainline has this bug too, reason unknown
|
|
||||||
if (end == code_points.end() - 1) {
|
|
||||||
if (tmp_new_stacks.size()) {
|
|
||||||
grammar->stacks = tmp_new_stacks;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
grammar->stacks = tmp_new_stacks;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||||
|
llama_grammar_accept(grammar, *it);
|
||||||
}
|
}
|
||||||
|
|
||||||
grammar->partial_utf8 = decoded.second;
|
grammar->partial_utf8 = decoded.second;
|
||||||
if (grammar->stacks.empty()) {
|
if (grammar->stacks.empty()) {
|
||||||
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
|
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ struct llama_grammar {
|
|||||||
llama_partial_utf8 partial_utf8;
|
llama_partial_utf8 partial_utf8;
|
||||||
|
|
||||||
// lazy grammars wait for trigger words or tokens before constraining the sampling.
|
// lazy grammars wait for trigger words or tokens before constraining the sampling.
|
||||||
// we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
|
// we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
|
||||||
// (useful e.g. for tool_choice=required)
|
// (useful e.g. for tool_choice=required)
|
||||||
bool lazy = false;
|
bool lazy = false;
|
||||||
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
|
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
|
||||||
|
|||||||
@@ -1171,8 +1171,10 @@ struct llama_grammar* llama_sampler_init_grammar_impl(
|
|||||||
num_trigger_patterns = 1;
|
num_trigger_patterns = 1;
|
||||||
}
|
}
|
||||||
grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens);
|
grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens);
|
||||||
|
if (!grammar) {
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
else {
|
} else {
|
||||||
grammar = nullptr;
|
grammar = nullptr;
|
||||||
}
|
}
|
||||||
return grammar;
|
return grammar;
|
||||||
|
|||||||
@@ -318,6 +318,30 @@ static void test_simple_grammar() {
|
|||||||
"0123",
|
"0123",
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
test_schema(
|
||||||
|
"min 1 max 900719925474091",
|
||||||
|
// Schema
|
||||||
|
R"""({
|
||||||
|
"type": "integer",
|
||||||
|
"exclusiveMinimum": 0,
|
||||||
|
"maximum": 900719925474091
|
||||||
|
})""",
|
||||||
|
// Passing strings
|
||||||
|
{
|
||||||
|
"1",
|
||||||
|
"2",
|
||||||
|
"10",
|
||||||
|
"900719925474090",
|
||||||
|
"900719925474091",
|
||||||
|
},
|
||||||
|
// Failing strings
|
||||||
|
{
|
||||||
|
"0",
|
||||||
|
"01",
|
||||||
|
"900719925474092",
|
||||||
|
"9007199254740910",
|
||||||
|
}
|
||||||
|
);
|
||||||
test_schema(
|
test_schema(
|
||||||
"min -1 max 1",
|
"min -1 max 1",
|
||||||
R"""({
|
R"""({
|
||||||
|
|||||||
@@ -595,6 +595,22 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||||||
)"""
|
)"""
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test({
|
||||||
|
SUCCESS,
|
||||||
|
"maxItems 0",
|
||||||
|
R"""({
|
||||||
|
"items": {
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
"maxItems": 0
|
||||||
|
})""",
|
||||||
|
R"""(
|
||||||
|
boolean ::= ("true" | "false") space
|
||||||
|
root ::= "[" space "]" space
|
||||||
|
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||||
|
)"""
|
||||||
|
});
|
||||||
|
|
||||||
test({
|
test({
|
||||||
SUCCESS,
|
SUCCESS,
|
||||||
"maxItems 1",
|
"maxItems 1",
|
||||||
@@ -694,7 +710,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||||||
"pattern": "^abc?d*efg+(hij)?kl$"
|
"pattern": "^abc?d*efg+(hij)?kl$"
|
||||||
})""",
|
})""",
|
||||||
R"""(
|
R"""(
|
||||||
root ::= "\"" "ab" "c"? "d"* "ef" "g"+ ("hij")? "kl" "\"" space
|
root ::= "\"" ("ab" "c"? "d"* "ef" "g"+ ("hij")? "kl") "\"" space
|
||||||
space ::= | " " | "\n" [ \t]{0,20}
|
space ::= | " " | "\n" [ \t]{0,20}
|
||||||
)"""
|
)"""
|
||||||
});
|
});
|
||||||
@@ -707,7 +723,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||||||
"pattern": "^\\[\\]\\{\\}\\(\\)\\|\\+\\*\\?$"
|
"pattern": "^\\[\\]\\{\\}\\(\\)\\|\\+\\*\\?$"
|
||||||
})""",
|
})""",
|
||||||
R"""(
|
R"""(
|
||||||
root ::= "\"" "[]{}()|+*?" "\"" space
|
root ::= "\"" ("[]{}()|+*?") "\"" space
|
||||||
space ::= | " " | "\n" [ \t]{0,20}
|
space ::= | " " | "\n" [ \t]{0,20}
|
||||||
)"""
|
)"""
|
||||||
});
|
});
|
||||||
@@ -720,7 +736,20 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||||||
"pattern": "^\"$"
|
"pattern": "^\"$"
|
||||||
})""",
|
})""",
|
||||||
R"""(
|
R"""(
|
||||||
root ::= "\"" "\"" "\"" space
|
root ::= "\"" ("\"") "\"" space
|
||||||
|
space ::= | " " | "\n" [ \t]{0,20}
|
||||||
|
)"""
|
||||||
|
});
|
||||||
|
|
||||||
|
test({
|
||||||
|
SUCCESS,
|
||||||
|
"regexp with top-level alternation",
|
||||||
|
R"""({
|
||||||
|
"type": "string",
|
||||||
|
"pattern": "^A|B|C|D$"
|
||||||
|
})""",
|
||||||
|
R"""(
|
||||||
|
root ::= "\"" ("A" | "B" | "C" | "D") "\"" space
|
||||||
space ::= | " " | "\n" [ \t]{0,20}
|
space ::= | " " | "\n" [ \t]{0,20}
|
||||||
)"""
|
)"""
|
||||||
});
|
});
|
||||||
@@ -734,7 +763,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||||||
})""",
|
})""",
|
||||||
R"""(
|
R"""(
|
||||||
dot ::= [^\x0A\x0D]
|
dot ::= [^\x0A\x0D]
|
||||||
root ::= "\"" ("(" root-1{1,3} ")")? root-1{3,3} "-" root-1{4,4} " " "a"{3,5} "nd" dot dot dot "\"" space
|
root ::= "\"" (("(" root-1{1,3} ")")? root-1{3,3} "-" root-1{4,4} " " "a"{3,5} "nd" dot dot dot) "\"" space
|
||||||
root-1 ::= [0-9]
|
root-1 ::= [0-9]
|
||||||
space ::= | " " | "\n" [ \t]{0,20}
|
space ::= | " " | "\n" [ \t]{0,20}
|
||||||
)"""
|
)"""
|
||||||
@@ -1091,9 +1120,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||||||
})""",
|
})""",
|
||||||
R"""(
|
R"""(
|
||||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||||
foo ::= "{" space foo-a-kv "}" space
|
ref-definitions-foo ::= "{" space ref-definitions-foo-a-kv "}" space
|
||||||
foo-a-kv ::= "\"a\"" space ":" space string
|
ref-definitions-foo-a-kv ::= "\"a\"" space ":" space string
|
||||||
root ::= foo
|
root ::= ref-definitions-foo
|
||||||
space ::= | " " | "\n" [ \t]{0,20}
|
space ::= | " " | "\n" [ \t]{0,20}
|
||||||
string ::= "\"" char* "\"" space
|
string ::= "\"" char* "\"" space
|
||||||
)"""
|
)"""
|
||||||
@@ -1118,20 +1147,58 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||||||
"type": "object"
|
"type": "object"
|
||||||
})""",
|
})""",
|
||||||
R"""(
|
R"""(
|
||||||
alternative-0 ::= foo
|
alternative-0 ::= ref-definitions-foo
|
||||||
alternative-1 ::= bar
|
alternative-1 ::= ref-definitions-bar
|
||||||
bar ::= "{" space (bar-b-kv )? "}" space
|
|
||||||
bar-b-kv ::= "\"b\"" space ":" space number
|
|
||||||
decimal-part ::= [0-9]{1,16}
|
decimal-part ::= [0-9]{1,16}
|
||||||
foo ::= "{" space (foo-a-kv )? "}" space
|
|
||||||
foo-a-kv ::= "\"a\"" space ":" space number
|
|
||||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||||
|
ref-definitions-bar ::= "{" space (ref-definitions-bar-b-kv )? "}" space
|
||||||
|
ref-definitions-bar-b-kv ::= "\"b\"" space ":" space number
|
||||||
|
ref-definitions-foo ::= "{" space (ref-definitions-foo-a-kv )? "}" space
|
||||||
|
ref-definitions-foo-a-kv ::= "\"a\"" space ":" space number
|
||||||
root ::= alternative-0 | alternative-1
|
root ::= alternative-0 | alternative-1
|
||||||
space ::= | " " | "\n" [ \t]{0,20}
|
space ::= | " " | "\n" [ \t]{0,20}
|
||||||
)"""
|
)"""
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test({
|
||||||
|
SUCCESS,
|
||||||
|
"anyOf $ref",
|
||||||
|
R"""({
|
||||||
|
"properties": {
|
||||||
|
"a": {
|
||||||
|
"anyOf": [
|
||||||
|
{"type": "string"},
|
||||||
|
{"type": "number"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"b": {
|
||||||
|
"anyOf": [
|
||||||
|
{"$ref": "#/properties/a/anyOf/0"},
|
||||||
|
{"type": "boolean"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "object"
|
||||||
|
})""",
|
||||||
|
R"""(
|
||||||
|
a ::= string | number
|
||||||
|
a-kv ::= "\"a\"" space ":" space a
|
||||||
|
a-rest ::= ( "," space b-kv )?
|
||||||
|
b ::= b-0 | boolean
|
||||||
|
b-0 ::= string
|
||||||
|
b-kv ::= "\"b\"" space ":" space b
|
||||||
|
boolean ::= ("true" | "false") space
|
||||||
|
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||||
|
decimal-part ::= [0-9]{1,16}
|
||||||
|
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||||
|
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||||
|
root ::= "{" space (a-kv a-rest | b-kv )? "}" space
|
||||||
|
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||||
|
string ::= "\"" char* "\"" space
|
||||||
|
)"""
|
||||||
|
});
|
||||||
|
|
||||||
test({
|
test({
|
||||||
SUCCESS,
|
SUCCESS,
|
||||||
"mix of allOf, anyOf and $ref (similar to https://json.schemastore.org/tsconfig.json)",
|
"mix of allOf, anyOf and $ref (similar to https://json.schemastore.org/tsconfig.json)",
|
||||||
@@ -1176,6 +1243,51 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||||||
)"""
|
)"""
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test({
|
||||||
|
SUCCESS,
|
||||||
|
"allOf with enum schema",
|
||||||
|
R"""({
|
||||||
|
"allOf": [
|
||||||
|
{"$ref": "#/definitions/foo"}
|
||||||
|
],
|
||||||
|
"definitions": {
|
||||||
|
"foo": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["a", "b"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})""",
|
||||||
|
R"""(
|
||||||
|
root ::= ("\"a\"" | "\"b\"") space
|
||||||
|
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||||
|
)"""
|
||||||
|
});
|
||||||
|
|
||||||
|
test({
|
||||||
|
SUCCESS,
|
||||||
|
"allOf with multiple enum schemas",
|
||||||
|
R"""({
|
||||||
|
"allOf": [
|
||||||
|
{"$ref": "#/definitions/foo"},
|
||||||
|
{"$ref": "#/definitions/bar"}
|
||||||
|
],
|
||||||
|
"definitions": {
|
||||||
|
"foo": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["a", "b", "c"]
|
||||||
|
},
|
||||||
|
"bar": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["b", "c", "d"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})""",
|
||||||
|
R"""(
|
||||||
|
root ::= ("\"b\"" | "\"c\"") space
|
||||||
|
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||||
|
)"""
|
||||||
|
});
|
||||||
|
|
||||||
test({
|
test({
|
||||||
SUCCESS,
|
SUCCESS,
|
||||||
"conflicting names",
|
"conflicting names",
|
||||||
|
|||||||
Reference in New Issue
Block a user