mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-20 06:19:15 +00:00
API keys are not allowed to view all the admin's models, templates, draft models, loras, etc. Basically anything that can be viewed on the filesystem outside of anything that's currently loaded is not allowed to be returned unless an admin key is present. This change helps preserve user privacy while not erroring out on list endpoints that the OAI spec requires. Signed-off-by: kingbri <bdashore3@proton.me>
124 lines
3.8 KiB
Python
124 lines
3.8 KiB
Python
import pathlib
|
|
from asyncio import CancelledError
|
|
from typing import Optional
|
|
|
|
from common import gen_logging, model
|
|
from common.networking import get_generator_error, handle_request_disconnect
|
|
from common.utils import unwrap
|
|
from endpoints.OAI.types.model import (
|
|
ModelCard,
|
|
ModelCardParameters,
|
|
ModelList,
|
|
ModelLoadRequest,
|
|
ModelLoadResponse,
|
|
)
|
|
|
|
|
|
def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None):
|
|
"""Get the list of models from the provided path."""
|
|
|
|
# Convert the provided draft model path to a pathlib path for
|
|
# equality comparisons
|
|
if draft_model_path:
|
|
draft_model_path = pathlib.Path(draft_model_path).resolve()
|
|
|
|
model_card_list = ModelList()
|
|
for path in model_path.iterdir():
|
|
# Don't include the draft models path
|
|
if path.is_dir() and path != draft_model_path:
|
|
model_card = ModelCard(id=path.name)
|
|
model_card_list.data.append(model_card) # pylint: disable=no-member
|
|
|
|
return model_card_list
|
|
|
|
|
|
async def get_current_model_list(is_draft: bool = False):
|
|
"""Gets the current model in list format and with path only."""
|
|
current_models = []
|
|
|
|
# Make sure the model container exists
|
|
if model.container:
|
|
model_path = model.container.get_model_path(is_draft)
|
|
if model_path:
|
|
current_models.append(ModelCard(id=model_path.name))
|
|
|
|
return ModelList(data=current_models)
|
|
|
|
|
|
def get_current_model():
|
|
"""Gets the current model with all parameters."""
|
|
|
|
model_params = model.container.get_model_parameters()
|
|
draft_model_params = model_params.pop("draft", {})
|
|
|
|
if draft_model_params:
|
|
model_params["draft"] = ModelCard(
|
|
id=unwrap(draft_model_params.get("name"), "unknown"),
|
|
parameters=ModelCardParameters.model_validate(draft_model_params),
|
|
)
|
|
else:
|
|
draft_model_params = None
|
|
|
|
model_card = ModelCard(
|
|
id=unwrap(model_params.pop("name", None), "unknown"),
|
|
parameters=ModelCardParameters.model_validate(model_params),
|
|
logging=gen_logging.PREFERENCES,
|
|
)
|
|
|
|
if draft_model_params:
|
|
draft_card = ModelCard(
|
|
id=unwrap(draft_model_params.pop("name", None), "unknown"),
|
|
parameters=ModelCardParameters.model_validate(draft_model_params),
|
|
)
|
|
|
|
model_card.parameters.draft = draft_card
|
|
|
|
return model_card
|
|
|
|
|
|
async def stream_model_load(
|
|
data: ModelLoadRequest,
|
|
model_path: pathlib.Path,
|
|
draft_model_path: str,
|
|
):
|
|
"""Request generation wrapper for the loading process."""
|
|
|
|
# Set the draft model path if it exists
|
|
load_data = data.model_dump()
|
|
if draft_model_path:
|
|
load_data["draft"]["draft_model_dir"] = draft_model_path
|
|
|
|
load_status = model.load_model_gen(
|
|
model_path, skip_wait=data.skip_queue, **load_data
|
|
)
|
|
try:
|
|
async for module, modules, model_type in load_status:
|
|
if module != 0:
|
|
response = ModelLoadResponse(
|
|
model_type=model_type,
|
|
module=module,
|
|
modules=modules,
|
|
status="processing",
|
|
)
|
|
|
|
yield response.model_dump_json()
|
|
|
|
if module == modules:
|
|
response = ModelLoadResponse(
|
|
model_type=model_type,
|
|
module=module,
|
|
modules=modules,
|
|
status="finished",
|
|
)
|
|
|
|
yield response.model_dump_json()
|
|
except CancelledError:
|
|
# Get out if the request gets disconnected
|
|
|
|
handle_request_disconnect(
|
|
"Model load cancelled by user. "
|
|
"Please make sure to run unload to free up resources."
|
|
)
|
|
except Exception as exc:
|
|
yield get_generator_error(str(exc))
|