Templates: Switch to Jinja2

Jinja2 is a lightweight template parser that's used in Transformers
for parsing chat completions. It's much more efficient than Fastchat
and can be imported as part of requirements.

Also allows for unblocking Pydantic's version.

Users now have to provide their own template if needed. A separate
repo may be usable for common prompt template storage.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-17 00:41:42 -05:00
committed by Brian Dashore
parent 95fd0f075e
commit f631dd6ff7
14 changed files with 115 additions and 74 deletions

19
main.py
View File

@@ -27,10 +27,10 @@ from OAI.utils import (
create_completion_response,
get_model_list,
get_lora_list,
get_chat_completion_prompt,
create_chat_completion_response,
create_chat_completion_stream_chunk
)
from templating import get_prompt_from_template
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
app = FastAPI()
@@ -76,6 +76,7 @@ async def list_models():
@app.get("/v1/internal/model/info", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def get_current_model():
model_name = model_container.get_model_path().name
prompt_template = model_container.prompt_template
model_card = ModelCard(
id = model_name,
parameters = ModelCardParameters(
@@ -83,7 +84,7 @@ async def get_current_model():
rope_alpha = model_container.config.scale_alpha_value,
max_seq_len = model_container.config.max_seq_len,
cache_mode = "FP8" if model_container.cache_fp8 else "FP16",
prompt_template = unwrap(model_container.prompt_template, "auto")
prompt_template = prompt_template.name if prompt_template else None
),
logging = gen_logging.config
)
@@ -302,19 +303,21 @@ async def generate_completion(request: Request, data: CompletionRequest):
# Chat completions endpoint
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
if model_container.prompt_template is None:
return HTTPException(422, "This endpoint is disabled because a prompt template is not set.")
model_path = model_container.get_model_path()
if isinstance(data.messages, str):
prompt = data.messages
else:
# If the request specified prompt template isn't found, use the one from model container
# Otherwise, let fastchat figure it out
prompt_template = unwrap(data.prompt_template, model_container.prompt_template)
try:
prompt = get_chat_completion_prompt(model_path.name, data.messages, prompt_template)
prompt = get_prompt_from_template(data.messages, model_container.prompt_template)
except KeyError:
return HTTPException(400, f"Could not find a Conversation from prompt template '{prompt_template}'. Check your spelling?")
return HTTPException(
400,
f"Could not find a Conversation from prompt template '{model_container.prompt_template.name}'. Check your spelling?"
)
if data.stream:
const_id = f"chatcmpl-{uuid4().hex}"