mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-20 14:28:54 +00:00
Tree: Unify sampler parameters and add override support
Unify API sampler params into a superclass which should make them easier to manage and inherit generic functions from. Not all frontends expose all sampling parameters due to connections with OAI (that handles sampling themselves with the exception of a few sliders). Add the ability for the user to customize fallback parameters from server-side. In addition, parameters can be forced to a certain value server-side in case the repo automatically sets other sampler values in the background that the user doesn't want. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
""" Common types for OAI. """
|
||||
from pydantic import BaseModel, Field, AliasChoices
|
||||
from typing import List, Dict, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from common.sampling import SamplerParams
|
||||
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
@@ -20,7 +22,7 @@ class UsageStats(BaseModel):
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class CommonCompletionRequest(BaseModel):
|
||||
class CommonCompletionRequest(SamplerParams):
|
||||
"""Represents a common completion request."""
|
||||
|
||||
# Model information
|
||||
@@ -47,87 +49,5 @@ class CommonCompletionRequest(BaseModel):
|
||||
description="Not parsed. Only used for OAI compliance.", default=None
|
||||
)
|
||||
|
||||
# Generation info
|
||||
# seed: Optional[int] = -1
|
||||
# Generation info (remainder is in SamplerParams superclass)
|
||||
stream: Optional[bool] = False
|
||||
stop: Optional[Union[str, List[str]]] = []
|
||||
|
||||
# Default to 150 as 16 makes no sense as a default
|
||||
max_tokens: Optional[int] = 150
|
||||
|
||||
# Sampling params
|
||||
token_healing: Optional[bool] = False
|
||||
temperature: Optional[float] = 1.0
|
||||
temperature_last: Optional[bool] = False
|
||||
top_k: Optional[int] = 0
|
||||
top_p: Optional[float] = 1.0
|
||||
top_a: Optional[float] = 0.0
|
||||
min_p: Optional[float] = 0.0
|
||||
tfs: Optional[float] = 1.0
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
repetition_decay: Optional[int] = 0
|
||||
mirostat_mode: Optional[int] = 0
|
||||
mirostat_tau: Optional[float] = 1.5
|
||||
mirostat_eta: Optional[float] = 0.1
|
||||
add_bos_token: Optional[bool] = True
|
||||
ban_eos_token: Optional[bool] = False
|
||||
logit_bias: Optional[Dict[int, float]] = Field(default=None, examples=[[{"1": 10}]])
|
||||
negative_prompt: Optional[str] = None
|
||||
|
||||
# Aliased variables
|
||||
typical: Optional[float] = Field(
|
||||
default=1.0,
|
||||
validation_alias=AliasChoices("typical", "typical_p"),
|
||||
description="Aliases: typical_p",
|
||||
)
|
||||
|
||||
penalty_range: Optional[int] = Field(
|
||||
default=-1,
|
||||
validation_alias=AliasChoices(
|
||||
"penalty_range",
|
||||
"repetition_range",
|
||||
"repetition_penalty_range",
|
||||
),
|
||||
description="Aliases: repetition_range, repetition_penalty_range",
|
||||
)
|
||||
|
||||
cfg_scale: Optional[float] = Field(
|
||||
default=1.0,
|
||||
validation_alias=AliasChoices("cfg_scale", "guidance_scale"),
|
||||
description="Aliases: guidance_scale",
|
||||
)
|
||||
|
||||
def to_gen_params(self):
|
||||
"""Converts to internal generation parameters."""
|
||||
# Convert stop to an array of strings
|
||||
if isinstance(self.stop, str):
|
||||
self.stop = [self.stop]
|
||||
|
||||
return {
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens,
|
||||
"add_bos_token": self.add_bos_token,
|
||||
"ban_eos_token": self.ban_eos_token,
|
||||
"token_healing": self.token_healing,
|
||||
"logit_bias": self.logit_bias,
|
||||
"temperature": self.temperature,
|
||||
"temperature_last": self.temperature_last,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"top_a": self.top_a,
|
||||
"typical": self.typical,
|
||||
"min_p": self.min_p,
|
||||
"tfs": self.tfs,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"penalty_range": self.penalty_range,
|
||||
"repetition_decay": self.repetition_decay,
|
||||
"mirostat": self.mirostat_mode == 2,
|
||||
"mirostat_tau": self.mirostat_tau,
|
||||
"mirostat_eta": self.mirostat_eta,
|
||||
"cfg_scale": self.cfg_scale,
|
||||
"negative_prompt": self.negative_prompt,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user