mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-02 10:00:07 +00:00
Update grammar (#1023)
* grammar : fix JSON Schema for string regex with top-level alt. (#9903) Prior to this commit, using a JSON Schema containing a string with `pattern` regular expression that uses top-level alternation (e.g. `"pattern": "^A|B|C|D$"`) would result in invalid JSON output from the constrained sampling grammar, because it ended up creating a grammar rule like this for the string: ``` thing ::= "\"" "A" | "B" | "C" | "D" "\"" space ``` Note that this rule will only match a starting quote for the "A" case, and will only match an ending quote for the "D" case, so this rule will always produce invalid JSON when used for sampling (that is, the JSON will always be lacking the starting quote, the ending quote, or both). This was fixed in a simple way by adding parentheses to the generated rule (for all string pattern rules, to keep it simple), such that the new generated rule looks like this (correct): ``` thing ::= "\"" ("A" | "B" | "C" | "D") "\"" space ``` * grammars : add English-only grammar (#10612) * grammar : handle maxItems == 0 in JSON schema (#13117) Co-authored-by: Richard Lyons <frob@cloudstaff.com> * grammar-parser : fix possible null-deref (#9004) Fixes: https://bugs.chromium.org/p/oss-fuzz/issues/detail?id=70680 Signed-off-by: David Korczynski <david@adalogics.com> * llama : fix typo in llama-grammar.h [no ci] (#11816) * * server: fix "--grammar-file" parameter (#12285) * common : use std::string_view now that we target c++17 (#14319) * json : support `enum` values within `allOf` (#15830) * grammar : use int64_t to avoid int overflows in int schema to grammar conversion logic (#16626) * grammar : support array references in json schema (#16792) * grammar : support array references in json schema * Update json-schema-to-grammar.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * grammar : improve regex when naming ref derived rules * grammar : replace non-conformant definitions array with anyOf test case --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> # Conflicts: # tests/test-json-schema-to-grammar.cpp * merge fix * llama : minor grammar refactor (#10897) * llama: fix error on bad grammar (#12628) * grammar : fix integer overflow (#17381) * Fix DoS / integer overflow * Remove optional, use INT64_MAX instead as placeholder value (it's technically -1, so it fits :) * White space * Actually, since it's unsigned, use UINT64_MAX # Conflicts: # src/llama-grammar.cpp * grammar: fix regression caused by #17381 (#17412) * grammar: fix regression caused by #17381 * more readable # Conflicts: # src/llama-grammar.cpp * Merge Fix * Fix warnings --------- Signed-off-by: David Korczynski <david@adalogics.com> Co-authored-by: Joe Eli McIlvain <joe.eli.mac@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: frob <rick+github@frob.com.au> Co-authored-by: Richard Lyons <frob@cloudstaff.com> Co-authored-by: DavidKorczynski <david@adalogics.com> Co-authored-by: Daniel Bevenius <daniel.bevenius@gmail.com> Co-authored-by: firecoperana <firecoperana> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Co-authored-by: Aldehir Rojas <hello@alde.dev> Co-authored-by: Olivier Chafik <olivier.chafik@gmail.com> Co-authored-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com> Co-authored-by: Xuan-Son Nguyen <son@huggingface.co> Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -13,22 +13,14 @@
|
||||
#include <vector>
|
||||
|
||||
static bool llama_sample_grammar_string(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
|
||||
auto decoded = decode_utf8(input_str, {});
|
||||
const auto & code_points = decoded.first;
|
||||
|
||||
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
||||
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
|
||||
|
||||
const auto cpts = unicode_cpts_from_utf8(input_str);
|
||||
auto& cur_stacks = llama_grammar_get_stacks(grammar);
|
||||
size_t pos = 0;
|
||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
|
||||
|
||||
llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
|
||||
|
||||
for (const auto& cpt : cpts) {
|
||||
llama_grammar_accept(grammar, cpt);
|
||||
if (cur_stacks.empty()) {
|
||||
error_pos = pos;
|
||||
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
|
||||
cur_stacks = prev_stacks;
|
||||
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
|
||||
return false;
|
||||
}
|
||||
++pos;
|
||||
|
||||
@@ -10,6 +10,9 @@ from typing import Any, List, Optional, Set, Tuple, Union
|
||||
|
||||
def _build_repetition(item_rule, min_items, max_items, separator_rule=None):
|
||||
|
||||
if max_items == 0:
|
||||
return ""
|
||||
|
||||
if min_items == 0 and max_items == 1:
|
||||
return f'{item_rule}?'
|
||||
|
||||
@@ -368,8 +371,17 @@ class SchemaConverter:
|
||||
raise ValueError(f'Unsupported ref {ref}')
|
||||
|
||||
for sel in ref.split('#')[-1].split('/')[1:]:
|
||||
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
|
||||
target = target[sel]
|
||||
assert target is not None, f'Error resolving ref {ref}: {sel} not in {target}'
|
||||
if isinstance(target, list):
|
||||
try:
|
||||
sel_index = int(sel)
|
||||
except ValueError:
|
||||
raise ValueError(f'Error resolving ref {ref}: {sel} not in {target}')
|
||||
assert 0 <= sel_index < len(target), f'Error resolving ref {ref}: {sel} not in {target}'
|
||||
target = target[sel_index]
|
||||
else:
|
||||
assert sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
|
||||
target = target[sel]
|
||||
|
||||
self._refs[ref] = target
|
||||
else:
|
||||
@@ -540,11 +552,12 @@ class SchemaConverter:
|
||||
return self._add_rule(
|
||||
name,
|
||||
to_rule(transform()) if self._raw_pattern \
|
||||
else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space")
|
||||
else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space")
|
||||
|
||||
|
||||
def _resolve_ref(self, ref):
|
||||
ref_name = ref.split('/')[-1]
|
||||
ref_fragment = ref.split('#')[-1]
|
||||
ref_name = 'ref' + re.sub(r'[^a-zA-Z0-9-]+', '-', ref_fragment)
|
||||
if ref_name not in self._rules and ref not in self._refs_being_resolved:
|
||||
self._refs_being_resolved.add(ref)
|
||||
resolved = self._refs[ref]
|
||||
@@ -583,9 +596,10 @@ class SchemaConverter:
|
||||
properties = list(schema.get('properties', {}).items())
|
||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
|
||||
|
||||
elif schema_type in (None, 'object') and 'allOf' in schema:
|
||||
elif schema_type in (None, 'object', 'string') and 'allOf' in schema:
|
||||
required = set()
|
||||
properties = []
|
||||
enum_sets = []
|
||||
hybrid_name = name
|
||||
def add_component(comp_schema, is_required):
|
||||
if (ref := comp_schema.get('$ref')) is not None:
|
||||
@@ -597,6 +611,9 @@ class SchemaConverter:
|
||||
if is_required:
|
||||
required.add(prop_name)
|
||||
|
||||
if 'enum' in comp_schema:
|
||||
enum_sets.append(set(comp_schema['enum']))
|
||||
|
||||
for t in schema['allOf']:
|
||||
if 'anyOf' in t:
|
||||
for tt in t['anyOf']:
|
||||
@@ -604,6 +621,15 @@ class SchemaConverter:
|
||||
else:
|
||||
add_component(t, is_required=True)
|
||||
|
||||
if enum_sets:
|
||||
enum_intersection = enum_sets[0]
|
||||
for s in enum_sets[1:]:
|
||||
enum_intersection &= s
|
||||
|
||||
if enum_intersection:
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space'
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
|
||||
|
||||
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
|
||||
|
||||
@@ -345,10 +345,14 @@ export class SchemaConverter {
|
||||
|
||||
const selectors = ref.split('#')[1].split('/').slice(1);
|
||||
for (const sel of selectors) {
|
||||
if (!target || !(sel in target)) {
|
||||
const selIndex = parseInt(sel, 10);
|
||||
if (target && sel in target) {
|
||||
target = target[sel];
|
||||
} else if (target && selIndex in target) {
|
||||
target = target[selIndex];
|
||||
} else {
|
||||
throw new Error(`Error resolving ref ${ref}: ${sel} not in ${JSON.stringify(target)}`);
|
||||
}
|
||||
target = target[sel];
|
||||
}
|
||||
|
||||
this._refs[ref] = target;
|
||||
@@ -594,7 +598,8 @@ export class SchemaConverter {
|
||||
}
|
||||
|
||||
_resolveRef(ref) {
|
||||
let refName = ref.split('/').pop();
|
||||
let refFragment = ref.split('#').pop();
|
||||
let refName = 'ref' + refFragment.replace(/[^a-zA-Z0-9-]+/g, '-');
|
||||
if (!(refName in this._rules) && !this._refsBeingResolved.has(ref)) {
|
||||
this._refsBeingResolved.add(ref);
|
||||
const resolved = this._refs[ref];
|
||||
@@ -631,9 +636,10 @@ export class SchemaConverter {
|
||||
const required = new Set(schema.required || []);
|
||||
const properties = Object.entries(schema.properties ?? {});
|
||||
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, schema.additionalProperties));
|
||||
} else if ((schemaType === undefined || schemaType === 'object') && 'allOf' in schema) {
|
||||
} else if ((schemaType === undefined || schemaType === 'object' || schemaType === 'string') && 'allOf' in schema) {
|
||||
const required = new Set();
|
||||
const properties = [];
|
||||
const enumSets = [];
|
||||
const addComponent = (compSchema, isRequired) => {
|
||||
const ref = compSchema.$ref;
|
||||
if (ref !== undefined) {
|
||||
@@ -648,6 +654,10 @@ export class SchemaConverter {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ('enum' in compSchema) {
|
||||
enumSets.push(new Set(compSchema.enum || []));
|
||||
}
|
||||
};
|
||||
|
||||
for (const t of schema.allOf) {
|
||||
@@ -660,6 +670,14 @@ export class SchemaConverter {
|
||||
}
|
||||
}
|
||||
|
||||
if (enumSets.length > 0) {
|
||||
const enumIntersection = new Set([...enumSets[0]].filter(v => enumSets.every(s => s.has(v))));
|
||||
if (enumIntersection.size > 0) {
|
||||
const sortedEnums = [...enumIntersection].sort((a, b) => a.localeCompare(b));
|
||||
const rule = '(' + sortedEnums.map(v => this._generateConstantRule(v)).join(' | ') + ') space';
|
||||
return this._addRule(ruleName, rule);
|
||||
}
|
||||
}
|
||||
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, null));
|
||||
} else if ((schemaType === undefined || schemaType === 'array') && ('items' in schema || 'prefixItems' in schema)) {
|
||||
const items = schema.items ?? schema.prefixItems;
|
||||
|
||||
@@ -3779,7 +3779,7 @@ struct server_context {
|
||||
common_prefix prefix = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, true); // string level match
|
||||
common_prefix prefix_nonexact = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, false);
|
||||
auto n_past0 = slot.cache_tokens.get_common_prefix_exact(prompt_tokens); // token level match
|
||||
LLAMA_LOG_INFO("======== Cache: cache_size = %ld, n_past0 = %ld, n_past1 = %ld, n_past_prompt1 = %ld, n_past2 = %ld, n_past_prompt2 = %ld\n", (int32_t) slot.cache_tokens.size(), (int32_t) n_past0, (int32_t) prefix.first, prefix.second, (int32_t) prefix_nonexact.first, (int32_t) prefix_nonexact.second);
|
||||
LLAMA_LOG_INFO("======== Cache: cache_size = %d, n_past0 = %d, n_past1 = %d, n_past_prompt1 = %d, n_past2 = %d, n_past_prompt2 = %d\n", (int32_t) slot.cache_tokens.size(), (int32_t) n_past0, (int32_t) prefix.first, (int32_t)prefix.second, (int32_t) prefix_nonexact.first, (int32_t) prefix_nonexact.second);
|
||||
int32_t size_threshold = 20;
|
||||
if (prefix.first + size_threshold < prefix_nonexact.first) {
|
||||
LLAMA_LOG_WARN("Common part contains missing or extra space and new line\n");
|
||||
@@ -4042,7 +4042,7 @@ struct server_context {
|
||||
slot.state = SLOT_STATE_PROCESSING;
|
||||
slot.command = SLOT_COMMAND_NONE;
|
||||
slot.release();
|
||||
LLAMA_LOG_INFO("n_past =% d\n", slot.cache_tokens.size());
|
||||
LLAMA_LOG_INFO("n_past = %d\n", (int)slot.cache_tokens.size());
|
||||
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
|
||||
}
|
||||
break; // break loop of n_batch
|
||||
|
||||
@@ -913,7 +913,9 @@ static json oaicompat_chat_params_parse(
|
||||
|
||||
llama_params["chat_format"] = static_cast<int>(chat_params.format);
|
||||
llama_params["prompt"] = chat_params.prompt;
|
||||
llama_params["grammar"] = chat_params.grammar;
|
||||
if (!chat_params.grammar.empty()) {
|
||||
llama_params["grammar"] = chat_params.grammar;
|
||||
}
|
||||
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
|
||||
auto grammar_triggers = json::array();
|
||||
for (const auto & trigger : chat_params.grammar_triggers) {
|
||||
|
||||
Reference in New Issue
Block a user