From 47343e2f1adb70a697a7fd4638808de721ae25d4 Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 13 Nov 2023 21:38:34 -0500 Subject: [PATCH] OAI: Add models support The models endpoint fetches all the models that OAI has to offer. However, since this is an OAI clone, just list the models inside the user's configured model directory instead. Signed-off-by: kingbri --- OAI/models/models.py | 13 +++++++++++++ OAI/utils.py | 11 +++++++++++ main.py | 22 +++++++++++++++++----- model.py | 13 ++++++------- 4 files changed, 47 insertions(+), 12 deletions(-) create mode 100644 OAI/models/models.py diff --git a/OAI/models/models.py b/OAI/models/models.py new file mode 100644 index 0000000..44a2f26 --- /dev/null +++ b/OAI/models/models.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel, Field +from time import time +from typing import List + +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time())) + owned_by: str = "tabbyAPI" + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = Field(default_factory=list) diff --git a/OAI/utils.py b/OAI/utils.py index ebd80e3..8b76d39 100644 --- a/OAI/utils.py +++ b/OAI/utils.py @@ -1,5 +1,7 @@ +import pathlib from OAI.models.completions import CompletionResponse, CompletionRespChoice from OAI.models.common import UsageStats +from OAI.models.models import ModelList, ModelCard from typing import Optional def create_completion_response(text: str, index: int, model_name: Optional[str]): @@ -17,3 +19,12 @@ def create_completion_response(text: str, index: int, model_name: Optional[str]) ) return response + +def get_model_list(model_path: pathlib.Path): + model_card_list = ModelList() + for path in model_path.parent.iterdir(): + if path.is_dir(): + model_card = ModelCard(id = path.name) + model_card_list.data.append(model_card) + + return model_card_list diff --git a/main.py b/main.py index d9b53e3..38c7979 100644 --- a/main.py +++ b/main.py @@ -4,15 +4,27 @@ from fastapi import FastAPI, Request from model import ModelContainer from progress.bar import IncrementalBar from sse_starlette import EventSourceResponse -from OAI.models.completions import CompletionRequest, CompletionResponse, CompletionRespChoice -from OAI.utils import create_completion_response +from OAI.models.completions import CompletionRequest, CompletionResponse +from OAI.models.models import ModelCard, ModelList +from OAI.utils import create_completion_response, get_model_list app = FastAPI() # Initialize a model container. This can be undefined at any period of time model_container: ModelContainer = None -@app.post("/v1/completions") +@app.get("/v1/models") +@app.get("/v1/model/list") +async def list_models(): + models = get_model_list(model_container.get_model_path()) + + return models.model_dump_json() + +@app.get("/v1/model") +async def get_current_model(): + return ModelCard(id = model_container.get_model_path().name) + +@app.post("/v1/completions", response_class=CompletionResponse) async def generate_completion(request: Request, data: CompletionRequest): if data.stream: async def generator(): @@ -21,14 +33,14 @@ async def generate_completion(request: Request, data: CompletionRequest): if await request.is_disconnected(): break - response = create_completion_response(part, index, model_container.get_model_name()) + response = create_completion_response(part, index, model_container.get_model_path().name) yield response.model_dump_json() return EventSourceResponse(generator()) else: response_text = model_container.generate(**data.to_gen_params()) - response = create_completion_response(response_text, 0, model_container.get_model_name()) + response = create_completion_response(response_text, 0, model_container.get_model_path().name) return response.model_dump_json() diff --git a/model.py b/model.py index b7368b0..0434c60 100644 --- a/model.py +++ b/model.py @@ -1,4 +1,4 @@ -import gc, time +import gc, time, pathlib import torch from exllamav2 import( ExLlamaV2, @@ -11,7 +11,6 @@ from exllamav2.generator import( ExLlamaV2StreamingGenerator, ExLlamaV2Sampler ) -from os import path from typing import Optional # Bytes to reserve on first device when loading with auto split @@ -102,11 +101,11 @@ class ModelContainer: self.draft_config.max_input_len = kwargs["chunk_size"] self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2 - def get_model_name(self): - if self.draft_enabled: - return path.basename(path.normpath(self.draft_config.model_dir)) - else: - return path.basename(path.normpath(self.config.model_dir)) + + def get_model_path(self): + model_path = pathlib.Path(self.draft_config.model_dir if self.draft_enabled else self.config.model_dir) + return model_path + def load(self, progress_callback = None): """