Tree: Unify sampler parameters and add override support

Unify API sampler params into a superclass which should make them
easier to manage and inherit generic functions from.

Not all frontends expose all sampling parameters due to connections
with OAI (that handles sampling themselves with the exception of
a few sliders).

Add the ability for the user to customize fallback parameters from
server-side.

In addition, parameters can be forced to a certain value server-side
in case the repo automatically sets other sampler values in the
background that the user doesn't want.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-01-21 23:34:44 -05:00
committed by Brian Dashore
parent 78f920eeda
commit 6c30f24c83
7 changed files with 337 additions and 86 deletions

View File

@@ -1,6 +1,8 @@
""" Common types for OAI. """
from pydantic import BaseModel, Field, AliasChoices
from typing import List, Dict, Optional, Union
from pydantic import BaseModel, Field
from typing import List, Dict, Optional
from common.sampling import SamplerParams
class LogProbs(BaseModel):
@@ -20,7 +22,7 @@ class UsageStats(BaseModel):
total_tokens: int
class CommonCompletionRequest(BaseModel):
class CommonCompletionRequest(SamplerParams):
"""Represents a common completion request."""
# Model information
@@ -47,87 +49,5 @@ class CommonCompletionRequest(BaseModel):
description="Not parsed. Only used for OAI compliance.", default=None
)
# Generation info
# seed: Optional[int] = -1
# Generation info (remainder is in SamplerParams superclass)
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = []
# Default to 150 as 16 makes no sense as a default
max_tokens: Optional[int] = 150
# Sampling params
token_healing: Optional[bool] = False
temperature: Optional[float] = 1.0
temperature_last: Optional[bool] = False
top_k: Optional[int] = 0
top_p: Optional[float] = 1.0
top_a: Optional[float] = 0.0
min_p: Optional[float] = 0.0
tfs: Optional[float] = 1.0
frequency_penalty: Optional[float] = 0.0
presence_penalty: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
repetition_decay: Optional[int] = 0
mirostat_mode: Optional[int] = 0
mirostat_tau: Optional[float] = 1.5
mirostat_eta: Optional[float] = 0.1
add_bos_token: Optional[bool] = True
ban_eos_token: Optional[bool] = False
logit_bias: Optional[Dict[int, float]] = Field(default=None, examples=[[{"1": 10}]])
negative_prompt: Optional[str] = None
# Aliased variables
typical: Optional[float] = Field(
default=1.0,
validation_alias=AliasChoices("typical", "typical_p"),
description="Aliases: typical_p",
)
penalty_range: Optional[int] = Field(
default=-1,
validation_alias=AliasChoices(
"penalty_range",
"repetition_range",
"repetition_penalty_range",
),
description="Aliases: repetition_range, repetition_penalty_range",
)
cfg_scale: Optional[float] = Field(
default=1.0,
validation_alias=AliasChoices("cfg_scale", "guidance_scale"),
description="Aliases: guidance_scale",
)
def to_gen_params(self):
"""Converts to internal generation parameters."""
# Convert stop to an array of strings
if isinstance(self.stop, str):
self.stop = [self.stop]
return {
"stop": self.stop,
"max_tokens": self.max_tokens,
"add_bos_token": self.add_bos_token,
"ban_eos_token": self.ban_eos_token,
"token_healing": self.token_healing,
"logit_bias": self.logit_bias,
"temperature": self.temperature,
"temperature_last": self.temperature_last,
"top_k": self.top_k,
"top_p": self.top_p,
"top_a": self.top_a,
"typical": self.typical,
"min_p": self.min_p,
"tfs": self.tfs,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"repetition_penalty": self.repetition_penalty,
"penalty_range": self.penalty_range,
"repetition_decay": self.repetition_decay,
"mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau,
"mirostat_eta": self.mirostat_eta,
"cfg_scale": self.cfg_scale,
"negative_prompt": self.negative_prompt,
}