mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-05-05 21:51:16 +00:00
API: Split into separate folder
Moving the API into its own directory helps compartmentalize it and allows for cleaning up the main file to just contain bootstrapping and the entry point. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
63
endpoints/OAI/types/chat_completion.py
Normal file
63
endpoints/OAI/types/chat_completion.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from time import time
|
||||
from typing import Union, List, Optional, Dict
|
||||
from uuid import uuid4
|
||||
|
||||
from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest
|
||||
|
||||
|
||||
class ChatCompletionLogprob(BaseModel):
|
||||
token: str
|
||||
logprob: float
|
||||
top_logprobs: Optional[List["ChatCompletionLogprob"]] = None
|
||||
|
||||
|
||||
class ChatCompletionLogprobs(BaseModel):
|
||||
content: List[ChatCompletionLogprob] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionRespChoice(BaseModel):
|
||||
# Index is 0 since we aren't using multiple choices
|
||||
index: int = 0
|
||||
finish_reason: str
|
||||
message: ChatCompletionMessage
|
||||
logprobs: Optional[ChatCompletionLogprobs] = None
|
||||
|
||||
|
||||
class ChatCompletionStreamChoice(BaseModel):
|
||||
# Index is 0 since we aren't using multiple choices
|
||||
index: int = 0
|
||||
finish_reason: Optional[str]
|
||||
delta: Union[ChatCompletionMessage, dict] = {}
|
||||
logprobs: Optional[ChatCompletionLogprobs] = None
|
||||
|
||||
|
||||
# Inherited from common request
|
||||
class ChatCompletionRequest(CommonCompletionRequest):
|
||||
# Messages
|
||||
# Take in a string as well even though it's not part of the OAI spec
|
||||
messages: Union[str, List[Dict[str, str]]]
|
||||
prompt_template: Optional[str] = None
|
||||
add_generation_prompt: Optional[bool] = True
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
|
||||
choices: List[ChatCompletionRespChoice]
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
model: str
|
||||
object: str = "chat.completion"
|
||||
usage: Optional[UsageStats] = None
|
||||
|
||||
|
||||
class ChatCompletionStreamChunk(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
|
||||
choices: List[ChatCompletionStreamChoice]
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
model: str
|
||||
object: str = "chat.completion.chunk"
|
||||
47
endpoints/OAI/types/common.py
Normal file
47
endpoints/OAI/types/common.py
Normal file
@@ -0,0 +1,47 @@
|
||||
""" Common types for OAI. """
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
from common.sampling import BaseSamplerRequest
|
||||
|
||||
|
||||
class UsageStats(BaseModel):
|
||||
"""Represents usage stats."""
|
||||
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class CommonCompletionRequest(BaseSamplerRequest):
|
||||
"""Represents a common completion request."""
|
||||
|
||||
# Model information
|
||||
# This parameter is not used, the loaded model is used instead
|
||||
model: Optional[str] = None
|
||||
|
||||
# Generation info (remainder is in BaseSamplerRequest superclass)
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[int] = 0
|
||||
|
||||
# Extra OAI request stuff
|
||||
best_of: Optional[int] = Field(
|
||||
description="Not parsed. Only used for OAI compliance.", default=None
|
||||
)
|
||||
echo: Optional[bool] = Field(
|
||||
description="Not parsed. Only used for OAI compliance.", default=False
|
||||
)
|
||||
n: Optional[int] = Field(
|
||||
description="Not parsed. Only used for OAI compliance.", default=1
|
||||
)
|
||||
suffix: Optional[str] = Field(
|
||||
description="Not parsed. Only used for OAI compliance.", default=None
|
||||
)
|
||||
user: Optional[str] = Field(
|
||||
description="Not parsed. Only used for OAI compliance.", default=None
|
||||
)
|
||||
|
||||
def to_gen_params(self):
|
||||
extra_gen_params = {"logprobs": self.logprobs}
|
||||
|
||||
return super().to_gen_params(**extra_gen_params)
|
||||
46
endpoints/OAI/types/completion.py
Normal file
46
endpoints/OAI/types/completion.py
Normal file
@@ -0,0 +1,46 @@
|
||||
""" Completion API protocols """
|
||||
from pydantic import BaseModel, Field
|
||||
from time import time
|
||||
from typing import Dict, List, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from endpoints.OAI.types.common import CommonCompletionRequest, UsageStats
|
||||
|
||||
|
||||
class CompletionLogProbs(BaseModel):
|
||||
"""Represents log probabilities for a completion request."""
|
||||
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
tokens: List[str] = Field(default_factory=list)
|
||||
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CompletionRespChoice(BaseModel):
|
||||
"""Represents a single choice in a completion response."""
|
||||
|
||||
# Index is 0 since we aren't using multiple choices
|
||||
index: int = 0
|
||||
finish_reason: str
|
||||
logprobs: Optional[CompletionLogProbs] = None
|
||||
text: str
|
||||
|
||||
|
||||
# Inherited from common request
|
||||
class CompletionRequest(CommonCompletionRequest):
|
||||
"""Represents a completion request."""
|
||||
|
||||
# Prompt can also contain token ids, but that's out of scope
|
||||
# for this project.
|
||||
prompt: Union[str, List[str]]
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
"""Represents a completion response."""
|
||||
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{uuid4().hex}")
|
||||
choices: List[CompletionRespChoice]
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
model: str
|
||||
object: str = "text_completion"
|
||||
usage: Optional[UsageStats] = None
|
||||
42
endpoints/OAI/types/lora.py
Normal file
42
endpoints/OAI/types/lora.py
Normal file
@@ -0,0 +1,42 @@
|
||||
""" Lora types """
|
||||
from pydantic import BaseModel, Field
|
||||
from time import time
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
class LoraCard(BaseModel):
|
||||
"""Represents a single Lora card."""
|
||||
|
||||
id: str = "test"
|
||||
object: str = "lora"
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
owned_by: str = "tabbyAPI"
|
||||
scaling: Optional[float] = None
|
||||
|
||||
|
||||
class LoraList(BaseModel):
|
||||
"""Represents a list of Lora cards."""
|
||||
|
||||
object: str = "list"
|
||||
data: List[LoraCard] = Field(default_factory=list)
|
||||
|
||||
|
||||
class LoraLoadInfo(BaseModel):
|
||||
"""Represents a single Lora load info."""
|
||||
|
||||
name: str
|
||||
scaling: Optional[float] = 1.0
|
||||
|
||||
|
||||
class LoraLoadRequest(BaseModel):
|
||||
"""Represents a Lora load request."""
|
||||
|
||||
loras: List[LoraLoadInfo]
|
||||
skip_queue: bool = False
|
||||
|
||||
|
||||
class LoraLoadResponse(BaseModel):
|
||||
"""Represents a Lora load response."""
|
||||
|
||||
success: List[str] = Field(default_factory=list)
|
||||
failure: List[str] = Field(default_factory=list)
|
||||
109
endpoints/OAI/types/model.py
Normal file
109
endpoints/OAI/types/model.py
Normal file
@@ -0,0 +1,109 @@
|
||||
""" Contains model card types. """
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from time import time
|
||||
from typing import List, Optional
|
||||
|
||||
from common.gen_logging import GenLogPreferences
|
||||
|
||||
|
||||
class ModelCardParameters(BaseModel):
|
||||
"""Represents model card parameters."""
|
||||
|
||||
# Safe to do this since it's guaranteed to fetch a max seq len
|
||||
# from model_container
|
||||
max_seq_len: Optional[int] = None
|
||||
rope_scale: Optional[float] = 1.0
|
||||
rope_alpha: Optional[float] = 1.0
|
||||
cache_mode: Optional[str] = "FP16"
|
||||
prompt_template: Optional[str] = None
|
||||
num_experts_per_token: Optional[int] = None
|
||||
use_cfg: Optional[bool] = None
|
||||
|
||||
# Draft is another model, so include it in the card params
|
||||
draft: Optional["ModelCard"] = None
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
"""Represents a single model card."""
|
||||
|
||||
id: str = "test"
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
owned_by: str = "tabbyAPI"
|
||||
logging: Optional[GenLogPreferences] = None
|
||||
parameters: Optional[ModelCardParameters] = None
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
"""Represents a list of model cards."""
|
||||
|
||||
object: str = "list"
|
||||
data: List[ModelCard] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DraftModelLoadRequest(BaseModel):
|
||||
"""Represents a draft model load request."""
|
||||
|
||||
draft_model_name: str
|
||||
draft_rope_scale: Optional[float] = 1.0
|
||||
draft_rope_alpha: Optional[float] = Field(
|
||||
description="Automatically calculated if not present",
|
||||
default=None,
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
|
||||
class ModelLoadRequest(BaseModel):
|
||||
"""Represents a model load request."""
|
||||
|
||||
name: str
|
||||
|
||||
# Max seq len is fetched from config.json of the model by default
|
||||
max_seq_len: Optional[int] = Field(
|
||||
description="Leave this blank to use the model's base sequence length",
|
||||
default=None,
|
||||
examples=[4096],
|
||||
)
|
||||
override_base_seq_len: Optional[int] = Field(
|
||||
description=(
|
||||
"Overrides the model's base sequence length. " "Leave blank if unsure"
|
||||
),
|
||||
default=None,
|
||||
examples=[4096],
|
||||
)
|
||||
gpu_split_auto: Optional[bool] = True
|
||||
autosplit_reserve: Optional[List[float]] = [96]
|
||||
gpu_split: Optional[List[float]] = Field(
|
||||
default_factory=list, examples=[[24.0, 20.0]]
|
||||
)
|
||||
rope_scale: Optional[float] = Field(
|
||||
description="Automatically pulled from the model's config if not present",
|
||||
default=None,
|
||||
examples=[1.0],
|
||||
)
|
||||
rope_alpha: Optional[float] = Field(
|
||||
description="Automatically calculated if not present",
|
||||
default=None,
|
||||
examples=[1.0],
|
||||
)
|
||||
no_flash_attention: Optional[bool] = False
|
||||
# low_mem: Optional[bool] = False
|
||||
cache_mode: Optional[str] = "FP16"
|
||||
prompt_template: Optional[str] = None
|
||||
num_experts_per_token: Optional[int] = None
|
||||
use_cfg: Optional[bool] = None
|
||||
fasttensors: Optional[bool] = False
|
||||
draft: Optional[DraftModelLoadRequest] = None
|
||||
skip_queue: Optional[bool] = False
|
||||
|
||||
|
||||
class ModelLoadResponse(BaseModel):
|
||||
"""Represents a model load response."""
|
||||
|
||||
# Avoids pydantic namespace warning
|
||||
model_config = ConfigDict(protected_namespaces=[])
|
||||
|
||||
model_type: str = "model"
|
||||
module: int
|
||||
modules: int
|
||||
status: str
|
||||
26
endpoints/OAI/types/sampler_overrides.py
Normal file
26
endpoints/OAI/types/sampler_overrides.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class SamplerOverrideSwitchRequest(BaseModel):
|
||||
"""Sampler override switch request"""
|
||||
|
||||
preset: Optional[str] = Field(
|
||||
default=None, description="Pass a sampler override preset name"
|
||||
)
|
||||
|
||||
overrides: Optional[dict] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Sampling override parent takes in individual keys and overrides. "
|
||||
+ "Ignored if preset is provided."
|
||||
),
|
||||
examples=[
|
||||
{
|
||||
"top_p": {
|
||||
"override": 1.5,
|
||||
"force": False,
|
||||
}
|
||||
}
|
||||
],
|
||||
)
|
||||
15
endpoints/OAI/types/template.py
Normal file
15
endpoints/OAI/types/template.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List
|
||||
|
||||
|
||||
class TemplateList(BaseModel):
|
||||
"""Represents a list of templates."""
|
||||
|
||||
object: str = "list"
|
||||
data: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TemplateSwitchRequest(BaseModel):
|
||||
"""Request to switch a template."""
|
||||
|
||||
name: str
|
||||
50
endpoints/OAI/types/token.py
Normal file
50
endpoints/OAI/types/token.py
Normal file
@@ -0,0 +1,50 @@
|
||||
""" Tokenization types """
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
|
||||
|
||||
class CommonTokenRequest(BaseModel):
|
||||
"""Represents a common tokenization request."""
|
||||
|
||||
add_bos_token: bool = True
|
||||
encode_special_tokens: bool = True
|
||||
decode_special_tokens: bool = True
|
||||
|
||||
def get_params(self):
|
||||
"""Get the parameters for tokenization."""
|
||||
return {
|
||||
"add_bos_token": self.add_bos_token,
|
||||
"encode_special_tokens": self.encode_special_tokens,
|
||||
"decode_special_tokens": self.decode_special_tokens,
|
||||
}
|
||||
|
||||
|
||||
class TokenEncodeRequest(CommonTokenRequest):
|
||||
"""Represents a tokenization request."""
|
||||
|
||||
text: str
|
||||
|
||||
|
||||
class TokenEncodeResponse(BaseModel):
|
||||
"""Represents a tokenization response."""
|
||||
|
||||
tokens: List[int]
|
||||
length: int
|
||||
|
||||
|
||||
class TokenDecodeRequest(CommonTokenRequest):
|
||||
""" " Represents a detokenization request."""
|
||||
|
||||
tokens: List[int]
|
||||
|
||||
|
||||
class TokenDecodeResponse(BaseModel):
|
||||
"""Represents a detokenization response."""
|
||||
|
||||
text: str
|
||||
|
||||
|
||||
class TokenCountResponse(BaseModel):
|
||||
"""Represents a token count response."""
|
||||
|
||||
length: int
|
||||
Reference in New Issue
Block a user