API: Fix chat completion formatting flow

Previously, the flow for parsing chat completion messages and rendering
from the prompt template was disconnected between endpoints. Now, create
a common function to render and handle everything appropriately afterwards.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-11-21 17:51:14 -05:00
parent c652a6e030
commit 902045edbb
6 changed files with 92 additions and 115 deletions

View File

@@ -31,7 +31,6 @@ from exllamav2.generator import (
)
from itertools import zip_longest
from loguru import logger
from PIL import Image
from typing import List, Optional, Union
from ruamel.yaml import YAML
@@ -374,6 +373,8 @@ class ExllamaV2Container:
self.draft_config.max_input_len = chunk_size
self.draft_config.max_attention_size = chunk_size**2
self.prompt_template = None
# Return the created instance
return self
@@ -875,17 +876,18 @@ class ExllamaV2Container:
async with self.load_condition:
self.load_condition.notify_all()
def encode_tokens(
self, text: str, embeddings: MultimodalEmbeddingWrapper, **kwargs
):
def encode_tokens(self, text: str, **kwargs):
"""Wrapper to encode tokens from a text string."""
mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings")
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
return (
self.tokenizer.encode(
text,
add_bos=unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
embeddings=embeddings.content,
embeddings=mm_embeddings_content,
)
.flatten()
.tolist()
@@ -931,7 +933,6 @@ class ExllamaV2Container:
async def generate(
self,
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
request_id: str,
abort_event: asyncio.Event = None,
**kwargs,
@@ -939,7 +940,7 @@ class ExllamaV2Container:
"""Generate a response to a prompt."""
generations = []
async for generation in self.generate_gen(
prompt, embeddings, request_id, abort_event, **kwargs
prompt, request_id, abort_event, **kwargs
):
generations.append(generation)
@@ -1005,7 +1006,6 @@ class ExllamaV2Container:
async def generate_gen(
self,
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
request_id: str,
abort_event: Optional[asyncio.Event] = None,
**kwargs,
@@ -1270,13 +1270,17 @@ class ExllamaV2Container:
else:
stop_conditions += eos_tokens
# Get multimodal embeddings if present
mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings")
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
# Encode both positive and negative prompts
input_ids = [
self.tokenizer.encode(
prompt,
add_bos=add_bos_token,
encode_special_tokens=True,
embeddings=embeddings.content,
embeddings=mm_embeddings_content,
)
for prompt in prompts
]
@@ -1327,7 +1331,7 @@ class ExllamaV2Container:
banned_strings=banned_strings,
token_healing=token_healing,
identifier=job_id,
embeddings=embeddings.content,
embeddings=mm_embeddings_content,
)
# Save generated tokens and full response