mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-21 23:09:13 +00:00
API: Fix chat completion formatting flow
Previously, the flow for parsing chat completion messages and rendering from the prompt template was disconnected between endpoints. Now, create a common function to render and handle everything appropriately afterwards. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -31,7 +31,6 @@ from exllamav2.generator import (
|
||||
)
|
||||
from itertools import zip_longest
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
@@ -374,6 +373,8 @@ class ExllamaV2Container:
|
||||
self.draft_config.max_input_len = chunk_size
|
||||
self.draft_config.max_attention_size = chunk_size**2
|
||||
|
||||
self.prompt_template = None
|
||||
|
||||
# Return the created instance
|
||||
return self
|
||||
|
||||
@@ -875,17 +876,18 @@ class ExllamaV2Container:
|
||||
async with self.load_condition:
|
||||
self.load_condition.notify_all()
|
||||
|
||||
def encode_tokens(
|
||||
self, text: str, embeddings: MultimodalEmbeddingWrapper, **kwargs
|
||||
):
|
||||
def encode_tokens(self, text: str, **kwargs):
|
||||
"""Wrapper to encode tokens from a text string."""
|
||||
|
||||
mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings")
|
||||
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
|
||||
|
||||
return (
|
||||
self.tokenizer.encode(
|
||||
text,
|
||||
add_bos=unwrap(kwargs.get("add_bos_token"), True),
|
||||
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
||||
embeddings=embeddings.content,
|
||||
embeddings=mm_embeddings_content,
|
||||
)
|
||||
.flatten()
|
||||
.tolist()
|
||||
@@ -931,7 +933,6 @@ class ExllamaV2Container:
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
embeddings: MultimodalEmbeddingWrapper,
|
||||
request_id: str,
|
||||
abort_event: asyncio.Event = None,
|
||||
**kwargs,
|
||||
@@ -939,7 +940,7 @@ class ExllamaV2Container:
|
||||
"""Generate a response to a prompt."""
|
||||
generations = []
|
||||
async for generation in self.generate_gen(
|
||||
prompt, embeddings, request_id, abort_event, **kwargs
|
||||
prompt, request_id, abort_event, **kwargs
|
||||
):
|
||||
generations.append(generation)
|
||||
|
||||
@@ -1005,7 +1006,6 @@ class ExllamaV2Container:
|
||||
async def generate_gen(
|
||||
self,
|
||||
prompt: str,
|
||||
embeddings: MultimodalEmbeddingWrapper,
|
||||
request_id: str,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
**kwargs,
|
||||
@@ -1270,13 +1270,17 @@ class ExllamaV2Container:
|
||||
else:
|
||||
stop_conditions += eos_tokens
|
||||
|
||||
# Get multimodal embeddings if present
|
||||
mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings")
|
||||
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
|
||||
|
||||
# Encode both positive and negative prompts
|
||||
input_ids = [
|
||||
self.tokenizer.encode(
|
||||
prompt,
|
||||
add_bos=add_bos_token,
|
||||
encode_special_tokens=True,
|
||||
embeddings=embeddings.content,
|
||||
embeddings=mm_embeddings_content,
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
@@ -1327,7 +1331,7 @@ class ExllamaV2Container:
|
||||
banned_strings=banned_strings,
|
||||
token_healing=token_healing,
|
||||
identifier=job_id,
|
||||
embeddings=embeddings.content,
|
||||
embeddings=mm_embeddings_content,
|
||||
)
|
||||
|
||||
# Save generated tokens and full response
|
||||
|
||||
Reference in New Issue
Block a user