mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-28 18:21:42 +00:00
Tree: Basic formatting and comments
Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from time import time
|
from time import time
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List, Optional, Dict, Union
|
from typing import List, Optional, Union
|
||||||
from OAI.types.common import LogProbs, UsageStats, CommonCompletionRequest
|
from OAI.types.common import LogProbs, UsageStats, CommonCompletionRequest
|
||||||
|
|
||||||
class CompletionRespChoice(BaseModel):
|
class CompletionRespChoice(BaseModel):
|
||||||
|
|||||||
22
main.py
22
main.py
@@ -9,7 +9,12 @@ from sse_starlette import EventSourceResponse
|
|||||||
from OAI.types.completion import CompletionRequest
|
from OAI.types.completion import CompletionRequest
|
||||||
from OAI.types.chat_completion import ChatCompletionRequest
|
from OAI.types.chat_completion import ChatCompletionRequest
|
||||||
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
|
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
|
||||||
from OAI.types.token import TokenEncodeRequest, TokenEncodeResponse, TokenDecodeRequest, TokenDecodeResponse
|
from OAI.types.token import (
|
||||||
|
TokenEncodeRequest,
|
||||||
|
TokenEncodeResponse,
|
||||||
|
TokenDecodeRequest,
|
||||||
|
TokenDecodeResponse
|
||||||
|
)
|
||||||
from OAI.utils import (
|
from OAI.utils import (
|
||||||
create_completion_response,
|
create_completion_response,
|
||||||
get_model_list,
|
get_model_list,
|
||||||
@@ -27,6 +32,7 @@ app = FastAPI()
|
|||||||
model_container: Optional[ModelContainer] = None
|
model_container: Optional[ModelContainer] = None
|
||||||
config: dict = {}
|
config: dict = {}
|
||||||
|
|
||||||
|
# Model list endpoint
|
||||||
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
|
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||||
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||||
async def list_models():
|
async def list_models():
|
||||||
@@ -40,6 +46,7 @@ async def list_models():
|
|||||||
|
|
||||||
return models.json()
|
return models.json()
|
||||||
|
|
||||||
|
# Currently loaded model endpoint
|
||||||
@app.get("/v1/model", dependencies=[Depends(check_api_key)])
|
@app.get("/v1/model", dependencies=[Depends(check_api_key)])
|
||||||
async def get_current_model():
|
async def get_current_model():
|
||||||
if model_container is None or model_container.model is None:
|
if model_container is None or model_container.model is None:
|
||||||
@@ -48,6 +55,7 @@ async def get_current_model():
|
|||||||
model_card = ModelCard(id=model_container.get_model_path().name)
|
model_card = ModelCard(id=model_container.get_model_path().name)
|
||||||
return model_card.json()
|
return model_card.json()
|
||||||
|
|
||||||
|
# Load model endpoint
|
||||||
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
|
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
|
||||||
async def load_model(data: ModelLoadRequest):
|
async def load_model(data: ModelLoadRequest):
|
||||||
if model_container and model_container.model:
|
if model_container and model_container.model:
|
||||||
@@ -89,6 +97,7 @@ async def load_model(data: ModelLoadRequest):
|
|||||||
|
|
||||||
return EventSourceResponse(generator())
|
return EventSourceResponse(generator())
|
||||||
|
|
||||||
|
# Unload model endpoint
|
||||||
@app.get("/v1/model/unload", dependencies=[Depends(check_admin_key)])
|
@app.get("/v1/model/unload", dependencies=[Depends(check_admin_key)])
|
||||||
async def unload_model():
|
async def unload_model():
|
||||||
global model_container
|
global model_container
|
||||||
@@ -99,6 +108,7 @@ async def unload_model():
|
|||||||
model_container.unload()
|
model_container.unload()
|
||||||
model_container = None
|
model_container = None
|
||||||
|
|
||||||
|
# Encode tokens endpoint
|
||||||
@app.post("/v1/token/encode", dependencies=[Depends(check_api_key)])
|
@app.post("/v1/token/encode", dependencies=[Depends(check_api_key)])
|
||||||
async def encode_tokens(data: TokenEncodeRequest):
|
async def encode_tokens(data: TokenEncodeRequest):
|
||||||
tokens = model_container.get_tokens(data.text, None, **data.get_params())[0].tolist()
|
tokens = model_container.get_tokens(data.text, None, **data.get_params())[0].tolist()
|
||||||
@@ -106,6 +116,7 @@ async def encode_tokens(data: TokenEncodeRequest):
|
|||||||
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
# Decode tokens endpoint
|
||||||
@app.post("/v1/token/decode", dependencies=[Depends(check_api_key)])
|
@app.post("/v1/token/decode", dependencies=[Depends(check_api_key)])
|
||||||
async def decode_tokens(data: TokenDecodeRequest):
|
async def decode_tokens(data: TokenDecodeRequest):
|
||||||
message = model_container.get_tokens(None, data.tokens, **data.get_params())
|
message = model_container.get_tokens(None, data.tokens, **data.get_params())
|
||||||
@@ -113,6 +124,7 @@ async def decode_tokens(data: TokenDecodeRequest):
|
|||||||
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
# Completions endpoint
|
||||||
@app.post("/v1/completions", dependencies=[Depends(check_api_key)])
|
@app.post("/v1/completions", dependencies=[Depends(check_api_key)])
|
||||||
async def generate_completion(request: Request, data: CompletionRequest):
|
async def generate_completion(request: Request, data: CompletionRequest):
|
||||||
model_path = model_container.get_model_path()
|
model_path = model_container.get_model_path()
|
||||||
@@ -138,6 +150,7 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||||||
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
# Chat completions endpoint
|
||||||
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
|
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
|
||||||
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
|
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
|
||||||
model_path = model_container.get_model_path()
|
model_path = model_container.get_model_path()
|
||||||
@@ -200,4 +213,9 @@ if __name__ == "__main__":
|
|||||||
loading_bar.next()
|
loading_bar.next()
|
||||||
|
|
||||||
network_config = config.get("network", {})
|
network_config = config.get("network", {})
|
||||||
uvicorn.run(app, host=network_config.get("host", "127.0.0.1"), port=network_config.get("port", 5000), log_level="debug")
|
uvicorn.run(
|
||||||
|
app,
|
||||||
|
host=network_config.get("host", "127.0.0.1"),
|
||||||
|
port=network_config.get("port", 5000),
|
||||||
|
log_level="debug"
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user