mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-20 14:28:54 +00:00
Sampling: Cleanup and update
Cleanup how overrides are handled, class naming, and adopt exllamav2's model class to enforce latest stable version methods rather than adding multiple backwards compatability checks. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -6,14 +6,14 @@ from pydantic import AliasChoices, BaseModel, Field
|
||||
import yaml
|
||||
|
||||
from common.logger import init_logger
|
||||
from common.utils import unwrap
|
||||
from common.utils import unwrap, prune_dict
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Common class for sampler params
|
||||
class SamplerParams(BaseModel):
|
||||
class BaseSamplerRequest(BaseModel):
|
||||
"""Common class for sampler params that are used in APIs"""
|
||||
|
||||
max_tokens: Optional[int] = Field(
|
||||
@@ -164,7 +164,7 @@ class SamplerParams(BaseModel):
|
||||
if isinstance(self.stop, str):
|
||||
self.stop = [self.stop]
|
||||
|
||||
return {
|
||||
gen_params = {
|
||||
"max_tokens": self.max_tokens,
|
||||
"generate_window": self.generate_window,
|
||||
"stop": self.stop,
|
||||
@@ -196,6 +196,8 @@ class SamplerParams(BaseModel):
|
||||
"negative_prompt": self.negative_prompt,
|
||||
}
|
||||
|
||||
return gen_params
|
||||
|
||||
|
||||
# Global for default overrides
|
||||
DEFAULT_OVERRIDES = {}
|
||||
@@ -211,7 +213,7 @@ def set_overrides_from_dict(new_overrides: dict):
|
||||
global DEFAULT_OVERRIDES
|
||||
|
||||
if isinstance(new_overrides, dict):
|
||||
DEFAULT_OVERRIDES = new_overrides
|
||||
DEFAULT_OVERRIDES = prune_dict(new_overrides)
|
||||
else:
|
||||
raise TypeError("New sampler overrides must be a dict!")
|
||||
|
||||
@@ -243,7 +245,7 @@ def get_default_sampler_value(key, fallback=None):
|
||||
return unwrap(DEFAULT_OVERRIDES.get(key, {}).get("override"), fallback)
|
||||
|
||||
|
||||
def apply_forced_sampler_overrides(params: SamplerParams):
|
||||
def apply_forced_sampler_overrides(params: BaseSamplerRequest):
|
||||
"""Forcefully applies overrides if specified by the user"""
|
||||
|
||||
for var, value in DEFAULT_OVERRIDES.items():
|
||||
|
||||
Reference in New Issue
Block a user