API: Add preset listing for sampler overrides

Querying the overrides list endpoint now returns the selected preset
and a list of presets to use.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-05-12 01:34:51 -04:00
parent b4bc941cbe
commit 6f4012d20d
3 changed files with 35 additions and 9 deletions

View File

@@ -313,17 +313,20 @@ class BaseSamplerRequest(BaseModel):
return {**gen_params, **kwargs}
class SamplerOverridesContainer(BaseModel):
selected_preset: Optional[str] = None
overrides: dict = {}
# Global for default overrides
overrides = {}
overrides_container = SamplerOverridesContainer()
def overrides_from_dict(new_overrides: dict):
"""Wrapper function to update sampler overrides"""
global overrides
if isinstance(new_overrides, dict):
overrides = prune_dict(new_overrides)
overrides_container.overrides = prune_dict(new_overrides)
else:
raise TypeError("New sampler overrides must be a dict!")
@@ -333,6 +336,7 @@ def overrides_from_file(preset_name: str):
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
if preset_path.exists():
overrides_container.selected_preset = preset_path.stem
with open(preset_path, "r", encoding="utf8") as raw_preset:
preset = yaml.safe_load(raw_preset)
overrides_from_dict(preset)
@@ -347,18 +351,27 @@ def overrides_from_file(preset_name: str):
raise FileNotFoundError(error_message)
def get_all_presets():
"""Fetches all sampler override presets from the overrides directory"""
override_directory = pathlib.Path("sampler_overrides")
preset_files = map(lambda file: file.stem, override_directory.glob("*.yml"))
return preset_files
# TODO: Maybe move these into the class
# Classmethods aren't recognized in pydantic default_factories
def get_default_sampler_value(key, fallback=None):
"""Gets an overridden default sampler value"""
return unwrap(overrides.get(key, {}).get("override"), fallback)
return unwrap(overrides_container.overrides.get(key, {}).get("override"), fallback)
def apply_forced_sampler_overrides(params: BaseSamplerRequest):
"""Forcefully applies overrides if specified by the user"""
for var, value in overrides.items():
for var, value in overrides_container.overrides.items():
override = value.get("override")
original_value = getattr(params, var, None)

View File

@@ -32,7 +32,10 @@ from endpoints.OAI.types.model import (
ModelLoadRequest,
ModelCardParameters,
)
from endpoints.OAI.types.sampler_overrides import SamplerOverrideSwitchRequest
from endpoints.OAI.types.sampler_overrides import (
SamplerOverrideListResponse,
SamplerOverrideSwitchRequest,
)
from endpoints.OAI.types.template import TemplateList, TemplateSwitchRequest
from endpoints.OAI.types.token import (
TokenEncodeRequest,
@@ -248,7 +251,9 @@ async def unload_template():
async def list_sampler_overrides():
"""API wrapper to list all currently applied sampler overrides"""
return sampling.overrides
return SamplerOverrideListResponse(
presets=sampling.get_all_presets(), **sampling.overrides_container.model_dump()
)
@router.post(

View File

@@ -1,5 +1,13 @@
from pydantic import BaseModel, Field
from typing import Optional
from typing import List, Optional
from common.sampling import SamplerOverridesContainer
class SamplerOverrideListResponse(SamplerOverridesContainer):
"""Sampler override list response"""
presets: Optional[List[str]]
class SamplerOverrideSwitchRequest(BaseModel):