API: Add KoboldAI server

Used for interacting with applications that use KoboldAI's API
such as horde.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-07-26 16:37:30 -04:00
parent 4e808cbed7
commit b7cb6f0b91
6 changed files with 330 additions and 3 deletions

View File

@@ -828,10 +828,14 @@ class ExllamaV2Container:
return dict(zip_longest(top_tokens, cleaned_values)) return dict(zip_longest(top_tokens, cleaned_values))
async def generate(self, prompt: str, request_id: str, **kwargs): async def generate(
self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs
):
"""Generate a response to a prompt""" """Generate a response to a prompt"""
generations = [] generations = []
async for generation in self.generate_gen(prompt, request_id, **kwargs): async for generation in self.generate_gen(
prompt, request_id, abort_event, **kwargs
):
generations.append(generation) generations.append(generation)
joined_generation = { joined_generation = {

103
endpoints/Kobold/router.py Normal file
View File

@@ -0,0 +1,103 @@
from sys import maxsize
from fastapi import APIRouter, Depends, Request
from sse_starlette import EventSourceResponse
from common import model
from common.auth import check_api_key
from common.model import check_model_container
from common.utils import unwrap
from endpoints.Kobold.types.generation import (
AbortRequest,
CheckGenerateRequest,
GenerateRequest,
GenerateResponse,
)
from endpoints.Kobold.types.token import TokenCountRequest, TokenCountResponse
from endpoints.Kobold.utils.generation import (
abort_generation,
generation_status,
get_generation,
stream_generation,
)
from endpoints.core.utils.model import get_current_model
router = APIRouter(prefix="/api")
@router.post(
"/v1/generate",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def generate(request: Request, data: GenerateRequest) -> GenerateResponse:
response = await get_generation(data, request)
return response
@router.post(
"/extra/generate/stream",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def generate_stream(request: Request, data: GenerateRequest) -> GenerateResponse:
response = EventSourceResponse(stream_generation(data, request), ping=maxsize)
return response
@router.post(
"/extra/abort",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def abort_generate(data: AbortRequest):
response = await abort_generation(data.genkey)
return response
@router.get(
"/extra/generate/check",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
@router.post(
"/extra/generate/check",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def check_generate(data: CheckGenerateRequest) -> GenerateResponse:
response = await generation_status(data.genkey)
return response
@router.get(
"/v1/model", dependencies=[Depends(check_api_key), Depends(check_model_container)]
)
async def current_model():
"""Fetches the current model and who owns it."""
current_model_card = get_current_model()
return {"result": f"{current_model_card.owned_by}/{current_model_card.id}"}
@router.post(
"/extra/tokencount",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def get_tokencount(data: TokenCountRequest):
raw_tokens = model.container.encode_tokens(data.prompt)
tokens = unwrap(raw_tokens, [])
return TokenCountResponse(value=len(tokens), ids=tokens)
@router.get("/v1/info/version")
async def get_version():
"""Impersonate KAI United."""
return {"result": "1.2.5"}
@router.get("/extra/version")
async def get_extra_version():
"""Impersonate Koboldcpp."""
return {"result": "KoboldCpp", "version": "1.61"}

View File

@@ -0,0 +1,53 @@
from typing import List, Optional
from pydantic import BaseModel, Field
from common.sampling import BaseSamplerRequest, get_default_sampler_value
class GenerateRequest(BaseSamplerRequest):
prompt: str
use_default_badwordsids: Optional[bool] = False
genkey: Optional[str] = None
max_length: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("max_tokens"),
examples=[150],
)
rep_pen_range: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("penalty_range", -1),
)
rep_pen: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
)
def to_gen_params(self, **kwargs):
# Swap kobold generation params to OAI/Exl2 ones
self.max_tokens = self.max_length
self.repetition_penalty = self.rep_pen
self.penalty_range = -1 if self.rep_pen_range == 0 else self.rep_pen_range
return super().to_gen_params(**kwargs)
class GenerateResponseResult(BaseModel):
text: str
class GenerateResponse(BaseModel):
results: List[GenerateResponseResult] = Field(default_factory=list)
class StreamGenerateChunk(BaseModel):
token: str
class AbortRequest(BaseModel):
genkey: str
class AbortResponse(BaseModel):
success: bool
class CheckGenerateRequest(BaseModel):
genkey: str

View File

@@ -0,0 +1,15 @@
from pydantic import BaseModel
from typing import List
class TokenCountRequest(BaseModel):
"""Represents a KAI tokenization request."""
prompt: str
class TokenCountResponse(BaseModel):
"""Represents a KAI tokenization response."""
value: int
ids: List[int]

View File

@@ -0,0 +1,151 @@
import asyncio
from asyncio import CancelledError
from fastapi import HTTPException, Request
from loguru import logger
from sse_starlette import ServerSentEvent
from common import model
from common.networking import (
get_generator_error,
handle_request_disconnect,
handle_request_error,
request_disconnect_loop,
)
from common.utils import unwrap
from endpoints.Kobold.types.generation import (
AbortResponse,
GenerateRequest,
GenerateResponse,
GenerateResponseResult,
StreamGenerateChunk,
)
generation_cache = {}
async def override_request_id(request: Request, data: GenerateRequest):
"""Overrides the request ID with a KAI genkey if present."""
if data.genkey:
request.state.id = data.genkey
def _create_response(text: str):
results = [GenerateResponseResult(text=text)]
return GenerateResponse(results=results)
def _create_stream_chunk(text: str):
return StreamGenerateChunk(token=text)
async def _stream_collector(data: GenerateRequest, request: Request):
"""Common async generator for generation streams."""
abort_event = asyncio.Event()
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
# Create a new entry in the cache
generation_cache[data.genkey] = {"abort": abort_event, "text": ""}
try:
logger.info(f"Received Kobold generation request {data.genkey}")
generator = model.container.generate_gen(
data.prompt, data.genkey, abort_event, **data.to_gen_params()
)
async for generation in generator:
if disconnect_task.done():
abort_event.set()
handle_request_disconnect(
f"Kobold generation {data.genkey} cancelled by user."
)
text = generation.get("text")
# Update the generation cache with the new chunk
if text:
generation_cache[data.genkey]["text"] += text
yield text
if "finish_reason" in generation:
logger.info(f"Finished streaming Kobold request {data.genkey}")
break
except CancelledError:
# If the request disconnects, break out
if not disconnect_task.done():
abort_event.set()
handle_request_disconnect(
f"Kobold generation {data.genkey} cancelled by user."
)
finally:
# Cleanup the cache
del generation_cache[data.genkey]
async def stream_generation(data: GenerateRequest, request: Request):
"""Wrapper for stream generations."""
# If the genkey doesn't exist, set it to the request ID
if not data.genkey:
data.genkey = request.state.id
try:
async for chunk in _stream_collector(data, request):
response = _create_stream_chunk(chunk)
yield ServerSentEvent(
event="message", data=response.model_dump_json(), sep="\n"
)
except Exception:
yield get_generator_error(
f"Kobold generation {data.genkey} aborted. "
"Please check the server console."
)
async def get_generation(data: GenerateRequest, request: Request):
"""Wrapper to get a static generation."""
# If the genkey doesn't exist, set it to the request ID
if not data.genkey:
data.genkey = request.state.id
try:
full_text = ""
async for chunk in _stream_collector(data, request):
full_text += chunk
response = _create_response(full_text)
return response
except Exception as exc:
error_message = handle_request_error(
f"Completion {request.state.id} aborted. Maybe the model was unloaded? "
"Please check the server console."
).error.message
# Server error if there's a generation exception
raise HTTPException(503, error_message) from exc
async def abort_generation(genkey: str):
"""Aborts a generation from the cache."""
abort_event = unwrap(generation_cache.get(genkey), {}).get("abort")
if abort_event:
abort_event.set()
handle_request_disconnect(f"Kobold generation {genkey} cancelled by user.")
return AbortResponse(success=True)
async def generation_status(genkey: str):
"""Fetches the status of a generation from the cache."""
current_text = unwrap(generation_cache.get(genkey), {}).get("text")
if current_text:
response = _create_response(current_text)
else:
response = GenerateResponse()
return response

View File

@@ -9,6 +9,7 @@ from common.logger import UVICORN_LOG_CONFIG
from common.networking import get_global_depends from common.networking import get_global_depends
from common.utils import unwrap from common.utils import unwrap
from endpoints.core.router import router as CoreRouter from endpoints.core.router import router as CoreRouter
from endpoints.Kobold.router import router as KoboldRouter
from endpoints.OAI.router import router as OAIRouter from endpoints.OAI.router import router as OAIRouter
@@ -37,7 +38,7 @@ def setup_app():
api_servers = unwrap(config.network_config().get("api_servers"), []) api_servers = unwrap(config.network_config().get("api_servers"), [])
# Map for API id to server router # Map for API id to server router
router_mapping = {"oai": OAIRouter} router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter}
# Include the OAI api by default # Include the OAI api by default
if api_servers: if api_servers: