mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-05-13 09:15:58 +00:00
764 lines
27 KiB
Python
764 lines
27 KiB
Python
"""Chat completion utilities for OAI server."""
|
|
|
|
import asyncio
|
|
import json
|
|
import pathlib
|
|
from asyncio import CancelledError
|
|
from typing import List, Optional
|
|
from fastapi import HTTPException, Request
|
|
from jinja2 import TemplateError
|
|
from loguru import logger
|
|
|
|
from common import model
|
|
from common.multimodal import MultimodalEmbeddingWrapper
|
|
from common.networking import (
|
|
get_generator_error,
|
|
handle_request_disconnect,
|
|
handle_request_error,
|
|
request_disconnect_loop,
|
|
)
|
|
from common.utils import unwrap
|
|
from endpoints.OAI.types.chat_completion import (
|
|
ChatCompletionLogprobs,
|
|
ChatCompletionLogprob,
|
|
ChatCompletionMessage,
|
|
ChatCompletionRequest,
|
|
ChatCompletionRespChoice,
|
|
ChatCompletionStreamChunk,
|
|
ChatCompletionResponse,
|
|
ChatCompletionStreamChoice,
|
|
)
|
|
from endpoints.OAI.types.common import UsageStats
|
|
from endpoints.OAI.types.tools import NamedToolChoice, ToolCall
|
|
from endpoints.OAI.utils.completion import _parse_gen_request_id, _stream_collector
|
|
from endpoints.OAI.utils.tools import ToolCallProcessor, TOOL_CALL_SCHEMA
|
|
|
|
|
|
def _serialize_stream_chunk(chunk) -> str:
|
|
"""Serialize a streaming chunk with OpenAI-compatible field handling.
|
|
|
|
Uses exclude_none=True to strip irrelevant null fields (tool_calls,
|
|
tool_call_id, logprobs, usage) while ensuring finish_reason is always
|
|
present on each choice (as null when not set), matching OpenAI's
|
|
observed streaming behavior.
|
|
"""
|
|
d = chunk.model_dump(exclude_none=True)
|
|
for choice in d.get("choices", []):
|
|
if "finish_reason" not in choice:
|
|
choice["finish_reason"] = None
|
|
return json.dumps(d, ensure_ascii=False)
|
|
|
|
|
|
def _create_response(
|
|
request_id: str,
|
|
generations: List[dict],
|
|
model_name: Optional[str],
|
|
tool_call_format: str = "json",
|
|
tool_choice=None,
|
|
):
|
|
"""Create a chat completion response from the provided text."""
|
|
|
|
choices = []
|
|
for index, generation in enumerate(generations):
|
|
message = ChatCompletionMessage(
|
|
role="assistant", content=unwrap(generation.get("text"), "")
|
|
)
|
|
|
|
tool_calls_raw = generation.get("tool_calls")
|
|
if tool_calls_raw:
|
|
parsed = ToolCallProcessor.parse(tool_calls_raw, format=tool_call_format)
|
|
if parsed and isinstance(tool_choice, NamedToolChoice):
|
|
parsed = ToolCallProcessor.filter_by_name(
|
|
parsed, tool_choice.function.name
|
|
)
|
|
if parsed:
|
|
message.tool_calls = parsed
|
|
else:
|
|
logger.warning(
|
|
"Tool call text present but parsing returned no results "
|
|
f"(format={tool_call_format})"
|
|
)
|
|
|
|
# Fallback: detect bare XML tool calls in content that were not
|
|
# caught by the two-pass system (model never emitted tool_start)
|
|
if (
|
|
tool_call_format in ("xml", "auto")
|
|
and not message.tool_calls
|
|
and message.content
|
|
and "<function=" in message.content
|
|
):
|
|
logger.warning(
|
|
"Fallback: Detected bare XML function blocks in content "
|
|
"(tool_start was likely not emitted by model)"
|
|
)
|
|
remaining, parsed = ToolCallProcessor.extract_content_and_tools(
|
|
message.content
|
|
)
|
|
if parsed:
|
|
message.tool_calls = parsed
|
|
message.content = remaining if remaining else None
|
|
|
|
logprob_response = None
|
|
|
|
token_probs = unwrap(generation.get("token_probs"), {})
|
|
if token_probs:
|
|
logprobs = unwrap(generation.get("logprobs"), [])
|
|
|
|
collected_token_probs = []
|
|
for index, token in enumerate(token_probs.keys()):
|
|
top_logprobs = [
|
|
ChatCompletionLogprob(token=token, logprob=logprob)
|
|
for token, logprob in logprobs[index].items()
|
|
]
|
|
|
|
collected_token_probs.append(
|
|
ChatCompletionLogprob(
|
|
token=token,
|
|
logprob=token_probs[token],
|
|
top_logprobs=top_logprobs,
|
|
)
|
|
)
|
|
|
|
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
|
|
|
|
# Set finish reason
|
|
if message.tool_calls:
|
|
finish_reason = "tool_calls"
|
|
else:
|
|
finish_reason = generation.get("finish_reason", "stop")
|
|
|
|
choice = ChatCompletionRespChoice(
|
|
index=index,
|
|
finish_reason=finish_reason,
|
|
stop_str=generation.get("stop_str"),
|
|
message=message,
|
|
logprobs=logprob_response,
|
|
)
|
|
|
|
choices.append(choice)
|
|
|
|
final_generation = generations[-1]
|
|
prompt_tokens = unwrap(final_generation.get("prompt_tokens"), 0)
|
|
completion_tokens = unwrap(final_generation.get("gen_tokens"), 0)
|
|
|
|
response = ChatCompletionResponse(
|
|
id=f"cmpl-{request_id}",
|
|
choices=choices,
|
|
model=model_name,
|
|
usage=UsageStats(
|
|
prompt_tokens=prompt_tokens,
|
|
prompt_time=final_generation.get("prompt_time"),
|
|
prompt_tokens_per_sec=final_generation.get("prompt_tokens_per_sec"),
|
|
completion_tokens=completion_tokens,
|
|
completion_time=final_generation.get("gen_time"),
|
|
completion_tokens_per_sec=final_generation.get("gen_tokens_per_sec"),
|
|
total_tokens=prompt_tokens + completion_tokens,
|
|
total_time=final_generation.get("total_time"),
|
|
),
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
def _create_stream_chunk(
|
|
request_id: str,
|
|
generation: Optional[dict] = None,
|
|
model_name: Optional[str] = None,
|
|
is_usage_chunk: bool = False,
|
|
):
|
|
"""Create a chat completion stream chunk from the provided text.
|
|
|
|
Note: Tool-call streaming is handled separately by
|
|
_build_tool_call_chunks() which emits the proper three-phase
|
|
OpenAI-standard chunk sequence.
|
|
"""
|
|
|
|
index = generation.get("index")
|
|
choices = []
|
|
usage_stats = None
|
|
|
|
if is_usage_chunk:
|
|
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
|
|
completion_tokens = unwrap(generation.get("gen_tokens"), 0)
|
|
|
|
usage_stats = UsageStats(
|
|
prompt_tokens=prompt_tokens,
|
|
prompt_time=generation.get("prompt_time"),
|
|
prompt_tokens_per_sec=generation.get("prompt_tokens_per_sec"),
|
|
completion_tokens=completion_tokens,
|
|
completion_time=generation.get("gen_time"),
|
|
completion_tokens_per_sec=generation.get("gen_tokens_per_sec"),
|
|
total_tokens=prompt_tokens + completion_tokens,
|
|
total_time=generation.get("total_time"),
|
|
)
|
|
elif "finish_reason" in generation:
|
|
finish_reason = generation.get("finish_reason")
|
|
choice = ChatCompletionStreamChoice(
|
|
index=index, finish_reason=finish_reason, delta={}
|
|
)
|
|
choices.append(choice)
|
|
else:
|
|
message = ChatCompletionMessage(
|
|
role="assistant", content=unwrap(generation.get("text"), "")
|
|
)
|
|
|
|
logprob_response = None
|
|
|
|
token_probs = unwrap(generation.get("token_probs"), {})
|
|
if token_probs:
|
|
logprobs = unwrap(generation.get("logprobs"), {})
|
|
top_logprobs = [
|
|
ChatCompletionLogprob(token=token, logprob=logprob)
|
|
for token, logprob in logprobs.items()
|
|
]
|
|
|
|
generated_token = next(iter(token_probs))
|
|
token_prob_response = ChatCompletionLogprob(
|
|
token=generated_token,
|
|
logprob=token_probs[generated_token],
|
|
top_logprobs=top_logprobs,
|
|
)
|
|
|
|
logprob_response = ChatCompletionLogprobs(content=[token_prob_response])
|
|
|
|
choice = ChatCompletionStreamChoice(
|
|
index=index,
|
|
delta=message,
|
|
logprobs=logprob_response,
|
|
)
|
|
|
|
choices.append(choice)
|
|
|
|
chunk = ChatCompletionStreamChunk(
|
|
id=f"chatcmpl-{request_id}",
|
|
choices=choices,
|
|
model=unwrap(model_name, ""),
|
|
usage=usage_stats,
|
|
)
|
|
|
|
return chunk
|
|
|
|
|
|
def _build_tool_call_chunks(
|
|
tool_calls: List[ToolCall],
|
|
request_id: str,
|
|
model_name: str,
|
|
) -> List[ChatCompletionStreamChunk]:
|
|
"""Build the OpenAI-standard streaming sequence for tool calls.
|
|
|
|
Emits two chunks:
|
|
1. Tool-call chunk: role="assistant", complete tool_calls with
|
|
index/id/type/name/arguments (all data in one chunk).
|
|
2. Finish chunk: empty delta, finish_reason="tool_calls".
|
|
|
|
Complete arguments are sent in a single chunk rather than streamed
|
|
incrementally, which is valid per OpenAI's spec (clients concatenate
|
|
argument strings across deltas) and maximizes compatibility with
|
|
clients that may not implement multi-chunk tool-call assembly.
|
|
|
|
The tool_calls are placed directly into a ChatCompletionMessage
|
|
(not a raw dict) so Pydantic validates them as ToolCall objects
|
|
with the index field preserved (ToolCall declares index as Optional[int]).
|
|
"""
|
|
chunk_id = f"chatcmpl-{request_id}"
|
|
|
|
# Set index on each tool call for streaming
|
|
for idx, tc in enumerate(tool_calls):
|
|
tc.index = idx
|
|
|
|
# Chunk 1: Complete tool call data
|
|
tool_call_message = ChatCompletionMessage(
|
|
role="assistant",
|
|
tool_calls=tool_calls,
|
|
)
|
|
tool_chunk = ChatCompletionStreamChunk(
|
|
id=chunk_id,
|
|
choices=[
|
|
ChatCompletionStreamChoice(
|
|
index=0,
|
|
delta=tool_call_message,
|
|
finish_reason=None,
|
|
)
|
|
],
|
|
model=model_name,
|
|
)
|
|
|
|
# Chunk 2: Finish signal
|
|
# Use model_construct to prevent Pydantic's smart Union from
|
|
# coercing the empty dict {} into ChatCompletionMessage(role="user")
|
|
finish_choice = ChatCompletionStreamChoice.model_construct(
|
|
index=0,
|
|
delta={},
|
|
finish_reason="tool_calls",
|
|
logprobs=None,
|
|
)
|
|
finish_chunk = ChatCompletionStreamChunk(
|
|
id=chunk_id,
|
|
choices=[finish_choice],
|
|
model=model_name,
|
|
)
|
|
|
|
return [tool_chunk, finish_chunk]
|
|
|
|
|
|
async def _append_template_metadata(data: ChatCompletionRequest, template_vars: dict):
|
|
"""Adding metadata is a one-time process."""
|
|
|
|
template_metadata = await model.container.prompt_template.extract_metadata(
|
|
template_vars
|
|
)
|
|
|
|
# Stop strings
|
|
if isinstance(data.stop, str):
|
|
data.stop = [data.stop] + template_metadata.stop_strings
|
|
else:
|
|
data.stop.extend(template_metadata.stop_strings)
|
|
|
|
# if a tool start is present, append it to stopping strings
|
|
if template_metadata.tool_start:
|
|
data.stop.append(template_metadata.tool_start)
|
|
|
|
|
|
async def format_messages_with_template(
|
|
messages: List[ChatCompletionMessage],
|
|
existing_template_vars: Optional[dict] = None,
|
|
):
|
|
"""Barebones function to format chat completion messages into a prompt."""
|
|
|
|
template_vars = unwrap(existing_template_vars, {})
|
|
mm_embeddings = MultimodalEmbeddingWrapper() if model.container.use_vision else None
|
|
|
|
# Convert all messages to a dictionary representation
|
|
message_dicts: List[dict] = []
|
|
for message in messages:
|
|
if isinstance(message.content, list):
|
|
concatenated_content = ""
|
|
for content in message.content:
|
|
if content.type == "text":
|
|
concatenated_content += content.text
|
|
elif content.type == "image_url" and mm_embeddings:
|
|
await mm_embeddings.add(content.image_url.url)
|
|
concatenated_content += mm_embeddings.text_alias[-1]
|
|
|
|
# Convert the message content into a concatenated string
|
|
message.content = concatenated_content
|
|
|
|
message_dicts.append(message.model_dump(exclude_none=True))
|
|
|
|
# Pre-template: convert tool_call arguments from JSON strings to dicts.
|
|
# OpenAI-compatible clients (Kilo, Roo, etc.) send arguments as JSON
|
|
# strings per the OAI spec, but Qwen3-Coder's template calls
|
|
# .items() on arguments which requires a dict/mapping.
|
|
for msg in message_dicts:
|
|
if msg.get("tool_calls"):
|
|
for tc in msg["tool_calls"]:
|
|
func = tc.get("function", {})
|
|
args = func.get("arguments")
|
|
if isinstance(args, str):
|
|
try:
|
|
func["arguments"] = json.loads(args)
|
|
except (json.JSONDecodeError, ValueError):
|
|
logger.warning(
|
|
"Failed to parse tool_call arguments JSON "
|
|
"string to dict, keeping as string"
|
|
)
|
|
|
|
# Get all special tokens
|
|
special_tokens_dict = model.container.get_special_tokens()
|
|
|
|
template_vars.update({"messages": message_dicts, **special_tokens_dict})
|
|
|
|
prompt = await model.container.prompt_template.render(template_vars)
|
|
return prompt, mm_embeddings, template_vars
|
|
|
|
|
|
async def apply_chat_template(data: ChatCompletionRequest):
|
|
"""
|
|
Compile the prompt and get any additional stop strings from the template.
|
|
Template stop strings can be overriden by sampler overrides if force is true.
|
|
"""
|
|
|
|
# Locally store tools dict
|
|
tools = data.model_dump()["tools"]
|
|
|
|
try:
|
|
data.template_vars.update(
|
|
{
|
|
"add_generation_prompt": data.add_generation_prompt,
|
|
"tools": tools,
|
|
"functions": data.functions,
|
|
}
|
|
)
|
|
|
|
prompt, mm_embeddings, template_vars = await format_messages_with_template(
|
|
data.messages, data.template_vars
|
|
)
|
|
|
|
# Append response prefix if present
|
|
if data.response_prefix:
|
|
if data.add_generation_prompt:
|
|
prompt += data.response_prefix
|
|
else:
|
|
logger.warning(
|
|
"Could not add response prefix because "
|
|
"add_generation_prompt is False"
|
|
)
|
|
|
|
# Removes the starting BOS token if the model adds one
|
|
# This is to prevent add_bos_token from adding multiple bos tokens
|
|
bos_token = template_vars.get("bos_token")
|
|
if (
|
|
bos_token
|
|
and model.container.hf_model.add_bos_token()
|
|
and prompt.startswith(bos_token)
|
|
):
|
|
prompt = prompt.removeprefix(bos_token)
|
|
|
|
# Add template metadata
|
|
await _append_template_metadata(data, template_vars)
|
|
|
|
return prompt, mm_embeddings
|
|
|
|
except KeyError as exc:
|
|
error_message = handle_request_error(
|
|
"Could not find a Conversation from prompt template "
|
|
f"'{model.container.prompt_template.name}'. "
|
|
"Check your spelling?",
|
|
).error.message
|
|
|
|
raise HTTPException(400, error_message) from exc
|
|
except TemplateError as exc:
|
|
error_message = handle_request_error(f"TemplateError: {str(exc)}").error.message
|
|
|
|
raise HTTPException(400, error_message) from exc
|
|
|
|
|
|
async def stream_generate_chat_completion(
|
|
prompt: str,
|
|
embeddings: MultimodalEmbeddingWrapper,
|
|
data: ChatCompletionRequest,
|
|
request: Request,
|
|
model_path: pathlib.Path,
|
|
):
|
|
"""Generator for the generation process."""
|
|
abort_event = asyncio.Event()
|
|
gen_queue = asyncio.Queue()
|
|
gen_tasks: List[asyncio.Task] = []
|
|
tool_start = model.container.prompt_template.metadata.tool_start
|
|
tool_call_format = model.container.prompt_template.metadata.tool_call_format
|
|
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
|
|
|
try:
|
|
logger.info(f"Received chat completion streaming request {request.state.id}")
|
|
|
|
for idx in range(0, data.n):
|
|
task_gen_params = data.model_copy(deep=True)
|
|
request_id = _parse_gen_request_id(data.n, request.state.id, idx)
|
|
|
|
gen_task = asyncio.create_task(
|
|
_stream_collector(
|
|
idx,
|
|
gen_queue,
|
|
request_id,
|
|
prompt,
|
|
task_gen_params,
|
|
abort_event,
|
|
mm_embeddings=embeddings,
|
|
)
|
|
)
|
|
|
|
gen_tasks.append(gen_task)
|
|
|
|
# Text accumulation for tool calls
|
|
current_generation_text = ""
|
|
|
|
# Consumer loop
|
|
while True:
|
|
# Fast path: items already queued — no task overhead
|
|
if not gen_queue.empty():
|
|
generation = gen_queue.get_nowait()
|
|
else:
|
|
# Slow path: queue empty — race get against disconnect
|
|
get_task = asyncio.create_task(gen_queue.get())
|
|
done, _ = await asyncio.wait(
|
|
[get_task, disconnect_task],
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
)
|
|
if disconnect_task in done:
|
|
get_task.cancel()
|
|
raise CancelledError()
|
|
generation = get_task.result()
|
|
|
|
if disconnect_task.done():
|
|
raise CancelledError()
|
|
|
|
# Handle options if a tool model is present
|
|
if tool_start and data.tool_choice != "none":
|
|
if "stop_str" in generation:
|
|
generations = await generate_tool_calls(
|
|
prompt,
|
|
embeddings,
|
|
data,
|
|
[generation],
|
|
request,
|
|
)
|
|
|
|
# Only one generation present in this case
|
|
generation = generations[0]
|
|
|
|
# Emit proper three-phase tool-call streaming sequence
|
|
if "tool_calls" in generation:
|
|
tool_calls_raw = generation["tool_calls"]
|
|
parsed = ToolCallProcessor.parse(
|
|
tool_calls_raw, format=tool_call_format
|
|
)
|
|
if parsed and isinstance(data.tool_choice, NamedToolChoice):
|
|
parsed = ToolCallProcessor.filter_by_name(
|
|
parsed, data.tool_choice.function.name
|
|
)
|
|
if parsed:
|
|
for tc_chunk in _build_tool_call_chunks(
|
|
parsed,
|
|
request.state.id,
|
|
model_path.name,
|
|
):
|
|
yield _serialize_stream_chunk(tc_chunk)
|
|
|
|
# Handle completion and usage after tool calls
|
|
if (
|
|
all(task.done() for task in gen_tasks)
|
|
and gen_queue.empty()
|
|
):
|
|
if (
|
|
data.stream_options
|
|
and data.stream_options.include_usage
|
|
):
|
|
usage_chunk = _create_stream_chunk(
|
|
request.state.id,
|
|
generation,
|
|
model_path.name,
|
|
is_usage_chunk=True,
|
|
)
|
|
yield _serialize_stream_chunk(usage_chunk)
|
|
|
|
logger.info(
|
|
"Finished chat completion streaming "
|
|
f"request {request.state.id}"
|
|
)
|
|
yield "[DONE]"
|
|
break
|
|
continue
|
|
|
|
elif "text" in generation:
|
|
current_generation_text += generation["text"]
|
|
|
|
# Stream collector will push an exception to the queue if it fails
|
|
if isinstance(generation, Exception):
|
|
raise generation
|
|
|
|
response = _create_stream_chunk(
|
|
request.state.id,
|
|
generation,
|
|
model_path.name,
|
|
)
|
|
yield _serialize_stream_chunk(response)
|
|
|
|
# Check if all tasks are completed
|
|
if all(task.done() for task in gen_tasks) and gen_queue.empty():
|
|
# Send a usage chunk
|
|
if data.stream_options and data.stream_options.include_usage:
|
|
usage_chunk = _create_stream_chunk(
|
|
request.state.id,
|
|
generation,
|
|
model_path.name,
|
|
is_usage_chunk=True,
|
|
)
|
|
yield _serialize_stream_chunk(usage_chunk)
|
|
|
|
logger.info(
|
|
f"Finished chat completion streaming request {request.state.id}"
|
|
)
|
|
|
|
yield "[DONE]"
|
|
break
|
|
except CancelledError:
|
|
# Get out if the request gets disconnected
|
|
|
|
handle_request_disconnect("Chat completion generation cancelled by user.")
|
|
except Exception:
|
|
yield get_generator_error(
|
|
"Chat completion aborted. Please check the server console."
|
|
)
|
|
finally:
|
|
abort_event.set()
|
|
disconnect_task.cancel()
|
|
|
|
|
|
async def generate_chat_completion(
|
|
prompt: str,
|
|
embeddings: MultimodalEmbeddingWrapper,
|
|
data: ChatCompletionRequest,
|
|
request: Request,
|
|
model_path: pathlib.Path,
|
|
):
|
|
gen_tasks: List[asyncio.Task] = []
|
|
tool_start = model.container.prompt_template.metadata.tool_start
|
|
tool_call_format = model.container.prompt_template.metadata.tool_call_format
|
|
|
|
try:
|
|
logger.info(f"Received chat completion request {request.state.id}")
|
|
|
|
for idx in range(0, data.n):
|
|
request_id = _parse_gen_request_id(data.n, request.state.id, idx)
|
|
|
|
gen_tasks.append(
|
|
asyncio.create_task(
|
|
model.container.generate(
|
|
request_id,
|
|
prompt,
|
|
data,
|
|
mm_embeddings=embeddings,
|
|
)
|
|
)
|
|
)
|
|
|
|
generations = await asyncio.gather(*gen_tasks)
|
|
|
|
# Check all the generations and see if a tool call is required
|
|
force_tool_pass = data.tool_choice == "required" or isinstance(
|
|
data.tool_choice, NamedToolChoice
|
|
)
|
|
if tool_start or force_tool_pass:
|
|
generations = await generate_tool_calls(
|
|
prompt, embeddings, data, generations, request
|
|
)
|
|
|
|
response = _create_response(
|
|
request.state.id,
|
|
generations,
|
|
model_path.name,
|
|
tool_call_format=tool_call_format,
|
|
tool_choice=data.tool_choice,
|
|
)
|
|
|
|
logger.info(f"Finished chat completion request {request.state.id}")
|
|
|
|
return response
|
|
except Exception as exc:
|
|
error_message = handle_request_error(
|
|
f"Chat completion {request.state.id} aborted. "
|
|
"Maybe the model was unloaded? "
|
|
"Please check the server console."
|
|
).error.message
|
|
|
|
# Server error if there's a generation exception
|
|
raise HTTPException(503, error_message) from exc
|
|
|
|
|
|
async def generate_tool_calls(
|
|
prompt: str,
|
|
embeddings: MultimodalEmbeddingWrapper,
|
|
data: ChatCompletionRequest,
|
|
generations: List[str],
|
|
request: Request,
|
|
):
|
|
gen_tasks: List[asyncio.Task] = []
|
|
tool_start = model.container.prompt_template.metadata.tool_start
|
|
tool_call_format = model.container.prompt_template.metadata.tool_call_format
|
|
tool_choice = data.tool_choice
|
|
|
|
if tool_choice == "none":
|
|
return generations
|
|
|
|
# Tracks which generations asked for a tool call
|
|
tool_idx: List[int] = []
|
|
|
|
# Copy to make sure the parent JSON schema doesn't get modified
|
|
tool_data = data.model_copy(deep=True)
|
|
|
|
if tool_call_format in ("xml", "auto"):
|
|
# XML / auto mode: let the model generate its natural output
|
|
# without JSON schema constraint
|
|
logger.debug(
|
|
f"generate_tool_calls: Using '{tool_call_format}' mode "
|
|
f"(no JSON schema constraint)"
|
|
)
|
|
|
|
# Remove tool_start from stop strings so the model can emit
|
|
# multiple sequential <tool_call> blocks without stopping early
|
|
if (
|
|
tool_start
|
|
and isinstance(tool_data.stop, list)
|
|
and tool_start in tool_data.stop
|
|
):
|
|
tool_data.stop = [s for s in tool_data.stop if s != tool_start]
|
|
logger.debug(
|
|
f"generate_tool_calls: Removed '{tool_start}' from "
|
|
f"second-pass stop strings"
|
|
)
|
|
else:
|
|
# JSON mode: constrained generation (existing behavior)
|
|
tool_data.json_schema = TOOL_CALL_SCHEMA
|
|
|
|
for idx, gen in enumerate(generations):
|
|
stop_str = gen.get("stop_str")
|
|
should_generate = stop_str == tool_start
|
|
|
|
# Force tool generation if tool_choice requires it
|
|
if not should_generate and (
|
|
tool_choice == "required" or isinstance(tool_choice, NamedToolChoice)
|
|
):
|
|
should_generate = True
|
|
|
|
if not should_generate:
|
|
continue
|
|
|
|
logger.info(
|
|
f"Detected tool call in chat completion request "
|
|
f"{request.state.id} (format={tool_call_format})"
|
|
)
|
|
|
|
# Build per-generation prompt (avoid mutating shared prompt)
|
|
tool_prompt = prompt
|
|
precursor_text = gen.get("full_text")
|
|
if precursor_text:
|
|
tool_prompt = tool_prompt + precursor_text
|
|
|
|
# For XML/auto mode: append tool_start back to prompt.
|
|
# The stop string was consumed by the first pass and not included
|
|
# in full_text, but the model expects to continue after <tool_call>.
|
|
# Include a trailing newline to match the canonical template format.
|
|
if tool_call_format in ("xml", "auto") and tool_start:
|
|
tool_prompt = tool_prompt + tool_start + "\n"
|
|
|
|
gen_request_id = gen.get("request_id")
|
|
tool_request_id = f"{gen_request_id}-tool"
|
|
|
|
gen_tasks.append(
|
|
asyncio.create_task(
|
|
model.container.generate(
|
|
tool_request_id,
|
|
tool_prompt,
|
|
tool_data,
|
|
mm_embeddings=embeddings,
|
|
)
|
|
)
|
|
)
|
|
|
|
tool_idx.append(idx)
|
|
|
|
if len(tool_idx) > 0:
|
|
tool_calls = await asyncio.gather(*gen_tasks)
|
|
|
|
# Map tool calls to their appropriate generation
|
|
for gen_idx, tool_call in zip(tool_idx, tool_calls, strict=True):
|
|
raw_text = tool_call["text"]
|
|
|
|
if tool_call_format in ("xml", "auto"):
|
|
# Prepend tool_start to reconstruct complete XML for parser
|
|
raw_text = tool_start + "\n" + raw_text
|
|
|
|
generations[gen_idx]["tool_calls"] = raw_text
|
|
|
|
return generations
|