mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-22 07:19:07 +00:00
Templates: Switch to Jinja2
Jinja2 is a lightweight template parser that's used in Transformers for parsing chat completions. It's much more efficient than Fastchat and can be imported as part of requirements. Also allows for unblocking Pydantic's version. Users now have to provide their own template if needed. A separate repo may be usable for common prompt template storage. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
19
main.py
19
main.py
@@ -27,10 +27,10 @@ from OAI.utils import (
|
||||
create_completion_response,
|
||||
get_model_list,
|
||||
get_lora_list,
|
||||
get_chat_completion_prompt,
|
||||
create_chat_completion_response,
|
||||
create_chat_completion_stream_chunk
|
||||
)
|
||||
from templating import get_prompt_from_template
|
||||
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
|
||||
|
||||
app = FastAPI()
|
||||
@@ -76,6 +76,7 @@ async def list_models():
|
||||
@app.get("/v1/internal/model/info", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
async def get_current_model():
|
||||
model_name = model_container.get_model_path().name
|
||||
prompt_template = model_container.prompt_template
|
||||
model_card = ModelCard(
|
||||
id = model_name,
|
||||
parameters = ModelCardParameters(
|
||||
@@ -83,7 +84,7 @@ async def get_current_model():
|
||||
rope_alpha = model_container.config.scale_alpha_value,
|
||||
max_seq_len = model_container.config.max_seq_len,
|
||||
cache_mode = "FP8" if model_container.cache_fp8 else "FP16",
|
||||
prompt_template = unwrap(model_container.prompt_template, "auto")
|
||||
prompt_template = prompt_template.name if prompt_template else None
|
||||
),
|
||||
logging = gen_logging.config
|
||||
)
|
||||
@@ -302,19 +303,21 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
||||
# Chat completions endpoint
|
||||
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
|
||||
if model_container.prompt_template is None:
|
||||
return HTTPException(422, "This endpoint is disabled because a prompt template is not set.")
|
||||
|
||||
model_path = model_container.get_model_path()
|
||||
|
||||
if isinstance(data.messages, str):
|
||||
prompt = data.messages
|
||||
else:
|
||||
# If the request specified prompt template isn't found, use the one from model container
|
||||
# Otherwise, let fastchat figure it out
|
||||
prompt_template = unwrap(data.prompt_template, model_container.prompt_template)
|
||||
|
||||
try:
|
||||
prompt = get_chat_completion_prompt(model_path.name, data.messages, prompt_template)
|
||||
prompt = get_prompt_from_template(data.messages, model_container.prompt_template)
|
||||
except KeyError:
|
||||
return HTTPException(400, f"Could not find a Conversation from prompt template '{prompt_template}'. Check your spelling?")
|
||||
return HTTPException(
|
||||
400,
|
||||
f"Could not find a Conversation from prompt template '{model_container.prompt_template.name}'. Check your spelling?"
|
||||
)
|
||||
|
||||
if data.stream:
|
||||
const_id = f"chatcmpl-{uuid4().hex}"
|
||||
|
||||
Reference in New Issue
Block a user