mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-20 06:19:15 +00:00
API: Add preset listing for sampler overrides
Querying the overrides list endpoint now returns the selected preset and a list of presets to use. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -313,17 +313,20 @@ class BaseSamplerRequest(BaseModel):
|
||||
return {**gen_params, **kwargs}
|
||||
|
||||
|
||||
class SamplerOverridesContainer(BaseModel):
|
||||
selected_preset: Optional[str] = None
|
||||
overrides: dict = {}
|
||||
|
||||
|
||||
# Global for default overrides
|
||||
overrides = {}
|
||||
overrides_container = SamplerOverridesContainer()
|
||||
|
||||
|
||||
def overrides_from_dict(new_overrides: dict):
|
||||
"""Wrapper function to update sampler overrides"""
|
||||
|
||||
global overrides
|
||||
|
||||
if isinstance(new_overrides, dict):
|
||||
overrides = prune_dict(new_overrides)
|
||||
overrides_container.overrides = prune_dict(new_overrides)
|
||||
else:
|
||||
raise TypeError("New sampler overrides must be a dict!")
|
||||
|
||||
@@ -333,6 +336,7 @@ def overrides_from_file(preset_name: str):
|
||||
|
||||
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
|
||||
if preset_path.exists():
|
||||
overrides_container.selected_preset = preset_path.stem
|
||||
with open(preset_path, "r", encoding="utf8") as raw_preset:
|
||||
preset = yaml.safe_load(raw_preset)
|
||||
overrides_from_dict(preset)
|
||||
@@ -347,18 +351,27 @@ def overrides_from_file(preset_name: str):
|
||||
raise FileNotFoundError(error_message)
|
||||
|
||||
|
||||
def get_all_presets():
|
||||
"""Fetches all sampler override presets from the overrides directory"""
|
||||
|
||||
override_directory = pathlib.Path("sampler_overrides")
|
||||
preset_files = map(lambda file: file.stem, override_directory.glob("*.yml"))
|
||||
|
||||
return preset_files
|
||||
|
||||
|
||||
# TODO: Maybe move these into the class
|
||||
# Classmethods aren't recognized in pydantic default_factories
|
||||
def get_default_sampler_value(key, fallback=None):
|
||||
"""Gets an overridden default sampler value"""
|
||||
|
||||
return unwrap(overrides.get(key, {}).get("override"), fallback)
|
||||
return unwrap(overrides_container.overrides.get(key, {}).get("override"), fallback)
|
||||
|
||||
|
||||
def apply_forced_sampler_overrides(params: BaseSamplerRequest):
|
||||
"""Forcefully applies overrides if specified by the user"""
|
||||
|
||||
for var, value in overrides.items():
|
||||
for var, value in overrides_container.overrides.items():
|
||||
override = value.get("override")
|
||||
original_value = getattr(params, var, None)
|
||||
|
||||
|
||||
@@ -32,7 +32,10 @@ from endpoints.OAI.types.model import (
|
||||
ModelLoadRequest,
|
||||
ModelCardParameters,
|
||||
)
|
||||
from endpoints.OAI.types.sampler_overrides import SamplerOverrideSwitchRequest
|
||||
from endpoints.OAI.types.sampler_overrides import (
|
||||
SamplerOverrideListResponse,
|
||||
SamplerOverrideSwitchRequest,
|
||||
)
|
||||
from endpoints.OAI.types.template import TemplateList, TemplateSwitchRequest
|
||||
from endpoints.OAI.types.token import (
|
||||
TokenEncodeRequest,
|
||||
@@ -248,7 +251,9 @@ async def unload_template():
|
||||
async def list_sampler_overrides():
|
||||
"""API wrapper to list all currently applied sampler overrides"""
|
||||
|
||||
return sampling.overrides
|
||||
return SamplerOverrideListResponse(
|
||||
presets=sampling.get_all_presets(), **sampling.overrides_container.model_dump()
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from common.sampling import SamplerOverridesContainer
|
||||
|
||||
|
||||
class SamplerOverrideListResponse(SamplerOverridesContainer):
|
||||
"""Sampler override list response"""
|
||||
|
||||
presets: Optional[List[str]]
|
||||
|
||||
|
||||
class SamplerOverrideSwitchRequest(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user