mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-24 08:19:19 +00:00
API: Add KoboldAI server
Used for interacting with applications that use KoboldAI's API such as horde. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -828,10 +828,14 @@ class ExllamaV2Container:
|
|||||||
|
|
||||||
return dict(zip_longest(top_tokens, cleaned_values))
|
return dict(zip_longest(top_tokens, cleaned_values))
|
||||||
|
|
||||||
async def generate(self, prompt: str, request_id: str, **kwargs):
|
async def generate(
|
||||||
|
self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs
|
||||||
|
):
|
||||||
"""Generate a response to a prompt"""
|
"""Generate a response to a prompt"""
|
||||||
generations = []
|
generations = []
|
||||||
async for generation in self.generate_gen(prompt, request_id, **kwargs):
|
async for generation in self.generate_gen(
|
||||||
|
prompt, request_id, abort_event, **kwargs
|
||||||
|
):
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
joined_generation = {
|
joined_generation = {
|
||||||
|
|||||||
103
endpoints/Kobold/router.py
Normal file
103
endpoints/Kobold/router.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
from sys import maxsize
|
||||||
|
from fastapi import APIRouter, Depends, Request
|
||||||
|
from sse_starlette import EventSourceResponse
|
||||||
|
|
||||||
|
from common import model
|
||||||
|
from common.auth import check_api_key
|
||||||
|
from common.model import check_model_container
|
||||||
|
from common.utils import unwrap
|
||||||
|
from endpoints.Kobold.types.generation import (
|
||||||
|
AbortRequest,
|
||||||
|
CheckGenerateRequest,
|
||||||
|
GenerateRequest,
|
||||||
|
GenerateResponse,
|
||||||
|
)
|
||||||
|
from endpoints.Kobold.types.token import TokenCountRequest, TokenCountResponse
|
||||||
|
from endpoints.Kobold.utils.generation import (
|
||||||
|
abort_generation,
|
||||||
|
generation_status,
|
||||||
|
get_generation,
|
||||||
|
stream_generation,
|
||||||
|
)
|
||||||
|
from endpoints.core.utils.model import get_current_model
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/v1/generate",
|
||||||
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||||
|
)
|
||||||
|
async def generate(request: Request, data: GenerateRequest) -> GenerateResponse:
|
||||||
|
response = await get_generation(data, request)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/extra/generate/stream",
|
||||||
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||||
|
)
|
||||||
|
async def generate_stream(request: Request, data: GenerateRequest) -> GenerateResponse:
|
||||||
|
response = EventSourceResponse(stream_generation(data, request), ping=maxsize)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/extra/abort",
|
||||||
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||||
|
)
|
||||||
|
async def abort_generate(data: AbortRequest):
|
||||||
|
response = await abort_generation(data.genkey)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/extra/generate/check",
|
||||||
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||||
|
)
|
||||||
|
@router.post(
|
||||||
|
"/extra/generate/check",
|
||||||
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||||
|
)
|
||||||
|
async def check_generate(data: CheckGenerateRequest) -> GenerateResponse:
|
||||||
|
response = await generation_status(data.genkey)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/v1/model", dependencies=[Depends(check_api_key), Depends(check_model_container)]
|
||||||
|
)
|
||||||
|
async def current_model():
|
||||||
|
"""Fetches the current model and who owns it."""
|
||||||
|
|
||||||
|
current_model_card = get_current_model()
|
||||||
|
return {"result": f"{current_model_card.owned_by}/{current_model_card.id}"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/extra/tokencount",
|
||||||
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||||
|
)
|
||||||
|
async def get_tokencount(data: TokenCountRequest):
|
||||||
|
raw_tokens = model.container.encode_tokens(data.prompt)
|
||||||
|
tokens = unwrap(raw_tokens, [])
|
||||||
|
return TokenCountResponse(value=len(tokens), ids=tokens)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1/info/version")
|
||||||
|
async def get_version():
|
||||||
|
"""Impersonate KAI United."""
|
||||||
|
|
||||||
|
return {"result": "1.2.5"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/extra/version")
|
||||||
|
async def get_extra_version():
|
||||||
|
"""Impersonate Koboldcpp."""
|
||||||
|
|
||||||
|
return {"result": "KoboldCpp", "version": "1.61"}
|
||||||
53
endpoints/Kobold/types/generation.py
Normal file
53
endpoints/Kobold/types/generation.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from common.sampling import BaseSamplerRequest, get_default_sampler_value
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateRequest(BaseSamplerRequest):
|
||||||
|
prompt: str
|
||||||
|
use_default_badwordsids: Optional[bool] = False
|
||||||
|
genkey: Optional[str] = None
|
||||||
|
|
||||||
|
max_length: Optional[int] = Field(
|
||||||
|
default_factory=lambda: get_default_sampler_value("max_tokens"),
|
||||||
|
examples=[150],
|
||||||
|
)
|
||||||
|
rep_pen_range: Optional[int] = Field(
|
||||||
|
default_factory=lambda: get_default_sampler_value("penalty_range", -1),
|
||||||
|
)
|
||||||
|
rep_pen: Optional[float] = Field(
|
||||||
|
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_gen_params(self, **kwargs):
|
||||||
|
# Swap kobold generation params to OAI/Exl2 ones
|
||||||
|
self.max_tokens = self.max_length
|
||||||
|
self.repetition_penalty = self.rep_pen
|
||||||
|
self.penalty_range = -1 if self.rep_pen_range == 0 else self.rep_pen_range
|
||||||
|
|
||||||
|
return super().to_gen_params(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateResponseResult(BaseModel):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateResponse(BaseModel):
|
||||||
|
results: List[GenerateResponseResult] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamGenerateChunk(BaseModel):
|
||||||
|
token: str
|
||||||
|
|
||||||
|
|
||||||
|
class AbortRequest(BaseModel):
|
||||||
|
genkey: str
|
||||||
|
|
||||||
|
|
||||||
|
class AbortResponse(BaseModel):
|
||||||
|
success: bool
|
||||||
|
|
||||||
|
|
||||||
|
class CheckGenerateRequest(BaseModel):
|
||||||
|
genkey: str
|
||||||
15
endpoints/Kobold/types/token.py
Normal file
15
endpoints/Kobold/types/token.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
class TokenCountRequest(BaseModel):
|
||||||
|
"""Represents a KAI tokenization request."""
|
||||||
|
|
||||||
|
prompt: str
|
||||||
|
|
||||||
|
|
||||||
|
class TokenCountResponse(BaseModel):
|
||||||
|
"""Represents a KAI tokenization response."""
|
||||||
|
|
||||||
|
value: int
|
||||||
|
ids: List[int]
|
||||||
151
endpoints/Kobold/utils/generation.py
Normal file
151
endpoints/Kobold/utils/generation.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
import asyncio
|
||||||
|
from asyncio import CancelledError
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
from loguru import logger
|
||||||
|
from sse_starlette import ServerSentEvent
|
||||||
|
|
||||||
|
from common import model
|
||||||
|
from common.networking import (
|
||||||
|
get_generator_error,
|
||||||
|
handle_request_disconnect,
|
||||||
|
handle_request_error,
|
||||||
|
request_disconnect_loop,
|
||||||
|
)
|
||||||
|
from common.utils import unwrap
|
||||||
|
from endpoints.Kobold.types.generation import (
|
||||||
|
AbortResponse,
|
||||||
|
GenerateRequest,
|
||||||
|
GenerateResponse,
|
||||||
|
GenerateResponseResult,
|
||||||
|
StreamGenerateChunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
generation_cache = {}
|
||||||
|
|
||||||
|
|
||||||
|
async def override_request_id(request: Request, data: GenerateRequest):
|
||||||
|
"""Overrides the request ID with a KAI genkey if present."""
|
||||||
|
|
||||||
|
if data.genkey:
|
||||||
|
request.state.id = data.genkey
|
||||||
|
|
||||||
|
|
||||||
|
def _create_response(text: str):
|
||||||
|
results = [GenerateResponseResult(text=text)]
|
||||||
|
return GenerateResponse(results=results)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_stream_chunk(text: str):
|
||||||
|
return StreamGenerateChunk(token=text)
|
||||||
|
|
||||||
|
|
||||||
|
async def _stream_collector(data: GenerateRequest, request: Request):
|
||||||
|
"""Common async generator for generation streams."""
|
||||||
|
|
||||||
|
abort_event = asyncio.Event()
|
||||||
|
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||||
|
|
||||||
|
# Create a new entry in the cache
|
||||||
|
generation_cache[data.genkey] = {"abort": abort_event, "text": ""}
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Received Kobold generation request {data.genkey}")
|
||||||
|
|
||||||
|
generator = model.container.generate_gen(
|
||||||
|
data.prompt, data.genkey, abort_event, **data.to_gen_params()
|
||||||
|
)
|
||||||
|
async for generation in generator:
|
||||||
|
if disconnect_task.done():
|
||||||
|
abort_event.set()
|
||||||
|
handle_request_disconnect(
|
||||||
|
f"Kobold generation {data.genkey} cancelled by user."
|
||||||
|
)
|
||||||
|
|
||||||
|
text = generation.get("text")
|
||||||
|
|
||||||
|
# Update the generation cache with the new chunk
|
||||||
|
if text:
|
||||||
|
generation_cache[data.genkey]["text"] += text
|
||||||
|
yield text
|
||||||
|
|
||||||
|
if "finish_reason" in generation:
|
||||||
|
logger.info(f"Finished streaming Kobold request {data.genkey}")
|
||||||
|
break
|
||||||
|
except CancelledError:
|
||||||
|
# If the request disconnects, break out
|
||||||
|
if not disconnect_task.done():
|
||||||
|
abort_event.set()
|
||||||
|
handle_request_disconnect(
|
||||||
|
f"Kobold generation {data.genkey} cancelled by user."
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# Cleanup the cache
|
||||||
|
del generation_cache[data.genkey]
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_generation(data: GenerateRequest, request: Request):
|
||||||
|
"""Wrapper for stream generations."""
|
||||||
|
|
||||||
|
# If the genkey doesn't exist, set it to the request ID
|
||||||
|
if not data.genkey:
|
||||||
|
data.genkey = request.state.id
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for chunk in _stream_collector(data, request):
|
||||||
|
response = _create_stream_chunk(chunk)
|
||||||
|
yield ServerSentEvent(
|
||||||
|
event="message", data=response.model_dump_json(), sep="\n"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
yield get_generator_error(
|
||||||
|
f"Kobold generation {data.genkey} aborted. "
|
||||||
|
"Please check the server console."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_generation(data: GenerateRequest, request: Request):
|
||||||
|
"""Wrapper to get a static generation."""
|
||||||
|
|
||||||
|
# If the genkey doesn't exist, set it to the request ID
|
||||||
|
if not data.genkey:
|
||||||
|
data.genkey = request.state.id
|
||||||
|
|
||||||
|
try:
|
||||||
|
full_text = ""
|
||||||
|
async for chunk in _stream_collector(data, request):
|
||||||
|
full_text += chunk
|
||||||
|
|
||||||
|
response = _create_response(full_text)
|
||||||
|
return response
|
||||||
|
except Exception as exc:
|
||||||
|
error_message = handle_request_error(
|
||||||
|
f"Completion {request.state.id} aborted. Maybe the model was unloaded? "
|
||||||
|
"Please check the server console."
|
||||||
|
).error.message
|
||||||
|
|
||||||
|
# Server error if there's a generation exception
|
||||||
|
raise HTTPException(503, error_message) from exc
|
||||||
|
|
||||||
|
|
||||||
|
async def abort_generation(genkey: str):
|
||||||
|
"""Aborts a generation from the cache."""
|
||||||
|
|
||||||
|
abort_event = unwrap(generation_cache.get(genkey), {}).get("abort")
|
||||||
|
if abort_event:
|
||||||
|
abort_event.set()
|
||||||
|
handle_request_disconnect(f"Kobold generation {genkey} cancelled by user.")
|
||||||
|
|
||||||
|
return AbortResponse(success=True)
|
||||||
|
|
||||||
|
|
||||||
|
async def generation_status(genkey: str):
|
||||||
|
"""Fetches the status of a generation from the cache."""
|
||||||
|
|
||||||
|
current_text = unwrap(generation_cache.get(genkey), {}).get("text")
|
||||||
|
if current_text:
|
||||||
|
response = _create_response(current_text)
|
||||||
|
else:
|
||||||
|
response = GenerateResponse()
|
||||||
|
|
||||||
|
return response
|
||||||
@@ -9,6 +9,7 @@ from common.logger import UVICORN_LOG_CONFIG
|
|||||||
from common.networking import get_global_depends
|
from common.networking import get_global_depends
|
||||||
from common.utils import unwrap
|
from common.utils import unwrap
|
||||||
from endpoints.core.router import router as CoreRouter
|
from endpoints.core.router import router as CoreRouter
|
||||||
|
from endpoints.Kobold.router import router as KoboldRouter
|
||||||
from endpoints.OAI.router import router as OAIRouter
|
from endpoints.OAI.router import router as OAIRouter
|
||||||
|
|
||||||
|
|
||||||
@@ -37,7 +38,7 @@ def setup_app():
|
|||||||
api_servers = unwrap(config.network_config().get("api_servers"), [])
|
api_servers = unwrap(config.network_config().get("api_servers"), [])
|
||||||
|
|
||||||
# Map for API id to server router
|
# Map for API id to server router
|
||||||
router_mapping = {"oai": OAIRouter}
|
router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter}
|
||||||
|
|
||||||
# Include the OAI api by default
|
# Include the OAI api by default
|
||||||
if api_servers:
|
if api_servers:
|
||||||
|
|||||||
Reference in New Issue
Block a user