mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
API: Add sampler override switching
Allow users to switch the currently overriden samplers via the API so a restart isn't required to switch the overrides. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
26
OAI/types/sampler_overrides.py
Normal file
26
OAI/types/sampler_overrides.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class SamplerOverrideSwitchRequest(BaseModel):
|
||||
"""Sampler override switch request"""
|
||||
|
||||
preset: Optional[str] = Field(
|
||||
default=None, description="Pass a sampler override preset name"
|
||||
)
|
||||
|
||||
overrides: Optional[dict] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Sampling override parent takes in individual keys and overrides."
|
||||
+ "Ignored if preset is provided."
|
||||
),
|
||||
examples=[
|
||||
{
|
||||
"top_p": {
|
||||
"override": 1.5,
|
||||
"force": False,
|
||||
}
|
||||
}
|
||||
],
|
||||
)
|
||||
@@ -166,6 +166,10 @@ class SamplerParams(BaseModel):
|
||||
DEFAULT_OVERRIDES = {}
|
||||
|
||||
|
||||
def get_sampler_overrides():
|
||||
return DEFAULT_OVERRIDES
|
||||
|
||||
|
||||
def set_overrides_from_dict(new_overrides: dict):
|
||||
"""Wrapper function to update sampler overrides"""
|
||||
|
||||
@@ -174,10 +178,10 @@ def set_overrides_from_dict(new_overrides: dict):
|
||||
if isinstance(new_overrides, dict):
|
||||
DEFAULT_OVERRIDES = new_overrides
|
||||
else:
|
||||
raise TypeError("new sampler overrides must be a dict!")
|
||||
raise TypeError("New sampler overrides must be a dict!")
|
||||
|
||||
|
||||
def get_overrides_from_file(preset_name: str):
|
||||
def set_overrides_from_file(preset_name: str):
|
||||
"""Fetches an override preset from a file"""
|
||||
|
||||
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
|
||||
@@ -188,11 +192,13 @@ def get_overrides_from_file(preset_name: str):
|
||||
|
||||
logger.info("Applied sampler overrides from file.")
|
||||
else:
|
||||
logger.warn(
|
||||
f"Sampler override file named \"{preset_name}\" was not found. "
|
||||
error_message = (
|
||||
f'Sampler override file named "{preset_name}" was not found. '
|
||||
+ "Make sure it's located in the sampler_overrides folder."
|
||||
)
|
||||
|
||||
raise FileNotFoundError(error_message)
|
||||
|
||||
|
||||
# TODO: Maybe move these into the class
|
||||
# Classmethods aren't recognized in pydantic default_factories
|
||||
|
||||
53
main.py
53
main.py
@@ -26,7 +26,11 @@ from common.config import (
|
||||
get_network_config,
|
||||
)
|
||||
from common.generators import call_with_semaphore, generate_with_semaphore
|
||||
from common.sampling import get_overrides_from_file
|
||||
from common.sampling import (
|
||||
get_sampler_overrides,
|
||||
set_overrides_from_file,
|
||||
set_overrides_from_dict,
|
||||
)
|
||||
from common.templating import (
|
||||
get_all_templates,
|
||||
get_prompt_from_template,
|
||||
@@ -43,6 +47,7 @@ from OAI.types.model import (
|
||||
ModelLoadResponse,
|
||||
ModelCardParameters,
|
||||
)
|
||||
from OAI.types.sampler_overrides import SamplerOverrideSwitchRequest
|
||||
from OAI.types.template import TemplateList, TemplateSwitchRequest
|
||||
from OAI.types.token import (
|
||||
TokenEncodeRequest,
|
||||
@@ -288,6 +293,47 @@ async def unload_template():
|
||||
MODEL_CONTAINER.prompt_template = None
|
||||
|
||||
|
||||
# Sampler override endpoints
|
||||
@app.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_sampler_overrides():
|
||||
"""API wrapper to list all currently applied sampler overrides"""
|
||||
|
||||
return get_sampler_overrides()
|
||||
|
||||
|
||||
@app.post(
|
||||
"/v1/sampling/override/switch",
|
||||
dependencies=[Depends(check_admin_key)],
|
||||
)
|
||||
async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
||||
"""Switch the currently loaded override preset"""
|
||||
|
||||
if data.preset:
|
||||
try:
|
||||
set_overrides_from_file(data.preset)
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(
|
||||
400, "Sampler override preset does not exist. Check the name?"
|
||||
) from e
|
||||
elif data.overrides:
|
||||
set_overrides_from_dict(data.overrides)
|
||||
else:
|
||||
raise HTTPException(
|
||||
400, "A sampler override preset or dictionary wasn't provided."
|
||||
)
|
||||
|
||||
|
||||
@app.post(
|
||||
"/v1/sampling/override/unload",
|
||||
dependencies=[Depends(check_admin_key)],
|
||||
)
|
||||
async def unload_sampler_override():
|
||||
"""Unloads the currently selected override preset"""
|
||||
|
||||
set_overrides_from_dict({})
|
||||
|
||||
|
||||
# Lora list endpoint
|
||||
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
||||
@@ -558,7 +604,10 @@ def entrypoint(args: Optional[dict] = None):
|
||||
sampling_config = get_sampling_config()
|
||||
sampling_override_preset = sampling_config.get("override_preset")
|
||||
if sampling_override_preset:
|
||||
get_overrides_from_file(sampling_override_preset)
|
||||
try:
|
||||
set_overrides_from_file(sampling_override_preset)
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(str(e))
|
||||
|
||||
# If an initial model name is specified, create a container
|
||||
# and load the model
|
||||
|
||||
Reference in New Issue
Block a user