mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
API: Don't do a second re-render when tool calling
Re-rendering the template is an expensive operation when it's possible to just concatenate the prompt and current generation text together. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
@@ -361,6 +361,8 @@ async def stream_generate_chat_completion(
|
||||
if tool_start:
|
||||
if "stop_str" in generation:
|
||||
generations = await generate_tool_calls(
|
||||
prompt,
|
||||
embeddings,
|
||||
data,
|
||||
[generation],
|
||||
request,
|
||||
@@ -442,7 +444,9 @@ async def generate_chat_completion(
|
||||
|
||||
# Check all the generations and see if a tool call is required
|
||||
if tool_start:
|
||||
generations = await generate_tool_calls(data, generations, request)
|
||||
generations = await generate_tool_calls(
|
||||
prompt, embeddings, data, generations, request
|
||||
)
|
||||
|
||||
response = _create_response(request.state.id, generations, model_path.name)
|
||||
|
||||
@@ -461,6 +465,8 @@ async def generate_chat_completion(
|
||||
|
||||
|
||||
async def generate_tool_calls(
|
||||
prompt: str,
|
||||
embeddings: MultimodalEmbeddingWrapper,
|
||||
data: ChatCompletionRequest,
|
||||
generations: List[str],
|
||||
request: Request,
|
||||
@@ -482,12 +488,10 @@ async def generate_tool_calls(
|
||||
|
||||
logger.info(f"Detected tool call in chat completion request {request.state.id}")
|
||||
|
||||
# Append the existing generation as part of the response prefix
|
||||
# Append the existing generation text if present
|
||||
precursor_text = current_generation_text or gen.get("text")
|
||||
if precursor_text:
|
||||
tool_data.response_prefix = precursor_text
|
||||
|
||||
pre_tool_prompt, embeddings = await apply_chat_template(tool_data)
|
||||
prompt = prompt + precursor_text
|
||||
|
||||
gen_request_id = _parse_gen_request_id(data.n, request.state.id, idx)
|
||||
tool_request_id = f"{gen_request_id}-tool"
|
||||
@@ -496,7 +500,7 @@ async def generate_tool_calls(
|
||||
asyncio.create_task(
|
||||
model.container.generate(
|
||||
tool_request_id,
|
||||
pre_tool_prompt,
|
||||
prompt,
|
||||
tool_data,
|
||||
mm_embeddings=embeddings,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user