API: Split into separate folder

Moving the API into its own directory helps compartmentalize it
and allows for cleaning up the main file to just contain bootstrapping
and the entry point.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-03-11 22:45:30 -04:00
committed by Brian Dashore
parent 5a2de30066
commit 104a6121cb
13 changed files with 635 additions and 621 deletions

View File

@@ -0,0 +1,63 @@
from pydantic import BaseModel, Field
from time import time
from typing import Union, List, Optional, Dict
from uuid import uuid4
from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest
class ChatCompletionLogprob(BaseModel):
token: str
logprob: float
top_logprobs: Optional[List["ChatCompletionLogprob"]] = None
class ChatCompletionLogprobs(BaseModel):
content: List[ChatCompletionLogprob] = Field(default_factory=list)
class ChatCompletionMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionRespChoice(BaseModel):
# Index is 0 since we aren't using multiple choices
index: int = 0
finish_reason: str
message: ChatCompletionMessage
logprobs: Optional[ChatCompletionLogprobs] = None
class ChatCompletionStreamChoice(BaseModel):
# Index is 0 since we aren't using multiple choices
index: int = 0
finish_reason: Optional[str]
delta: Union[ChatCompletionMessage, dict] = {}
logprobs: Optional[ChatCompletionLogprobs] = None
# Inherited from common request
class ChatCompletionRequest(CommonCompletionRequest):
# Messages
# Take in a string as well even though it's not part of the OAI spec
messages: Union[str, List[Dict[str, str]]]
prompt_template: Optional[str] = None
add_generation_prompt: Optional[bool] = True
class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
choices: List[ChatCompletionRespChoice]
created: int = Field(default_factory=lambda: int(time()))
model: str
object: str = "chat.completion"
usage: Optional[UsageStats] = None
class ChatCompletionStreamChunk(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
choices: List[ChatCompletionStreamChoice]
created: int = Field(default_factory=lambda: int(time()))
model: str
object: str = "chat.completion.chunk"

View File

@@ -0,0 +1,47 @@
""" Common types for OAI. """
from pydantic import BaseModel, Field
from typing import Optional
from common.sampling import BaseSamplerRequest
class UsageStats(BaseModel):
"""Represents usage stats."""
prompt_tokens: int
completion_tokens: int
total_tokens: int
class CommonCompletionRequest(BaseSamplerRequest):
"""Represents a common completion request."""
# Model information
# This parameter is not used, the loaded model is used instead
model: Optional[str] = None
# Generation info (remainder is in BaseSamplerRequest superclass)
stream: Optional[bool] = False
logprobs: Optional[int] = 0
# Extra OAI request stuff
best_of: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
echo: Optional[bool] = Field(
description="Not parsed. Only used for OAI compliance.", default=False
)
n: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=1
)
suffix: Optional[str] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
user: Optional[str] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
def to_gen_params(self):
extra_gen_params = {"logprobs": self.logprobs}
return super().to_gen_params(**extra_gen_params)

View File

@@ -0,0 +1,46 @@
""" Completion API protocols """
from pydantic import BaseModel, Field
from time import time
from typing import Dict, List, Optional, Union
from uuid import uuid4
from endpoints.OAI.types.common import CommonCompletionRequest, UsageStats
class CompletionLogProbs(BaseModel):
"""Represents log probabilities for a completion request."""
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
class CompletionRespChoice(BaseModel):
"""Represents a single choice in a completion response."""
# Index is 0 since we aren't using multiple choices
index: int = 0
finish_reason: str
logprobs: Optional[CompletionLogProbs] = None
text: str
# Inherited from common request
class CompletionRequest(CommonCompletionRequest):
"""Represents a completion request."""
# Prompt can also contain token ids, but that's out of scope
# for this project.
prompt: Union[str, List[str]]
class CompletionResponse(BaseModel):
"""Represents a completion response."""
id: str = Field(default_factory=lambda: f"cmpl-{uuid4().hex}")
choices: List[CompletionRespChoice]
created: int = Field(default_factory=lambda: int(time()))
model: str
object: str = "text_completion"
usage: Optional[UsageStats] = None

View File

@@ -0,0 +1,42 @@
""" Lora types """
from pydantic import BaseModel, Field
from time import time
from typing import Optional, List
class LoraCard(BaseModel):
"""Represents a single Lora card."""
id: str = "test"
object: str = "lora"
created: int = Field(default_factory=lambda: int(time()))
owned_by: str = "tabbyAPI"
scaling: Optional[float] = None
class LoraList(BaseModel):
"""Represents a list of Lora cards."""
object: str = "list"
data: List[LoraCard] = Field(default_factory=list)
class LoraLoadInfo(BaseModel):
"""Represents a single Lora load info."""
name: str
scaling: Optional[float] = 1.0
class LoraLoadRequest(BaseModel):
"""Represents a Lora load request."""
loras: List[LoraLoadInfo]
skip_queue: bool = False
class LoraLoadResponse(BaseModel):
"""Represents a Lora load response."""
success: List[str] = Field(default_factory=list)
failure: List[str] = Field(default_factory=list)

View File

@@ -0,0 +1,109 @@
""" Contains model card types. """
from pydantic import BaseModel, Field, ConfigDict
from time import time
from typing import List, Optional
from common.gen_logging import GenLogPreferences
class ModelCardParameters(BaseModel):
"""Represents model card parameters."""
# Safe to do this since it's guaranteed to fetch a max seq len
# from model_container
max_seq_len: Optional[int] = None
rope_scale: Optional[float] = 1.0
rope_alpha: Optional[float] = 1.0
cache_mode: Optional[str] = "FP16"
prompt_template: Optional[str] = None
num_experts_per_token: Optional[int] = None
use_cfg: Optional[bool] = None
# Draft is another model, so include it in the card params
draft: Optional["ModelCard"] = None
class ModelCard(BaseModel):
"""Represents a single model card."""
id: str = "test"
object: str = "model"
created: int = Field(default_factory=lambda: int(time()))
owned_by: str = "tabbyAPI"
logging: Optional[GenLogPreferences] = None
parameters: Optional[ModelCardParameters] = None
class ModelList(BaseModel):
"""Represents a list of model cards."""
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)
class DraftModelLoadRequest(BaseModel):
"""Represents a draft model load request."""
draft_model_name: str
draft_rope_scale: Optional[float] = 1.0
draft_rope_alpha: Optional[float] = Field(
description="Automatically calculated if not present",
default=None,
examples=[1.0],
)
class ModelLoadRequest(BaseModel):
"""Represents a model load request."""
name: str
# Max seq len is fetched from config.json of the model by default
max_seq_len: Optional[int] = Field(
description="Leave this blank to use the model's base sequence length",
default=None,
examples=[4096],
)
override_base_seq_len: Optional[int] = Field(
description=(
"Overrides the model's base sequence length. " "Leave blank if unsure"
),
default=None,
examples=[4096],
)
gpu_split_auto: Optional[bool] = True
autosplit_reserve: Optional[List[float]] = [96]
gpu_split: Optional[List[float]] = Field(
default_factory=list, examples=[[24.0, 20.0]]
)
rope_scale: Optional[float] = Field(
description="Automatically pulled from the model's config if not present",
default=None,
examples=[1.0],
)
rope_alpha: Optional[float] = Field(
description="Automatically calculated if not present",
default=None,
examples=[1.0],
)
no_flash_attention: Optional[bool] = False
# low_mem: Optional[bool] = False
cache_mode: Optional[str] = "FP16"
prompt_template: Optional[str] = None
num_experts_per_token: Optional[int] = None
use_cfg: Optional[bool] = None
fasttensors: Optional[bool] = False
draft: Optional[DraftModelLoadRequest] = None
skip_queue: Optional[bool] = False
class ModelLoadResponse(BaseModel):
"""Represents a model load response."""
# Avoids pydantic namespace warning
model_config = ConfigDict(protected_namespaces=[])
model_type: str = "model"
module: int
modules: int
status: str

View File

@@ -0,0 +1,26 @@
from pydantic import BaseModel, Field
from typing import Optional
class SamplerOverrideSwitchRequest(BaseModel):
"""Sampler override switch request"""
preset: Optional[str] = Field(
default=None, description="Pass a sampler override preset name"
)
overrides: Optional[dict] = Field(
default=None,
description=(
"Sampling override parent takes in individual keys and overrides. "
+ "Ignored if preset is provided."
),
examples=[
{
"top_p": {
"override": 1.5,
"force": False,
}
}
],
)

View File

@@ -0,0 +1,15 @@
from pydantic import BaseModel, Field
from typing import List
class TemplateList(BaseModel):
"""Represents a list of templates."""
object: str = "list"
data: List[str] = Field(default_factory=list)
class TemplateSwitchRequest(BaseModel):
"""Request to switch a template."""
name: str

View File

@@ -0,0 +1,50 @@
""" Tokenization types """
from pydantic import BaseModel
from typing import List
class CommonTokenRequest(BaseModel):
"""Represents a common tokenization request."""
add_bos_token: bool = True
encode_special_tokens: bool = True
decode_special_tokens: bool = True
def get_params(self):
"""Get the parameters for tokenization."""
return {
"add_bos_token": self.add_bos_token,
"encode_special_tokens": self.encode_special_tokens,
"decode_special_tokens": self.decode_special_tokens,
}
class TokenEncodeRequest(CommonTokenRequest):
"""Represents a tokenization request."""
text: str
class TokenEncodeResponse(BaseModel):
"""Represents a tokenization response."""
tokens: List[int]
length: int
class TokenDecodeRequest(CommonTokenRequest):
""" " Represents a detokenization request."""
tokens: List[int]
class TokenDecodeResponse(BaseModel):
"""Represents a detokenization response."""
text: str
class TokenCountResponse(BaseModel):
"""Represents a token count response."""
length: int