Files
tabbyAPI/endpoints/OAI/utils/model.py
kingbri 1f46a1130c OAI: Restrict list permissions for API keys
API keys are not allowed to view all the admin's models, templates,
draft models, loras, etc. Basically anything that can be viewed
on the filesystem outside of anything that's currently loaded is
not allowed to be returned unless an admin key is present.

This change helps preserve user privacy while not erroring out on
list endpoints that the OAI spec requires.

Signed-off-by: kingbri <bdashore3@proton.me>
2024-07-11 14:22:50 -04:00

124 lines
3.8 KiB
Python

import pathlib
from asyncio import CancelledError
from typing import Optional
from common import gen_logging, model
from common.networking import get_generator_error, handle_request_disconnect
from common.utils import unwrap
from endpoints.OAI.types.model import (
ModelCard,
ModelCardParameters,
ModelList,
ModelLoadRequest,
ModelLoadResponse,
)
def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None):
"""Get the list of models from the provided path."""
# Convert the provided draft model path to a pathlib path for
# equality comparisons
if draft_model_path:
draft_model_path = pathlib.Path(draft_model_path).resolve()
model_card_list = ModelList()
for path in model_path.iterdir():
# Don't include the draft models path
if path.is_dir() and path != draft_model_path:
model_card = ModelCard(id=path.name)
model_card_list.data.append(model_card) # pylint: disable=no-member
return model_card_list
async def get_current_model_list(is_draft: bool = False):
"""Gets the current model in list format and with path only."""
current_models = []
# Make sure the model container exists
if model.container:
model_path = model.container.get_model_path(is_draft)
if model_path:
current_models.append(ModelCard(id=model_path.name))
return ModelList(data=current_models)
def get_current_model():
"""Gets the current model with all parameters."""
model_params = model.container.get_model_parameters()
draft_model_params = model_params.pop("draft", {})
if draft_model_params:
model_params["draft"] = ModelCard(
id=unwrap(draft_model_params.get("name"), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
else:
draft_model_params = None
model_card = ModelCard(
id=unwrap(model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(model_params),
logging=gen_logging.PREFERENCES,
)
if draft_model_params:
draft_card = ModelCard(
id=unwrap(draft_model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
model_card.parameters.draft = draft_card
return model_card
async def stream_model_load(
data: ModelLoadRequest,
model_path: pathlib.Path,
draft_model_path: str,
):
"""Request generation wrapper for the loading process."""
# Set the draft model path if it exists
load_data = data.model_dump()
if draft_model_path:
load_data["draft"]["draft_model_dir"] = draft_model_path
load_status = model.load_model_gen(
model_path, skip_wait=data.skip_queue, **load_data
)
try:
async for module, modules, model_type in load_status:
if module != 0:
response = ModelLoadResponse(
model_type=model_type,
module=module,
modules=modules,
status="processing",
)
yield response.model_dump_json()
if module == modules:
response = ModelLoadResponse(
model_type=model_type,
module=module,
modules=modules,
status="finished",
)
yield response.model_dump_json()
except CancelledError:
# Get out if the request gets disconnected
handle_request_disconnect(
"Model load cancelled by user. "
"Please make sure to run unload to free up resources."
)
except Exception as exc:
yield get_generator_error(str(exc))