Tree: Basic formatting and comments

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-11-16 11:48:30 -05:00
parent 5defb1b0b4
commit 60eb076b43
2 changed files with 21 additions and 3 deletions

View File

@@ -1,7 +1,7 @@
from uuid import uuid4 from uuid import uuid4
from time import time from time import time
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Union from typing import List, Optional, Union
from OAI.types.common import LogProbs, UsageStats, CommonCompletionRequest from OAI.types.common import LogProbs, UsageStats, CommonCompletionRequest
class CompletionRespChoice(BaseModel): class CompletionRespChoice(BaseModel):

22
main.py
View File

@@ -9,7 +9,12 @@ from sse_starlette import EventSourceResponse
from OAI.types.completion import CompletionRequest from OAI.types.completion import CompletionRequest
from OAI.types.chat_completion import ChatCompletionRequest from OAI.types.chat_completion import ChatCompletionRequest
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
from OAI.types.token import TokenEncodeRequest, TokenEncodeResponse, TokenDecodeRequest, TokenDecodeResponse from OAI.types.token import (
TokenEncodeRequest,
TokenEncodeResponse,
TokenDecodeRequest,
TokenDecodeResponse
)
from OAI.utils import ( from OAI.utils import (
create_completion_response, create_completion_response,
get_model_list, get_model_list,
@@ -27,6 +32,7 @@ app = FastAPI()
model_container: Optional[ModelContainer] = None model_container: Optional[ModelContainer] = None
config: dict = {} config: dict = {}
# Model list endpoint
@app.get("/v1/models", dependencies=[Depends(check_api_key)]) @app.get("/v1/models", dependencies=[Depends(check_api_key)])
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)]) @app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models(): async def list_models():
@@ -40,6 +46,7 @@ async def list_models():
return models.json() return models.json()
# Currently loaded model endpoint
@app.get("/v1/model", dependencies=[Depends(check_api_key)]) @app.get("/v1/model", dependencies=[Depends(check_api_key)])
async def get_current_model(): async def get_current_model():
if model_container is None or model_container.model is None: if model_container is None or model_container.model is None:
@@ -48,6 +55,7 @@ async def get_current_model():
model_card = ModelCard(id=model_container.get_model_path().name) model_card = ModelCard(id=model_container.get_model_path().name)
return model_card.json() return model_card.json()
# Load model endpoint
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) @app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
async def load_model(data: ModelLoadRequest): async def load_model(data: ModelLoadRequest):
if model_container and model_container.model: if model_container and model_container.model:
@@ -89,6 +97,7 @@ async def load_model(data: ModelLoadRequest):
return EventSourceResponse(generator()) return EventSourceResponse(generator())
# Unload model endpoint
@app.get("/v1/model/unload", dependencies=[Depends(check_admin_key)]) @app.get("/v1/model/unload", dependencies=[Depends(check_admin_key)])
async def unload_model(): async def unload_model():
global model_container global model_container
@@ -99,6 +108,7 @@ async def unload_model():
model_container.unload() model_container.unload()
model_container = None model_container = None
# Encode tokens endpoint
@app.post("/v1/token/encode", dependencies=[Depends(check_api_key)]) @app.post("/v1/token/encode", dependencies=[Depends(check_api_key)])
async def encode_tokens(data: TokenEncodeRequest): async def encode_tokens(data: TokenEncodeRequest):
tokens = model_container.get_tokens(data.text, None, **data.get_params())[0].tolist() tokens = model_container.get_tokens(data.text, None, **data.get_params())[0].tolist()
@@ -106,6 +116,7 @@ async def encode_tokens(data: TokenEncodeRequest):
return response.json() return response.json()
# Decode tokens endpoint
@app.post("/v1/token/decode", dependencies=[Depends(check_api_key)]) @app.post("/v1/token/decode", dependencies=[Depends(check_api_key)])
async def decode_tokens(data: TokenDecodeRequest): async def decode_tokens(data: TokenDecodeRequest):
message = model_container.get_tokens(None, data.tokens, **data.get_params()) message = model_container.get_tokens(None, data.tokens, **data.get_params())
@@ -113,6 +124,7 @@ async def decode_tokens(data: TokenDecodeRequest):
return response.json() return response.json()
# Completions endpoint
@app.post("/v1/completions", dependencies=[Depends(check_api_key)]) @app.post("/v1/completions", dependencies=[Depends(check_api_key)])
async def generate_completion(request: Request, data: CompletionRequest): async def generate_completion(request: Request, data: CompletionRequest):
model_path = model_container.get_model_path() model_path = model_container.get_model_path()
@@ -138,6 +150,7 @@ async def generate_completion(request: Request, data: CompletionRequest):
return response.json() return response.json()
# Chat completions endpoint
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)]) @app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
async def generate_chat_completion(request: Request, data: ChatCompletionRequest): async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
model_path = model_container.get_model_path() model_path = model_container.get_model_path()
@@ -200,4 +213,9 @@ if __name__ == "__main__":
loading_bar.next() loading_bar.next()
network_config = config.get("network", {}) network_config = config.get("network", {})
uvicorn.run(app, host=network_config.get("host", "127.0.0.1"), port=network_config.get("port", 5000), log_level="debug") uvicorn.run(
app,
host=network_config.get("host", "127.0.0.1"),
port=network_config.get("port", 5000),
log_level="debug"
)