API: Add sampler override switching

Allow users to switch the currently overriden samplers via the API
so a restart isn't required to switch the overrides.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-01-24 01:20:58 -05:00
committed by Brian Dashore
parent de0ba7214c
commit b14c5443fd
3 changed files with 87 additions and 6 deletions

53
main.py
View File

@@ -26,7 +26,11 @@ from common.config import (
get_network_config,
)
from common.generators import call_with_semaphore, generate_with_semaphore
from common.sampling import get_overrides_from_file
from common.sampling import (
get_sampler_overrides,
set_overrides_from_file,
set_overrides_from_dict,
)
from common.templating import (
get_all_templates,
get_prompt_from_template,
@@ -43,6 +47,7 @@ from OAI.types.model import (
ModelLoadResponse,
ModelCardParameters,
)
from OAI.types.sampler_overrides import SamplerOverrideSwitchRequest
from OAI.types.template import TemplateList, TemplateSwitchRequest
from OAI.types.token import (
TokenEncodeRequest,
@@ -288,6 +293,47 @@ async def unload_template():
MODEL_CONTAINER.prompt_template = None
# Sampler override endpoints
@app.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
@app.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
async def list_sampler_overrides():
"""API wrapper to list all currently applied sampler overrides"""
return get_sampler_overrides()
@app.post(
"/v1/sampling/override/switch",
dependencies=[Depends(check_admin_key)],
)
async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
"""Switch the currently loaded override preset"""
if data.preset:
try:
set_overrides_from_file(data.preset)
except FileNotFoundError as e:
raise HTTPException(
400, "Sampler override preset does not exist. Check the name?"
) from e
elif data.overrides:
set_overrides_from_dict(data.overrides)
else:
raise HTTPException(
400, "A sampler override preset or dictionary wasn't provided."
)
@app.post(
"/v1/sampling/override/unload",
dependencies=[Depends(check_admin_key)],
)
async def unload_sampler_override():
"""Unloads the currently selected override preset"""
set_overrides_from_dict({})
# Lora list endpoint
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
@@ -558,7 +604,10 @@ def entrypoint(args: Optional[dict] = None):
sampling_config = get_sampling_config()
sampling_override_preset = sampling_config.get("override_preset")
if sampling_override_preset:
get_overrides_from_file(sampling_override_preset)
try:
set_overrides_from_file(sampling_override_preset)
except FileNotFoundError as e:
logger.warning(str(e))
# If an initial model name is specified, create a container
# and load the model