mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-05-13 01:05:59 +00:00
- move tool config from template_vars to separate yml config - new per-gen stream collector used for both streaming and non-streaming requests to ensure logic is consistent for both - move responsibility for switching between phases to stream collector - collect tool calls during streaming and parse at the end of each gen - prevent streaming empty content spans (be nice to clients) - correctly aggregate usage stats for n>1 requests, always emit with last chunk in last gen to finish - collect logprobs in model wrapper and correctly handle logprobs for multi-token chars etc. - respect top_logprobs argument in request - handle a number of edge cases like <think> tag being part of held string, etc. - retain tool parsing and inference-abort fixes from #413, apply similar fix to non-stream request as well Still TODO: - testing and validation with more models and tool schemas (tested on Qwen so far) - enable JSON constraint for JSON tool models - possibly some pydantification - documentation
504 lines
18 KiB
Python
504 lines
18 KiB
Python
"""Tool call processing utilities for OAI server."""
|
|
|
|
import json
|
|
import re
|
|
from common.logger import xlogger
|
|
from typing import Any, List, Tuple
|
|
|
|
from endpoints.OAI.types.tools import ToolCall, Tool
|
|
|
|
|
|
TOOL_CALL_SCHEMA = {
|
|
"$schema": "http://json-schema.org/draft-07/schema#",
|
|
"type": "array",
|
|
"items": {
|
|
"type": "object",
|
|
"properties": {
|
|
"function": {
|
|
"type": "object",
|
|
"properties": {
|
|
"name": {"type": "string"},
|
|
"arguments": {
|
|
# Converted to OAI's string in post process
|
|
"type": "object"
|
|
},
|
|
},
|
|
"required": ["name", "arguments"],
|
|
},
|
|
},
|
|
"required": ["function"],
|
|
},
|
|
}
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# XML parsing regex patterns
|
|
# Derived from vLLM's Qwen3CoderToolParser and the official Qwen parser.
|
|
# These handle both complete and partially-closed tags.
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Matches complete <tool_call>...</tool_call> blocks
|
|
TOOL_CALL_BLOCK_RE = re.compile(
|
|
r"<tool_call>(.*?)</tool_call>",
|
|
re.DOTALL,
|
|
)
|
|
|
|
# Matches <function=NAME>BODY</function> blocks
|
|
FUNCTION_RE = re.compile(
|
|
r"<function=(.*?)>(.*?)</function>",
|
|
re.DOTALL,
|
|
)
|
|
|
|
# Matches <parameter=KEY>VALUE</terminator>
|
|
# Terminates on: </parameter>, next <parameter=, </function>, or <tool_call>
|
|
PARAMETER_RE = re.compile(
|
|
r"<parameter=(.*?)>(.*?)"
|
|
r"(?:</parameter>|(?=<parameter=)|(?=</function>)|(?=<tool_call>))",
|
|
re.DOTALL,
|
|
)
|
|
|
|
# Markdown code fence patterns
|
|
CODE_FENCE_RE = re.compile(r"^```(?:json)?\s*", re.MULTILINE)
|
|
CODE_FENCE_END_RE = re.compile(r"\s*```\s*$", re.MULTILINE)
|
|
|
|
def _coerce_param_value(raw: str) -> Any:
|
|
"""Coerce a raw parameter value string to the appropriate Python type.
|
|
|
|
Strategy (safe, no eval()):
|
|
1. Strip leading/trailing newlines (official template emits \\n
|
|
after opening tag and before closing tag).
|
|
2. Try json.loads — handles objects, arrays, numbers, bools, null.
|
|
3. Fall back to plain string.
|
|
"""
|
|
# Strip template-inserted newlines around values
|
|
if raw.startswith("\n"):
|
|
raw = raw[1:]
|
|
if raw.endswith("\n"):
|
|
raw = raw[:-1]
|
|
|
|
stripped = raw.strip()
|
|
|
|
# Empty string
|
|
if not stripped:
|
|
return ""
|
|
|
|
# Try JSON parse (handles objects, arrays, numbers, booleans, null)
|
|
try:
|
|
return json.loads(stripped)
|
|
except (json.JSONDecodeError, ValueError):
|
|
pass
|
|
|
|
# Fall back to string — never eval()
|
|
return stripped
|
|
|
|
|
|
class ToolCallProcessor:
|
|
|
|
# ------------------------------------------------------------------
|
|
# JSON normalization helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _normalize_tool_calls(raw) -> list:
|
|
"""Normalize model-emitted tool call payloads into OAI-like objects.
|
|
|
|
Accepted forms:
|
|
- [{"type":"function","function":{"name":...,"arguments":{...}}}]
|
|
- [{"name":...,"arguments":{...}}]
|
|
- {"name":...,"arguments":{...}}
|
|
"""
|
|
if isinstance(raw, dict):
|
|
raw = [raw]
|
|
if not isinstance(raw, list):
|
|
raise ValueError("tool_calls payload is not list/dict")
|
|
|
|
normalized: list = []
|
|
for item in raw:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
|
|
if "function" in item and isinstance(item["function"], dict):
|
|
fn = item["function"]
|
|
name = fn.get("name")
|
|
arguments = fn.get("arguments", {})
|
|
else:
|
|
name = item.get("name")
|
|
arguments = item.get("arguments", {})
|
|
|
|
if name is None:
|
|
continue
|
|
|
|
if isinstance(arguments, str):
|
|
try:
|
|
arguments = json.loads(arguments)
|
|
except json.JSONDecodeError:
|
|
arguments = {"input": arguments}
|
|
|
|
normalized.append(
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": name,
|
|
"arguments": arguments if isinstance(arguments, dict) else {},
|
|
},
|
|
}
|
|
)
|
|
return normalized
|
|
|
|
@staticmethod
|
|
def _safe_json_loads(payload: str) -> list:
|
|
"""Best-effort JSON parse for model-emitted tool payloads.
|
|
|
|
Handles: clean JSON, markdown-fenced JSON, JSON substrings in
|
|
surrounding text, flat {name, arguments} dicts, and single objects.
|
|
"""
|
|
# Direct parse
|
|
try:
|
|
return ToolCallProcessor._normalize_tool_calls(json.loads(payload))
|
|
except (json.JSONDecodeError, ValueError):
|
|
pass
|
|
|
|
# Clean up common model artifacts (markdown fences, whitespace)
|
|
cleaned = payload.strip()
|
|
cleaned = CODE_FENCE_RE.sub("", cleaned)
|
|
cleaned = CODE_FENCE_END_RE.sub("", cleaned)
|
|
cleaned = cleaned.strip()
|
|
|
|
# Try cleaned
|
|
try:
|
|
return ToolCallProcessor._normalize_tool_calls(json.loads(cleaned))
|
|
except (json.JSONDecodeError, ValueError):
|
|
pass
|
|
|
|
# Find JSON array substring
|
|
start = cleaned.find("[")
|
|
end = cleaned.rfind("]")
|
|
if start != -1 and end != -1 and end > start:
|
|
try:
|
|
return ToolCallProcessor._normalize_tool_calls(json.loads(cleaned[start : end + 1]))
|
|
except (json.JSONDecodeError, ValueError):
|
|
pass
|
|
|
|
# Find JSON object substring
|
|
obj_start = cleaned.find("{")
|
|
obj_end = cleaned.rfind("}")
|
|
if obj_start != -1 and obj_end != -1 and obj_end > obj_start:
|
|
try:
|
|
return ToolCallProcessor._normalize_tool_calls(
|
|
json.loads(cleaned[obj_start : obj_end + 1])
|
|
)
|
|
except (json.JSONDecodeError, ValueError):
|
|
pass
|
|
|
|
raise json.JSONDecodeError("Could not extract valid JSON from payload", payload, 0)
|
|
|
|
# ------------------------------------------------------------------
|
|
# JSON parsing
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def from_json(tool_calls_str: str) -> List[ToolCall]:
|
|
"""Postprocess tool call JSON to a parseable class.
|
|
|
|
Handles clean JSON arrays, markdown-fenced output, flat dicts,
|
|
and other common model output variations via _safe_json_loads.
|
|
"""
|
|
xlogger.debug(f"JSON Parser: Parsing tool calls ({len(tool_calls_str)} chars)")
|
|
|
|
tool_calls = ToolCallProcessor._safe_json_loads(tool_calls_str)
|
|
for tool_call in tool_calls:
|
|
tool_call["function"]["arguments"] = json.dumps(tool_call["function"]["arguments"])
|
|
|
|
result = [ToolCall(**tool_call) for tool_call in tool_calls]
|
|
xlogger.debug(f"JSON Parser: Successfully parsed {len(result)} tool call(s)")
|
|
return result
|
|
|
|
# ------------------------------------------------------------------
|
|
# XML parsing (Qwen3-Coder / GLM-4.5 style)
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def from_xml(raw_text: str) -> List[ToolCall]:
|
|
"""Parse Qwen3-Coder XML-format tool calls into ToolCall objects.
|
|
|
|
Handles:
|
|
- Wrapped: <tool_call><function=name>...</function></tool_call>
|
|
- Bare: <function=name>...</function> (missing wrapper)
|
|
- Multiple sequential tool call blocks
|
|
- <think> blocks (stripped)
|
|
- Multi-line parameter values
|
|
- Missing </parameter> closing tags
|
|
"""
|
|
xlogger.debug(f"XML Parser: Parsing tool calls ({len(raw_text)} chars)")
|
|
|
|
# Stage 1: Strip think blocks
|
|
text = raw_text
|
|
|
|
# Stage 2: Check for incomplete XML at end (generation cutoff)
|
|
stripped_end = text.rstrip()
|
|
if stripped_end.endswith(("<", "</", "<parameter", "<function")):
|
|
xlogger.warning(
|
|
f"XML Parser: Detected incomplete XML tag at end: ...{stripped_end[-80:]}"
|
|
)
|
|
text = re.sub(r"<[^>]*$", "", text)
|
|
|
|
# Stage 3: Extract function blocks
|
|
# First, find all wrapped <tool_call>...</tool_call> blocks
|
|
wrapped_positions = [(m.start(), m.end()) for m in TOOL_CALL_BLOCK_RE.finditer(text)]
|
|
|
|
# Collect function blocks from inside wrapped regions
|
|
function_blocks = []
|
|
for match in TOOL_CALL_BLOCK_RE.finditer(text):
|
|
inner = match.group(1)
|
|
for func_match in FUNCTION_RE.finditer(inner):
|
|
function_blocks.append((func_match.group(1), func_match.group(2)))
|
|
|
|
# Find bare <function> blocks NOT inside any wrapped region
|
|
for func_match in FUNCTION_RE.finditer(text):
|
|
pos = func_match.start()
|
|
is_wrapped = any(start <= pos < end for start, end in wrapped_positions)
|
|
if not is_wrapped:
|
|
xlogger.debug("XML Parser: Found bare <function> block without <tool_call> wrapper")
|
|
function_blocks.append((func_match.group(1), func_match.group(2)))
|
|
|
|
if not function_blocks:
|
|
xlogger.warning("XML Parser: No <function=...> blocks found")
|
|
return []
|
|
|
|
# Stage 4: Parse each function block into a ToolCall
|
|
tool_calls = []
|
|
for func_name_raw, func_body in function_blocks:
|
|
func_name = func_name_raw.strip()
|
|
|
|
# Extract parameters
|
|
params = {}
|
|
for param_match in PARAMETER_RE.finditer(func_body):
|
|
key = param_match.group(1).strip()
|
|
value_raw = param_match.group(2)
|
|
value = _coerce_param_value(value_raw)
|
|
params[key] = value
|
|
|
|
arguments_json = json.dumps(params, ensure_ascii=False)
|
|
|
|
tool_call = ToolCall(function=Tool(name=func_name, arguments=arguments_json))
|
|
tool_calls.append(tool_call)
|
|
|
|
xlogger.debug(f"XML Parser: Successfully parsed {len(tool_calls)} tool call(s)")
|
|
return tool_calls
|
|
|
|
# ------------------------------------------------------------------
|
|
# Auto-detect parsing (JSON → JSON-in-tool_call → XML)
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def from_auto(raw_text: str) -> List[ToolCall]:
|
|
"""Auto-detect format and parse.
|
|
|
|
Tries in order:
|
|
1. Pure JSON (standard TabbyAPI / Llama)
|
|
2. JSON inside <tool_call> wrappers (Qwen3-Instruct style)
|
|
3. XML with <function=...> tags (Qwen3-Coder style)
|
|
"""
|
|
xlogger.debug("Auto Parser: Attempting format auto-detection")
|
|
|
|
# Attempt 1: Pure JSON array
|
|
try:
|
|
result = ToolCallProcessor.from_json(raw_text)
|
|
xlogger.debug("Auto Parser: Detected JSON format")
|
|
return result
|
|
except (json.JSONDecodeError, ValueError, KeyError) as e:
|
|
xlogger.debug(f"Auto Parser: Not JSON ({e}), trying next format")
|
|
|
|
# Attempt 2: JSON inside <tool_call> wrappers (Qwen3-Instruct)
|
|
try:
|
|
all_tool_calls = []
|
|
for match in TOOL_CALL_BLOCK_RE.finditer(raw_text):
|
|
inner = match.group(1).strip()
|
|
if inner.startswith("{") or inner.startswith("["):
|
|
parsed = json.loads(inner)
|
|
if isinstance(parsed, dict):
|
|
parsed = [parsed]
|
|
if isinstance(parsed, list):
|
|
for tc in parsed:
|
|
name = tc.get("name", "")
|
|
arguments = tc.get("arguments", {})
|
|
if isinstance(arguments, dict):
|
|
arguments = json.dumps(arguments)
|
|
elif not isinstance(arguments, str):
|
|
arguments = json.dumps(arguments)
|
|
all_tool_calls.append(
|
|
ToolCall(function=Tool(name=name, arguments=arguments))
|
|
)
|
|
if all_tool_calls:
|
|
xlogger.debug(
|
|
"Auto Parser: Detected JSON-inside-tool_call "
|
|
f"format ({len(all_tool_calls)} call(s))"
|
|
)
|
|
return all_tool_calls
|
|
except (json.JSONDecodeError, ValueError, KeyError) as e:
|
|
xlogger.debug(
|
|
"Auto Parser: Not JSON-in-tool_call trying XML",
|
|
str(e),
|
|
details=f"({e})",
|
|
)
|
|
|
|
# Attempt 3: XML format (Qwen3-Coder style)
|
|
result = ToolCallProcessor.from_xml(raw_text)
|
|
if result:
|
|
xlogger.debug("Auto Parser: Detected XML format")
|
|
else:
|
|
xlogger.warning("Auto Parser: All format detection attempts failed")
|
|
return result
|
|
|
|
# ------------------------------------------------------------------
|
|
# Dispatcher
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def parse(tool_calls_str: str, tool_format: str = "json") -> List[ToolCall]:
|
|
"""Dispatch tool call parsing to the appropriate format handler.
|
|
|
|
Args:
|
|
tool_calls_str: Raw tool call text from model generation.
|
|
tool_format: One of ``"json"``, ``"xml"``, ``"auto"``.
|
|
|
|
Returns:
|
|
List of parsed ToolCall objects. Empty list on parse failure
|
|
(never raises).
|
|
"""
|
|
try:
|
|
if tool_format == "xml":
|
|
return ToolCallProcessor.from_xml(tool_calls_str)
|
|
elif tool_format == "auto":
|
|
return ToolCallProcessor.from_auto(tool_calls_str)
|
|
else:
|
|
return ToolCallProcessor.from_json(tool_calls_str)
|
|
except Exception as e:
|
|
xlogger.error(
|
|
"ToolCallProcessor.parse: Failed to parse tool calls",
|
|
{"tool_format": tool_format, "e": str(e)},
|
|
details=f"(format={tool_format}): {e}",
|
|
)
|
|
return []
|
|
|
|
# ------------------------------------------------------------------
|
|
# Filtering
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def filter_by_name(tool_calls: List[ToolCall], function_name: str) -> List[ToolCall]:
|
|
"""Filter parsed tool calls to only those matching a function name."""
|
|
filtered = [tc for tc in tool_calls if tc.function.name == function_name]
|
|
if not filtered:
|
|
xlogger.warning(
|
|
f"filter_by_name: No tool calls matched '{function_name}' "
|
|
f"(had {len(tool_calls)} call(s))"
|
|
)
|
|
return filtered
|
|
|
|
# ------------------------------------------------------------------
|
|
# Content / tool-call separation
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def extract_content_and_tools(
|
|
raw_text: str,
|
|
) -> Tuple[str, List[ToolCall]]:
|
|
"""Separate plain text content from XML tool call blocks.
|
|
|
|
Used when the model mixes reasoning text with tool calls, e.g.:
|
|
``"I'll help with that: <tool_call><function=...>...``
|
|
|
|
Returns:
|
|
Tuple of (remaining_content, tool_calls).
|
|
"""
|
|
text = raw_text
|
|
|
|
# Collect all XML regions to exclude from content
|
|
xml_regions = []
|
|
|
|
# Wrapped tool call blocks
|
|
for match in TOOL_CALL_BLOCK_RE.finditer(text):
|
|
xml_regions.append((match.start(), match.end()))
|
|
|
|
# Bare function blocks not inside wrappers
|
|
for match in FUNCTION_RE.finditer(text):
|
|
pos = match.start()
|
|
is_wrapped = any(start <= pos < end for start, end in xml_regions)
|
|
if not is_wrapped:
|
|
xml_regions.append((match.start(), match.end()))
|
|
|
|
# Sort and extract content (everything outside XML regions)
|
|
xml_regions.sort()
|
|
content_parts = []
|
|
last_end = 0
|
|
for start, end in xml_regions:
|
|
if start > last_end:
|
|
part = text[last_end:start].strip()
|
|
if part:
|
|
content_parts.append(part)
|
|
last_end = end
|
|
if last_end < len(text):
|
|
part = text[last_end:].strip()
|
|
if part:
|
|
content_parts.append(part)
|
|
|
|
content = " ".join(content_parts).strip()
|
|
|
|
# Parse tool calls from the full text
|
|
tool_calls = ToolCallProcessor.from_xml(text)
|
|
|
|
xlogger.debug(
|
|
f"extract_content_and_tools: Found {len(tool_calls)} tool call(s)",
|
|
{"tool_calls": tool_calls},
|
|
details=f" content={'yes' if content else 'no'} ({len(content)} chars)",
|
|
)
|
|
|
|
return content, tool_calls
|
|
|
|
# ------------------------------------------------------------------
|
|
# Serialisation helpers (unchanged from original)
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def dump(tool_calls: List[ToolCall]) -> List[dict]:
|
|
"""
|
|
Convert ToolCall objects to a list of dictionaries.
|
|
|
|
Args:
|
|
tool_calls (List[ToolCall]): List of ToolCall objects to convert
|
|
|
|
Returns:
|
|
List[dict]: List of dictionaries representing the tool calls
|
|
"""
|
|
|
|
# Don't use list comprehension here
|
|
# as that will fail rather than warn
|
|
dumped_tool_calls = []
|
|
for tool_call_obj in tool_calls:
|
|
try:
|
|
dumped_tool_calls.append(tool_call_obj.model_dump())
|
|
except (json.JSONDecodeError, AttributeError) as e:
|
|
xlogger.warning("Error processing tool call:", str(e), details=str(e))
|
|
return dumped_tool_calls
|
|
|
|
@staticmethod
|
|
def to_json(tool_calls: List[ToolCall]) -> str:
|
|
"""
|
|
Convert ToolCall objects to JSON string representation.
|
|
|
|
Args:
|
|
tool_calls (List[ToolCall]): List of ToolCall objects to convert
|
|
|
|
Returns:
|
|
str: JSON representation of the tool calls
|
|
"""
|
|
|
|
if not tool_calls:
|
|
return ""
|
|
|
|
# Use the dump method to get the list of dictionaries
|
|
dumped_tool_calls = ToolCallProcessor.dump(tool_calls)
|
|
|
|
# Serialize the dumped array
|
|
return json.dumps(dumped_tool_calls, indent=2)
|