mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Merge mainline llama.cpp (#3)
* Merging mainline - WIP * Merging mainline - WIP AVX2 and CUDA appear to work. CUDA performance seems slightly (~1-2%) lower as it is so often the case with llama.cpp/ggml after some "improvements" have been made. * Merging mainline - fix Metal * Remove check --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, Dict, List, Set, Tuple, Union
|
||||
|
||||
from typing import Any, List, Optional, Set, Tuple, Union
|
||||
|
||||
def _build_repetition(item_rule, min_items, max_items, separator_rule=None):
|
||||
|
||||
@@ -23,9 +24,173 @@ def _build_repetition(item_rule, min_items, max_items, separator_rule=None):
|
||||
result = item_rule + ' ' + _build_repetition(f'({separator_rule} {item_rule})', min_items - 1 if min_items > 0 else 0, max_items - 1 if max_items is not None else None)
|
||||
return f'({result})?' if min_items == 0 else result
|
||||
|
||||
def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True):
|
||||
has_min = min_value != None
|
||||
has_max = max_value != None
|
||||
|
||||
def digit_range(from_char: str, to_char: str):
|
||||
out.append("[")
|
||||
if from_char == to_char:
|
||||
out.append(from_char)
|
||||
else:
|
||||
out.append(from_char)
|
||||
out.append("-")
|
||||
out.append(to_char)
|
||||
out.append("]")
|
||||
|
||||
def more_digits(min_digits: int, max_digits: int):
|
||||
out.append("[0-9]")
|
||||
if min_digits == max_digits and min_digits == 1:
|
||||
return
|
||||
out.append("{")
|
||||
out.append(str(min_digits))
|
||||
if max_digits != min_digits:
|
||||
out.append(",")
|
||||
if max_digits != sys.maxsize:
|
||||
out.append(str(max_digits))
|
||||
out.append("}")
|
||||
|
||||
def uniform_range(from_str: str, to_str: str):
|
||||
i = 0
|
||||
while i < len(from_str) and from_str[i] == to_str[i]:
|
||||
i += 1
|
||||
if i > 0:
|
||||
out.append("\"")
|
||||
out.append(from_str[:i])
|
||||
out.append("\"")
|
||||
if i < len(from_str):
|
||||
if i > 0:
|
||||
out.append(" ")
|
||||
sub_len = len(from_str) - i - 1
|
||||
if sub_len > 0:
|
||||
from_sub = from_str[i+1:]
|
||||
to_sub = to_str[i+1:]
|
||||
sub_zeros = "0" * sub_len
|
||||
sub_nines = "9" * sub_len
|
||||
|
||||
to_reached = False
|
||||
out.append("(")
|
||||
if from_sub == sub_zeros:
|
||||
digit_range(from_str[i], chr(ord(to_str[i]) - 1))
|
||||
out.append(" ")
|
||||
more_digits(sub_len, sub_len)
|
||||
else:
|
||||
out.append("[")
|
||||
out.append(from_str[i])
|
||||
out.append("] ")
|
||||
out.append("(")
|
||||
uniform_range(from_sub, sub_nines)
|
||||
out.append(")")
|
||||
if ord(from_str[i]) < ord(to_str[i]) - 1:
|
||||
out.append(" | ")
|
||||
if to_sub == sub_nines:
|
||||
digit_range(chr(ord(from_str[i]) + 1), to_str[i])
|
||||
to_reached = True
|
||||
else:
|
||||
digit_range(chr(ord(from_str[i]) + 1), chr(ord(to_str[i]) - 1))
|
||||
out.append(" ")
|
||||
more_digits(sub_len, sub_len)
|
||||
if not to_reached:
|
||||
out.append(" | ")
|
||||
digit_range(to_str[i], to_str[i])
|
||||
out.append(" ")
|
||||
uniform_range(sub_zeros, to_sub)
|
||||
out.append(")")
|
||||
else:
|
||||
out.append("[")
|
||||
out.append(from_str[i])
|
||||
out.append("-")
|
||||
out.append(to_str[i])
|
||||
out.append("]")
|
||||
|
||||
if has_min and has_max:
|
||||
if min_value < 0 and max_value < 0:
|
||||
out.append("\"-\" (")
|
||||
_generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True)
|
||||
out.append(")")
|
||||
return
|
||||
|
||||
if min_value < 0:
|
||||
out.append("\"-\" (")
|
||||
_generate_min_max_int(0, -min_value, out, decimals_left, top_level=True)
|
||||
out.append(") | ")
|
||||
min_value = 0
|
||||
|
||||
min_s = str(min_value)
|
||||
max_s = str(max_value)
|
||||
min_digits = len(min_s)
|
||||
max_digits = len(max_s)
|
||||
|
||||
for digits in range(min_digits, max_digits):
|
||||
uniform_range(min_s, "9" * digits)
|
||||
min_s = "1" + "0" * digits
|
||||
out.append(" | ")
|
||||
uniform_range(min_s, max_s)
|
||||
return
|
||||
|
||||
less_decimals = max(decimals_left - 1, 1)
|
||||
|
||||
if has_min:
|
||||
if min_value < 0:
|
||||
out.append("\"-\" (")
|
||||
_generate_min_max_int(None, -min_value, out, decimals_left, top_level=False)
|
||||
out.append(") | [0] | [1-9] ")
|
||||
more_digits(0, decimals_left - 1)
|
||||
elif min_value == 0:
|
||||
if top_level:
|
||||
out.append("[0] | [1-9] ")
|
||||
more_digits(0, less_decimals)
|
||||
else:
|
||||
more_digits(1, decimals_left)
|
||||
elif min_value <= 9:
|
||||
c = str(min_value)
|
||||
range_start = '1' if top_level else '0'
|
||||
if c > range_start:
|
||||
digit_range(range_start, chr(ord(c) - 1))
|
||||
out.append(" ")
|
||||
more_digits(1, less_decimals)
|
||||
out.append(" | ")
|
||||
digit_range(c, "9")
|
||||
out.append(" ")
|
||||
more_digits(0, less_decimals)
|
||||
else:
|
||||
min_s = str(min_value)
|
||||
length = len(min_s)
|
||||
c = min_s[0]
|
||||
|
||||
if c > "1":
|
||||
digit_range("1" if top_level else "0", chr(ord(c) - 1))
|
||||
out.append(" ")
|
||||
more_digits(length, less_decimals)
|
||||
out.append(" | ")
|
||||
digit_range(c, c)
|
||||
out.append(" (")
|
||||
_generate_min_max_int(int(min_s[1:]), None, out, less_decimals, top_level=False)
|
||||
out.append(")")
|
||||
if c < "9":
|
||||
out.append(" | ")
|
||||
digit_range(chr(ord(c) + 1), "9")
|
||||
out.append(" ")
|
||||
more_digits(length - 1, less_decimals)
|
||||
return
|
||||
|
||||
if has_max:
|
||||
if max_value >= 0:
|
||||
if top_level:
|
||||
out.append("\"-\" [1-9] ")
|
||||
more_digits(0, less_decimals)
|
||||
out.append(" | ")
|
||||
_generate_min_max_int(0, max_value, out, decimals_left, top_level=True)
|
||||
else:
|
||||
out.append("\"-\" (")
|
||||
_generate_min_max_int(-max_value, None, out, decimals_left, top_level=False)
|
||||
out.append(")")
|
||||
return
|
||||
|
||||
raise RuntimeError("At least one of min_value or max_value must be set")
|
||||
|
||||
class BuiltinRule:
|
||||
def __init__(self, content: str, deps: list = None):
|
||||
def __init__(self, content: str, deps: list | None = None):
|
||||
self.content = content
|
||||
self.deps = deps or []
|
||||
|
||||
@@ -68,7 +233,7 @@ GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]')
|
||||
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'}
|
||||
|
||||
NON_LITERAL_SET = set('|.()[]{}*+?')
|
||||
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?')
|
||||
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?')
|
||||
|
||||
|
||||
class SchemaConverter:
|
||||
@@ -85,7 +250,7 @@ class SchemaConverter:
|
||||
|
||||
def _format_literal(self, literal):
|
||||
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
|
||||
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
|
||||
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)) or m.group(0), literal
|
||||
)
|
||||
return f'"{escaped}"'
|
||||
|
||||
@@ -112,6 +277,51 @@ class SchemaConverter:
|
||||
|
||||
return ''.join(('(', *recurse(0), ')'))
|
||||
|
||||
def _not_strings(self, strings):
|
||||
class TrieNode:
|
||||
def __init__(self):
|
||||
self.children = {}
|
||||
self.is_end_of_string = False
|
||||
|
||||
def insert(self, string):
|
||||
node = self
|
||||
for c in string:
|
||||
node = node.children.setdefault(c, TrieNode())
|
||||
node.is_end_of_string = True
|
||||
|
||||
trie = TrieNode()
|
||||
for s in strings:
|
||||
trie.insert(s)
|
||||
|
||||
char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
|
||||
out = ['["] ( ']
|
||||
|
||||
def visit(node):
|
||||
rejects = []
|
||||
first = True
|
||||
for c in sorted(node.children.keys()):
|
||||
child = node.children[c]
|
||||
rejects.append(c)
|
||||
if first:
|
||||
first = False
|
||||
else:
|
||||
out.append(' | ')
|
||||
out.append(f'[{c}]')
|
||||
if child.children:
|
||||
out.append(f' (')
|
||||
visit(child)
|
||||
out.append(')')
|
||||
elif child.is_end_of_string:
|
||||
out.append(f' {char_rule}+')
|
||||
if node.children:
|
||||
if not first:
|
||||
out.append(' | ')
|
||||
out.append(f'[^"{"".join(rejects)}] {char_rule}*')
|
||||
visit(trie)
|
||||
|
||||
out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space')
|
||||
return ''.join(out)
|
||||
|
||||
def _add_rule(self, name, rule):
|
||||
esc_name = INVALID_RULE_CHARS_RE.sub('-', name)
|
||||
if esc_name not in self._rules or self._rules[esc_name] == rule:
|
||||
@@ -195,11 +405,11 @@ class SchemaConverter:
|
||||
i = 0
|
||||
length = len(pattern)
|
||||
|
||||
def to_rule(s: Tuple[str, bool]) -> str:
|
||||
def to_rule(s: tuple[str, bool]) -> str:
|
||||
(txt, is_literal) = s
|
||||
return "\"" + txt + "\"" if is_literal else txt
|
||||
|
||||
def transform() -> Tuple[str, bool]:
|
||||
def transform() -> tuple[str, bool]:
|
||||
'''
|
||||
Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
|
||||
'''
|
||||
@@ -212,7 +422,7 @@ class SchemaConverter:
|
||||
# We only need a flat structure here to apply repetition operators to the last item, and
|
||||
# to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
|
||||
# (GBNF's syntax is luckily very close to regular expressions!)
|
||||
seq: list[Tuple[str, bool]] = []
|
||||
seq: list[tuple[str, bool]] = []
|
||||
|
||||
def get_dot():
|
||||
if self._dotall:
|
||||
@@ -357,13 +567,13 @@ class SchemaConverter:
|
||||
return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf']))
|
||||
|
||||
elif isinstance(schema_type, list):
|
||||
return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type]))
|
||||
return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type]))
|
||||
|
||||
elif 'const' in schema:
|
||||
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']))
|
||||
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space')
|
||||
|
||||
elif 'enum' in schema:
|
||||
rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum']))
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space'
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
elif schema_type in (None, 'object') and \
|
||||
@@ -394,7 +604,7 @@ class SchemaConverter:
|
||||
else:
|
||||
add_component(t, is_required=True)
|
||||
|
||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=[]))
|
||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
|
||||
|
||||
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
|
||||
items = schema.get('items') or schema['prefixItems']
|
||||
@@ -432,6 +642,24 @@ class SchemaConverter:
|
||||
|
||||
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')
|
||||
|
||||
elif schema_type in (None, 'integer') and \
|
||||
('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema):
|
||||
min_value = None
|
||||
max_value = None
|
||||
if 'minimum' in schema:
|
||||
min_value = schema['minimum']
|
||||
elif 'exclusiveMinimum' in schema:
|
||||
min_value = schema['exclusiveMinimum'] + 1
|
||||
if 'maximum' in schema:
|
||||
max_value = schema['maximum']
|
||||
elif 'exclusiveMaximum' in schema:
|
||||
max_value = schema['exclusiveMaximum'] - 1
|
||||
|
||||
out = ["("]
|
||||
_generate_min_max_int(min_value, max_value, out)
|
||||
out.append(") space")
|
||||
return self._add_rule(rule_name, ''.join(out))
|
||||
|
||||
elif (schema_type == 'object') or (len(schema) == 0):
|
||||
return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object']))
|
||||
|
||||
@@ -450,7 +678,7 @@ class SchemaConverter:
|
||||
self._add_primitive(dep, dep_rule)
|
||||
return n
|
||||
|
||||
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]):
|
||||
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Optional[Union[bool, Any]]):
|
||||
prop_order = self._prop_order
|
||||
# sort by position in prop_order (if specified) then by original order
|
||||
sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))]
|
||||
@@ -465,12 +693,16 @@ class SchemaConverter:
|
||||
required_props = [k for k in sorted_props if k in required]
|
||||
optional_props = [k for k in sorted_props if k not in required]
|
||||
|
||||
if additional_properties == True or isinstance(additional_properties, dict):
|
||||
if additional_properties is not None and additional_properties != False:
|
||||
sub_name = f'{name}{"-" if name else ""}additional'
|
||||
value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value')
|
||||
value_rule = self.visit(additional_properties, f'{sub_name}-value') if isinstance(additional_properties, dict) else \
|
||||
self._add_primitive('value', PRIMITIVE_RULES['value'])
|
||||
key_rule = self._add_primitive('string', PRIMITIVE_RULES['string']) if not sorted_props \
|
||||
else self._add_rule(f'{sub_name}-k', self._not_strings(sorted_props))
|
||||
|
||||
prop_kv_rule_names["*"] = self._add_rule(
|
||||
f'{sub_name}-kv',
|
||||
self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}'
|
||||
f'{key_rule} ":" space {value_rule}'
|
||||
)
|
||||
optional_props.append("*")
|
||||
|
||||
@@ -485,15 +717,11 @@ class SchemaConverter:
|
||||
def get_recursive_refs(ks, first_is_optional):
|
||||
[k, *rest] = ks
|
||||
kv_rule_name = prop_kv_rule_names[k]
|
||||
if k == '*':
|
||||
res = self._add_rule(
|
||||
f'{name}{"-" if name else ""}additional-kvs',
|
||||
f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*'
|
||||
)
|
||||
elif first_is_optional:
|
||||
res = f'( "," space {kv_rule_name} )?'
|
||||
comma_ref = f'( "," space {kv_rule_name} )'
|
||||
if first_is_optional:
|
||||
res = comma_ref + ('*' if k == '*' else '?')
|
||||
else:
|
||||
res = kv_rule_name
|
||||
res = kv_rule_name + (' ' + comma_ref + "*" if k == '*' else '')
|
||||
if len(rest) > 0:
|
||||
res += ' ' + self._add_rule(
|
||||
f'{name}{"-" if name else ""}{k}-rest',
|
||||
|
||||
Reference in New Issue
Block a user