OAI: Restrict list permissions for API keys

API keys are not allowed to view all the admin's models, templates,
draft models, loras, etc. Basically anything that can be viewed
on the filesystem outside of anything that's currently loaded is
not allowed to be returned unless an admin key is present.

This change helps preserve user privacy while not erroring out on
list endpoints that the OAI spec requires.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-07-11 14:06:03 -04:00
parent 10890913b8
commit 1f46a1130c
5 changed files with 119 additions and 60 deletions

View File

@@ -1,5 +1,6 @@
import pathlib
from common import model
from endpoints.OAI.types.lora import LoraCard, LoraList
@@ -12,3 +13,18 @@ def get_lora_list(lora_path: pathlib.Path):
lora_list.data.append(lora_card)
return lora_list
def get_active_loras():
if model.container:
active_loras = [
LoraCard(
id=pathlib.Path(lora.lora_path).parent.name,
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
)
for lora in model.container.get_loras()
]
else:
active_loras = []
return LoraList(data=active_loras)

View File

@@ -2,11 +2,12 @@ import pathlib
from asyncio import CancelledError
from typing import Optional
from common import model
from common import gen_logging, model
from common.networking import get_generator_error, handle_request_disconnect
from common.utils import unwrap
from endpoints.OAI.types.model import (
ModelCard,
ModelCardParameters,
ModelList,
ModelLoadRequest,
ModelLoadResponse,
@@ -31,6 +32,50 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N
return model_card_list
async def get_current_model_list(is_draft: bool = False):
"""Gets the current model in list format and with path only."""
current_models = []
# Make sure the model container exists
if model.container:
model_path = model.container.get_model_path(is_draft)
if model_path:
current_models.append(ModelCard(id=model_path.name))
return ModelList(data=current_models)
def get_current_model():
"""Gets the current model with all parameters."""
model_params = model.container.get_model_parameters()
draft_model_params = model_params.pop("draft", {})
if draft_model_params:
model_params["draft"] = ModelCard(
id=unwrap(draft_model_params.get("name"), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
else:
draft_model_params = None
model_card = ModelCard(
id=unwrap(model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(model_params),
logging=gen_logging.PREFERENCES,
)
if draft_model_params:
draft_card = ModelCard(
id=unwrap(draft_model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
model_card.parameters.draft = draft_card
return model_card
async def stream_model_load(
data: ModelLoadRequest,
model_path: pathlib.Path,