API: Add sampler override switching

Allow users to switch the currently overriden samplers via the API
so a restart isn't required to switch the overrides.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-01-24 01:20:58 -05:00
committed by Brian Dashore
parent de0ba7214c
commit b14c5443fd
3 changed files with 87 additions and 6 deletions

View File

@@ -0,0 +1,26 @@
from pydantic import BaseModel, Field
from typing import Optional
class SamplerOverrideSwitchRequest(BaseModel):
"""Sampler override switch request"""
preset: Optional[str] = Field(
default=None, description="Pass a sampler override preset name"
)
overrides: Optional[dict] = Field(
default=None,
description=(
"Sampling override parent takes in individual keys and overrides."
+ "Ignored if preset is provided."
),
examples=[
{
"top_p": {
"override": 1.5,
"force": False,
}
}
],
)

View File

@@ -166,6 +166,10 @@ class SamplerParams(BaseModel):
DEFAULT_OVERRIDES = {}
def get_sampler_overrides():
return DEFAULT_OVERRIDES
def set_overrides_from_dict(new_overrides: dict):
"""Wrapper function to update sampler overrides"""
@@ -174,10 +178,10 @@ def set_overrides_from_dict(new_overrides: dict):
if isinstance(new_overrides, dict):
DEFAULT_OVERRIDES = new_overrides
else:
raise TypeError("new sampler overrides must be a dict!")
raise TypeError("New sampler overrides must be a dict!")
def get_overrides_from_file(preset_name: str):
def set_overrides_from_file(preset_name: str):
"""Fetches an override preset from a file"""
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
@@ -188,11 +192,13 @@ def get_overrides_from_file(preset_name: str):
logger.info("Applied sampler overrides from file.")
else:
logger.warn(
f"Sampler override file named \"{preset_name}\" was not found. "
error_message = (
f'Sampler override file named "{preset_name}" was not found. '
+ "Make sure it's located in the sampler_overrides folder."
)
raise FileNotFoundError(error_message)
# TODO: Maybe move these into the class
# Classmethods aren't recognized in pydantic default_factories

53
main.py
View File

@@ -26,7 +26,11 @@ from common.config import (
get_network_config,
)
from common.generators import call_with_semaphore, generate_with_semaphore
from common.sampling import get_overrides_from_file
from common.sampling import (
get_sampler_overrides,
set_overrides_from_file,
set_overrides_from_dict,
)
from common.templating import (
get_all_templates,
get_prompt_from_template,
@@ -43,6 +47,7 @@ from OAI.types.model import (
ModelLoadResponse,
ModelCardParameters,
)
from OAI.types.sampler_overrides import SamplerOverrideSwitchRequest
from OAI.types.template import TemplateList, TemplateSwitchRequest
from OAI.types.token import (
TokenEncodeRequest,
@@ -288,6 +293,47 @@ async def unload_template():
MODEL_CONTAINER.prompt_template = None
# Sampler override endpoints
@app.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
@app.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
async def list_sampler_overrides():
"""API wrapper to list all currently applied sampler overrides"""
return get_sampler_overrides()
@app.post(
"/v1/sampling/override/switch",
dependencies=[Depends(check_admin_key)],
)
async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
"""Switch the currently loaded override preset"""
if data.preset:
try:
set_overrides_from_file(data.preset)
except FileNotFoundError as e:
raise HTTPException(
400, "Sampler override preset does not exist. Check the name?"
) from e
elif data.overrides:
set_overrides_from_dict(data.overrides)
else:
raise HTTPException(
400, "A sampler override preset or dictionary wasn't provided."
)
@app.post(
"/v1/sampling/override/unload",
dependencies=[Depends(check_admin_key)],
)
async def unload_sampler_override():
"""Unloads the currently selected override preset"""
set_overrides_from_dict({})
# Lora list endpoint
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
@@ -558,7 +604,10 @@ def entrypoint(args: Optional[dict] = None):
sampling_config = get_sampling_config()
sampling_override_preset = sampling_config.get("override_preset")
if sampling_override_preset:
get_overrides_from_file(sampling_override_preset)
try:
set_overrides_from_file(sampling_override_preset)
except FileNotFoundError as e:
logger.warning(str(e))
# If an initial model name is specified, create a container
# and load the model