mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
API: Add sampler override switching
Allow users to switch the currently overriden samplers via the API so a restart isn't required to switch the overrides. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
53
main.py
53
main.py
@@ -26,7 +26,11 @@ from common.config import (
|
||||
get_network_config,
|
||||
)
|
||||
from common.generators import call_with_semaphore, generate_with_semaphore
|
||||
from common.sampling import get_overrides_from_file
|
||||
from common.sampling import (
|
||||
get_sampler_overrides,
|
||||
set_overrides_from_file,
|
||||
set_overrides_from_dict,
|
||||
)
|
||||
from common.templating import (
|
||||
get_all_templates,
|
||||
get_prompt_from_template,
|
||||
@@ -43,6 +47,7 @@ from OAI.types.model import (
|
||||
ModelLoadResponse,
|
||||
ModelCardParameters,
|
||||
)
|
||||
from OAI.types.sampler_overrides import SamplerOverrideSwitchRequest
|
||||
from OAI.types.template import TemplateList, TemplateSwitchRequest
|
||||
from OAI.types.token import (
|
||||
TokenEncodeRequest,
|
||||
@@ -288,6 +293,47 @@ async def unload_template():
|
||||
MODEL_CONTAINER.prompt_template = None
|
||||
|
||||
|
||||
# Sampler override endpoints
|
||||
@app.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_sampler_overrides():
|
||||
"""API wrapper to list all currently applied sampler overrides"""
|
||||
|
||||
return get_sampler_overrides()
|
||||
|
||||
|
||||
@app.post(
|
||||
"/v1/sampling/override/switch",
|
||||
dependencies=[Depends(check_admin_key)],
|
||||
)
|
||||
async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
||||
"""Switch the currently loaded override preset"""
|
||||
|
||||
if data.preset:
|
||||
try:
|
||||
set_overrides_from_file(data.preset)
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(
|
||||
400, "Sampler override preset does not exist. Check the name?"
|
||||
) from e
|
||||
elif data.overrides:
|
||||
set_overrides_from_dict(data.overrides)
|
||||
else:
|
||||
raise HTTPException(
|
||||
400, "A sampler override preset or dictionary wasn't provided."
|
||||
)
|
||||
|
||||
|
||||
@app.post(
|
||||
"/v1/sampling/override/unload",
|
||||
dependencies=[Depends(check_admin_key)],
|
||||
)
|
||||
async def unload_sampler_override():
|
||||
"""Unloads the currently selected override preset"""
|
||||
|
||||
set_overrides_from_dict({})
|
||||
|
||||
|
||||
# Lora list endpoint
|
||||
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
||||
@@ -558,7 +604,10 @@ def entrypoint(args: Optional[dict] = None):
|
||||
sampling_config = get_sampling_config()
|
||||
sampling_override_preset = sampling_config.get("override_preset")
|
||||
if sampling_override_preset:
|
||||
get_overrides_from_file(sampling_override_preset)
|
||||
try:
|
||||
set_overrides_from_file(sampling_override_preset)
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(str(e))
|
||||
|
||||
# If an initial model name is specified, create a container
|
||||
# and load the model
|
||||
|
||||
Reference in New Issue
Block a user