API: Don't use response_class

This arg in routes caused many errors and isn't even needed for
responses.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-11-14 21:56:15 -05:00
parent b625bface9
commit 4670a77c26
3 changed files with 5 additions and 5 deletions

View File

@@ -1,7 +1,7 @@
import pathlib import pathlib
from OAI.types.completions import CompletionResponse, CompletionRespChoice from OAI.types.completions import CompletionResponse, CompletionRespChoice
from OAI.types.common import UsageStats from OAI.types.common import UsageStats
from OAI.types.models import ModelList, ModelCard from OAI.types.model import ModelList, ModelCard
from typing import Optional from typing import Optional
def create_completion_response(text: str, index: int, model_name: Optional[str]): def create_completion_response(text: str, index: int, model_name: Optional[str]):

View File

@@ -6,8 +6,8 @@ from fastapi import FastAPI, Request, HTTPException, Depends
from model import ModelContainer from model import ModelContainer
from progress.bar import IncrementalBar from progress.bar import IncrementalBar
from sse_starlette import EventSourceResponse from sse_starlette import EventSourceResponse
from OAI.types.completions import CompletionRequest, CompletionResponse from OAI.types.completions import CompletionRequest
from OAI.types.models import ModelCard, ModelList, ModelLoadRequest, ModelLoadResponse from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
from OAI.utils import create_completion_response, get_model_list from OAI.utils import create_completion_response, get_model_list
from typing import Optional from typing import Optional
from utils import load_progress from utils import load_progress
@@ -34,7 +34,7 @@ async def get_current_model():
model_card = ModelCard(id=model_container.get_model_path().name) model_card = ModelCard(id=model_container.get_model_path().name)
return model_card.model_dump_json() return model_card.model_dump_json()
@app.post("/v1/model/load", response_class=ModelLoadResponse, dependencies=[Depends(check_admin_key)]) @app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
async def load_model(data: ModelLoadRequest): async def load_model(data: ModelLoadRequest):
if model_container and model_container.model: if model_container and model_container.model:
raise HTTPException(400, "A model is already loaded! Please unload it first.") raise HTTPException(400, "A model is already loaded! Please unload it first.")
@@ -80,7 +80,7 @@ async def unload_model():
model_container.unload() model_container.unload()
model_container = None model_container = None
@app.post("/v1/completions", response_class=CompletionResponse, dependencies=[Depends(check_api_key)]) @app.post("/v1/completions", dependencies=[Depends(check_api_key)])
async def generate_completion(request: Request, data: CompletionRequest): async def generate_completion(request: Request, data: CompletionRequest):
if data.stream: if data.stream:
async def generator(): async def generator():