diff --git a/OAI/models/models.py b/OAI/models/models.py new file mode 100644 index 0000000..44a2f26 --- /dev/null +++ b/OAI/models/models.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel, Field +from time import time +from typing import List + +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time())) + owned_by: str = "tabbyAPI" + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = Field(default_factory=list) diff --git a/OAI/utils.py b/OAI/utils.py index ebd80e3..8b76d39 100644 --- a/OAI/utils.py +++ b/OAI/utils.py @@ -1,5 +1,7 @@ +import pathlib from OAI.models.completions import CompletionResponse, CompletionRespChoice from OAI.models.common import UsageStats +from OAI.models.models import ModelList, ModelCard from typing import Optional def create_completion_response(text: str, index: int, model_name: Optional[str]): @@ -17,3 +19,12 @@ def create_completion_response(text: str, index: int, model_name: Optional[str]) ) return response + +def get_model_list(model_path: pathlib.Path): + model_card_list = ModelList() + for path in model_path.parent.iterdir(): + if path.is_dir(): + model_card = ModelCard(id = path.name) + model_card_list.data.append(model_card) + + return model_card_list diff --git a/main.py b/main.py index d9b53e3..38c7979 100644 --- a/main.py +++ b/main.py @@ -4,15 +4,27 @@ from fastapi import FastAPI, Request from model import ModelContainer from progress.bar import IncrementalBar from sse_starlette import EventSourceResponse -from OAI.models.completions import CompletionRequest, CompletionResponse, CompletionRespChoice -from OAI.utils import create_completion_response +from OAI.models.completions import CompletionRequest, CompletionResponse +from OAI.models.models import ModelCard, ModelList +from OAI.utils import create_completion_response, get_model_list app = FastAPI() # Initialize a model container. This can be undefined at any period of time model_container: ModelContainer = None -@app.post("/v1/completions") +@app.get("/v1/models") +@app.get("/v1/model/list") +async def list_models(): + models = get_model_list(model_container.get_model_path()) + + return models.model_dump_json() + +@app.get("/v1/model") +async def get_current_model(): + return ModelCard(id = model_container.get_model_path().name) + +@app.post("/v1/completions", response_class=CompletionResponse) async def generate_completion(request: Request, data: CompletionRequest): if data.stream: async def generator(): @@ -21,14 +33,14 @@ async def generate_completion(request: Request, data: CompletionRequest): if await request.is_disconnected(): break - response = create_completion_response(part, index, model_container.get_model_name()) + response = create_completion_response(part, index, model_container.get_model_path().name) yield response.model_dump_json() return EventSourceResponse(generator()) else: response_text = model_container.generate(**data.to_gen_params()) - response = create_completion_response(response_text, 0, model_container.get_model_name()) + response = create_completion_response(response_text, 0, model_container.get_model_path().name) return response.model_dump_json() diff --git a/model.py b/model.py index b7368b0..0434c60 100644 --- a/model.py +++ b/model.py @@ -1,4 +1,4 @@ -import gc, time +import gc, time, pathlib import torch from exllamav2 import( ExLlamaV2, @@ -11,7 +11,6 @@ from exllamav2.generator import( ExLlamaV2StreamingGenerator, ExLlamaV2Sampler ) -from os import path from typing import Optional # Bytes to reserve on first device when loading with auto split @@ -102,11 +101,11 @@ class ModelContainer: self.draft_config.max_input_len = kwargs["chunk_size"] self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2 - def get_model_name(self): - if self.draft_enabled: - return path.basename(path.normpath(self.draft_config.model_dir)) - else: - return path.basename(path.normpath(self.config.model_dir)) + + def get_model_path(self): + model_path = pathlib.Path(self.draft_config.model_dir if self.draft_enabled else self.config.model_dir) + return model_path + def load(self, progress_callback = None): """