mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
* returning stop str if exists from gen * added chat template for firefunctionv2 * pulling tool vars from template * adding parsing for tool inputs/outputs * passing tool data from endpoint to chat template, adding tool_start to the stop list * loosened typing on the response tool call, leaning more on the user supplying a quality schema if they want a particular format * non streaming generation prototype * cleaning template * Continued work with type, ingestion into template, and chat template for fire func * Correction - streaming toolcall comes back as delta obj not inside chatcomprespchoice per chat_completion_chunk.py inside OAI lib. * Ruff Formating * Moved stop string and tool updates out of prompt creation func Updated tool pydantic to match OAI Support for streaming Updated generate tool calls to use flag within chat_template and insert tool reminder * Llama 3.1 chat templates Updated fire func template * renamed llama3.1 to chatml_with_headers.. * update name of template * Support for calling a tool start token rather than the string. Simplified tool_params Warning when gen_settings are being overidden becuase user set temp to 0 Corrected schema and tools to correct types for function args. Str for some reason * draft groq tool use model template * changed headers to vars for readablity (but mostly because some models are weird about newlines after headers, so this is an easier way to change globally) * Clean up comments and code in chat comp * Post processed tool call to meet OAI spec rather than forcing model to write json in a string in the middle of the call. * changes example back to args as json rather than string of json * Standardize chat templates to each other * cleaning/rewording * stop elements can also be ints (tokens) * Cleaning/formatting * added special tokens for tools and tool_response as specified in description * Cleaning * removing aux templates - going to live in llm-promp-templates repo instead * Tree: Format Signed-off-by: kingbri <bdashore3@proton.me> * Chat Completions: Don't include internal tool variables in OpenAPI Use SkipJsonSchema to supress inclusion with the OpenAPI JSON. The location of these variables may need to be changed in the future. Signed-off-by: kingbri <bdashore3@proton.me> * Templates: Deserialize metadata on template load Since we're only looking for specific template variables that are static in the template, it makes more sense to render when the template is initialized. Signed-off-by: kingbri <bdashore3@proton.me> * Tools: Fix comments Adhere to the format style of comments in the rest of the project. Signed-off-by: kingbri <bdashore3@proton.me> --------- Co-authored-by: Ben Gitter <gitterbd@gmail.com> Signed-off-by: kingbri <bdashore3@proton.me>
418 lines
14 KiB
Python
418 lines
14 KiB
Python
"""Common functions for sampling parameters"""
|
|
|
|
import pathlib
|
|
import yaml
|
|
from copy import deepcopy
|
|
from loguru import logger
|
|
from pydantic import AliasChoices, BaseModel, Field
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
from common.utils import unwrap, prune_dict
|
|
|
|
|
|
# Common class for sampler params
|
|
class BaseSamplerRequest(BaseModel):
|
|
"""Common class for sampler params that are used in APIs"""
|
|
|
|
max_tokens: Optional[int] = Field(
|
|
default_factory=lambda: get_default_sampler_value("max_tokens"),
|
|
validation_alias=AliasChoices("max_tokens", "max_length"),
|
|
description="Aliases: max_length",
|
|
examples=[150],
|
|
)
|
|
|
|
min_tokens: Optional[int] = Field(
|
|
default_factory=lambda: get_default_sampler_value("min_tokens", 0),
|
|
validation_alias=AliasChoices("min_tokens", "min_length"),
|
|
description="Aliases: min_length",
|
|
examples=[0],
|
|
)
|
|
|
|
generate_window: Optional[int] = Field(
|
|
default_factory=lambda: get_default_sampler_value("generate_window"),
|
|
examples=[512],
|
|
)
|
|
|
|
stop: Optional[Union[str, List[Union[str, int]]]] = Field(
|
|
default_factory=lambda: get_default_sampler_value("stop", []),
|
|
validation_alias=AliasChoices("stop", "stop_sequence"),
|
|
description="Aliases: stop_sequence",
|
|
)
|
|
|
|
banned_strings: Optional[Union[str, List[str]]] = Field(
|
|
default_factory=lambda: get_default_sampler_value("banned_strings", [])
|
|
)
|
|
|
|
banned_tokens: Optional[Union[List[int], str]] = Field(
|
|
default_factory=lambda: get_default_sampler_value("banned_tokens", []),
|
|
validation_alias=AliasChoices("banned_tokens", "custom_token_bans"),
|
|
description="Aliases: custom_token_bans",
|
|
examples=[[128, 330]],
|
|
)
|
|
|
|
token_healing: Optional[bool] = Field(
|
|
default_factory=lambda: get_default_sampler_value("token_healing", False)
|
|
)
|
|
|
|
temperature: Optional[float] = Field(
|
|
default_factory=lambda: get_default_sampler_value("temperature", 1.0),
|
|
examples=[1.0],
|
|
)
|
|
|
|
temperature_last: Optional[bool] = Field(
|
|
default_factory=lambda: get_default_sampler_value("temperature_last", False)
|
|
)
|
|
|
|
smoothing_factor: Optional[float] = Field(
|
|
default_factory=lambda: get_default_sampler_value("smoothing_factor", 0.0),
|
|
)
|
|
|
|
top_k: Optional[int] = Field(
|
|
default_factory=lambda: get_default_sampler_value("top_k", 0),
|
|
)
|
|
|
|
top_p: Optional[float] = Field(
|
|
default_factory=lambda: get_default_sampler_value("top_p", 1.0),
|
|
examples=[1.0],
|
|
)
|
|
|
|
top_a: Optional[float] = Field(
|
|
default_factory=lambda: get_default_sampler_value("top_a", 0.0)
|
|
)
|
|
|
|
min_p: Optional[float] = Field(
|
|
default_factory=lambda: get_default_sampler_value("min_p", 0.0)
|
|
)
|
|
|
|
tfs: Optional[float] = Field(
|
|
default_factory=lambda: get_default_sampler_value("tfs", 1.0),
|
|
examples=[1.0],
|
|
)
|
|
|
|
typical: Optional[float] = Field(
|
|
default_factory=lambda: get_default_sampler_value("typical", 1.0),
|
|
validation_alias=AliasChoices("typical", "typical_p"),
|
|
description="Aliases: typical_p",
|
|
examples=[1.0],
|
|
)
|
|
|
|
skew: Optional[float] = Field(
|
|
default_factory=lambda: get_default_sampler_value("skew", 0.0),
|
|
examples=[0.0],
|
|
)
|
|
|
|
frequency_penalty: Optional[float] = Field(
|
|
default_factory=lambda: get_default_sampler_value("frequency_penalty", 0.0)
|
|
)
|
|
|
|
presence_penalty: Optional[float] = Field(
|
|
default_factory=lambda: get_default_sampler_value("presence_penalty", 0.0)
|
|
)
|
|
|
|
repetition_penalty: Optional[float] = Field(
|
|
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
|
|
validation_alias=AliasChoices("repetition_penalty", "rep_pen"),
|
|
description="Aliases: rep_pen",
|
|
examples=[1.0],
|
|
)
|
|
|
|
penalty_range: Optional[int] = Field(
|
|
default_factory=lambda: get_default_sampler_value("penalty_range", -1),
|
|
validation_alias=AliasChoices(
|
|
"penalty_range",
|
|
"repetition_range",
|
|
"repetition_penalty_range",
|
|
"rep_pen_range",
|
|
),
|
|
description=(
|
|
"Aliases: repetition_range, repetition_penalty_range, " "rep_pen_range"
|
|
),
|
|
)
|
|
|
|
repetition_decay: Optional[int] = Field(
|
|
default_factory=lambda: get_default_sampler_value("repetition_decay", 0)
|
|
)
|
|
|
|
mirostat_mode: Optional[int] = Field(
|
|
default_factory=lambda: get_default_sampler_value("mirostat_mode", 0)
|
|
)
|
|
|
|
mirostat_tau: Optional[float] = Field(
|
|
default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5),
|
|
examples=[1.5],
|
|
)
|
|
|
|
mirostat_eta: Optional[float] = Field(
|
|
default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3),
|
|
examples=[0.3],
|
|
)
|
|
|
|
add_bos_token: Optional[bool] = Field(
|
|
default_factory=lambda: get_default_sampler_value("add_bos_token", True)
|
|
)
|
|
|
|
ban_eos_token: Optional[bool] = Field(
|
|
default_factory=lambda: get_default_sampler_value("ban_eos_token", False),
|
|
validation_alias=AliasChoices("ban_eos_token", "ignore_eos"),
|
|
description="Aliases: ignore_eos",
|
|
examples=[False],
|
|
)
|
|
|
|
skip_special_tokens: Optional[bool] = Field(
|
|
default_factory=lambda: get_default_sampler_value("skip_special_tokens", True),
|
|
examples=[True],
|
|
)
|
|
|
|
logit_bias: Optional[Dict[int, float]] = Field(
|
|
default_factory=lambda: get_default_sampler_value("logit_bias"),
|
|
examples=[{"1": 10, "2": 50}],
|
|
)
|
|
|
|
negative_prompt: Optional[str] = Field(
|
|
default_factory=lambda: get_default_sampler_value("negative_prompt")
|
|
)
|
|
|
|
json_schema: Optional[object] = Field(
|
|
default_factory=lambda: get_default_sampler_value("json_schema"),
|
|
)
|
|
|
|
regex_pattern: Optional[str] = Field(
|
|
default_factory=lambda: get_default_sampler_value("regex_pattern"),
|
|
)
|
|
|
|
grammar_string: Optional[str] = Field(
|
|
default_factory=lambda: get_default_sampler_value("grammar_string"),
|
|
)
|
|
|
|
speculative_ngram: Optional[bool] = Field(
|
|
default_factory=lambda: get_default_sampler_value("speculative_ngram"),
|
|
)
|
|
|
|
cfg_scale: Optional[float] = Field(
|
|
default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0),
|
|
validation_alias=AliasChoices("cfg_scale", "guidance_scale"),
|
|
description="Aliases: guidance_scale",
|
|
examples=[1.0],
|
|
)
|
|
|
|
max_temp: Optional[float] = Field(
|
|
default_factory=lambda: get_default_sampler_value("max_temp", 1.0),
|
|
validation_alias=AliasChoices("max_temp", "dynatemp_high"),
|
|
description="Aliases: dynatemp_high",
|
|
examples=[1.0],
|
|
)
|
|
|
|
min_temp: Optional[float] = Field(
|
|
default_factory=lambda: get_default_sampler_value("min_temp", 1.0),
|
|
validation_alias=AliasChoices("min_temp", "dynatemp_low"),
|
|
description="Aliases: dynatemp_low",
|
|
examples=[1.0],
|
|
)
|
|
|
|
temp_exponent: Optional[float] = Field(
|
|
default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0),
|
|
validation_alias=AliasChoices("temp_exponent", "dynatemp_exponent"),
|
|
examples=[1.0],
|
|
)
|
|
|
|
# TODO: Return back to adaptable class-based validation But that's just too much
|
|
# abstraction compared to simple if statements at the moment
|
|
def validate_params(self):
|
|
"""
|
|
Validates sampler parameters to be within sane ranges.
|
|
"""
|
|
|
|
# Temperature
|
|
if self.temperature < 0.0:
|
|
raise ValueError(
|
|
"Temperature must be a non-negative value. " f"Got {self.temperature}"
|
|
)
|
|
|
|
# Smoothing factor
|
|
if self.smoothing_factor < 0.0:
|
|
raise ValueError(
|
|
"Smoothing factor must be a non-negative value. "
|
|
f"Got {self.smoothing_factor}"
|
|
)
|
|
|
|
# Top K
|
|
if self.top_k < 0:
|
|
raise ValueError("Top K must be a non-negative value. " f"Got {self.top_k}")
|
|
|
|
# Top P
|
|
if self.top_p < 0.0 or self.top_p > 1.0:
|
|
raise ValueError("Top P must be in [0, 1]. " f"Got {self.top_p}")
|
|
|
|
# Repetition Penalty
|
|
if self.repetition_penalty <= 0.0:
|
|
raise ValueError(
|
|
"Repetition penalty must be a positive value. "
|
|
f"Got {self.repetition_penalty}"
|
|
)
|
|
|
|
# Typical
|
|
if self.typical <= 0 and self.typical > 1:
|
|
raise ValueError("Typical must be in (0, 1]. " f"Got {self.typical}")
|
|
|
|
# Dynatemp values
|
|
if self.max_temp < 0.0:
|
|
raise ValueError(
|
|
"Max temp must be a non-negative value. ", f"Got {self.max_temp}"
|
|
)
|
|
|
|
if self.min_temp < 0.0:
|
|
raise ValueError(
|
|
"Min temp must be a non-negative value. ", f"Got {self.min_temp}"
|
|
)
|
|
|
|
if self.temp_exponent < 0.0:
|
|
raise ValueError(
|
|
"Temp exponent must be a non-negative value. ",
|
|
f"Got {self.temp_exponent}",
|
|
)
|
|
|
|
def to_gen_params(self, **kwargs):
|
|
"""Converts samplers to internal generation params"""
|
|
|
|
# Add forced overrides if present
|
|
apply_forced_sampler_overrides(self)
|
|
|
|
self.validate_params()
|
|
|
|
# Convert stop to an array of strings
|
|
if self.stop and isinstance(self.stop, str):
|
|
self.stop = [self.stop]
|
|
|
|
# Convert banned_strings to an array of strings
|
|
if self.banned_strings and isinstance(self.banned_strings, str):
|
|
self.banned_strings = [self.banned_strings]
|
|
|
|
# Convert string banned tokens to an integer list
|
|
if self.banned_tokens and isinstance(self.banned_tokens, str):
|
|
self.banned_tokens = [
|
|
int(x) for x in self.banned_tokens.split(",") if x.isdigit()
|
|
]
|
|
|
|
gen_params = {
|
|
"max_tokens": self.max_tokens,
|
|
"min_tokens": self.min_tokens,
|
|
"generate_window": self.generate_window,
|
|
"stop": self.stop,
|
|
"banned_strings": self.banned_strings,
|
|
"add_bos_token": self.add_bos_token,
|
|
"ban_eos_token": self.ban_eos_token,
|
|
"skip_special_tokens": self.skip_special_tokens,
|
|
"token_healing": self.token_healing,
|
|
"logit_bias": self.logit_bias,
|
|
"banned_tokens": self.banned_tokens,
|
|
"temperature": self.temperature,
|
|
"temperature_last": self.temperature_last,
|
|
"min_temp": self.min_temp,
|
|
"max_temp": self.max_temp,
|
|
"temp_exponent": self.temp_exponent,
|
|
"smoothing_factor": self.smoothing_factor,
|
|
"top_k": self.top_k,
|
|
"top_p": self.top_p,
|
|
"top_a": self.top_a,
|
|
"typical": self.typical,
|
|
"min_p": self.min_p,
|
|
"tfs": self.tfs,
|
|
"skew": self.skew,
|
|
"frequency_penalty": self.frequency_penalty,
|
|
"presence_penalty": self.presence_penalty,
|
|
"repetition_penalty": self.repetition_penalty,
|
|
"penalty_range": self.penalty_range,
|
|
"repetition_decay": self.repetition_decay,
|
|
"mirostat": self.mirostat_mode == 2,
|
|
"mirostat_tau": self.mirostat_tau,
|
|
"mirostat_eta": self.mirostat_eta,
|
|
"cfg_scale": self.cfg_scale,
|
|
"negative_prompt": self.negative_prompt,
|
|
"json_schema": self.json_schema,
|
|
"regex_pattern": self.regex_pattern,
|
|
"grammar_string": self.grammar_string,
|
|
"speculative_ngram": self.speculative_ngram,
|
|
}
|
|
|
|
return {**gen_params, **kwargs}
|
|
|
|
|
|
class SamplerOverridesContainer(BaseModel):
|
|
selected_preset: Optional[str] = None
|
|
overrides: dict = {}
|
|
|
|
|
|
# Global for default overrides
|
|
overrides_container = SamplerOverridesContainer()
|
|
|
|
|
|
def overrides_from_dict(new_overrides: dict):
|
|
"""Wrapper function to update sampler overrides"""
|
|
|
|
if isinstance(new_overrides, dict):
|
|
overrides_container.overrides = prune_dict(new_overrides)
|
|
else:
|
|
raise TypeError("New sampler overrides must be a dict!")
|
|
|
|
|
|
def overrides_from_file(preset_name: str):
|
|
"""Fetches an override preset from a file"""
|
|
|
|
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
|
|
if preset_path.exists():
|
|
overrides_container.selected_preset = preset_path.stem
|
|
with open(preset_path, "r", encoding="utf8") as raw_preset:
|
|
preset = yaml.safe_load(raw_preset)
|
|
overrides_from_dict(preset)
|
|
|
|
logger.info("Applied sampler overrides from file.")
|
|
else:
|
|
error_message = (
|
|
f'Sampler override file named "{preset_name}" was not found. '
|
|
+ "Make sure it's located in the sampler_overrides folder."
|
|
)
|
|
|
|
raise FileNotFoundError(error_message)
|
|
|
|
|
|
def get_all_presets():
|
|
"""Fetches all sampler override presets from the overrides directory"""
|
|
|
|
override_directory = pathlib.Path("sampler_overrides")
|
|
preset_files = [file.stem for file in override_directory.glob("*.yml")]
|
|
|
|
return preset_files
|
|
|
|
|
|
# TODO: Maybe move these into the class
|
|
# Classmethods aren't recognized in pydantic default_factories
|
|
def get_default_sampler_value(key, fallback=None):
|
|
"""Gets an overridden default sampler value"""
|
|
|
|
default_value = unwrap(
|
|
deepcopy(overrides_container.overrides.get(key, {}).get("override")),
|
|
fallback,
|
|
)
|
|
|
|
return default_value
|
|
|
|
|
|
def apply_forced_sampler_overrides(params: BaseSamplerRequest):
|
|
"""Forcefully applies overrides if specified by the user"""
|
|
|
|
for var, value in overrides_container.overrides.items():
|
|
override = deepcopy(value.get("override"))
|
|
original_value = getattr(params, var, None)
|
|
|
|
# Force takes precedence over additive
|
|
# Additive only works on lists and doesn't remove duplicates
|
|
if override:
|
|
if unwrap(value.get("force"), False):
|
|
setattr(params, var, override)
|
|
elif (
|
|
unwrap(value.get("additive"), False)
|
|
and isinstance(override, list)
|
|
and isinstance(original_value, list)
|
|
):
|
|
setattr(params, var, override + original_value)
|