diff --git a/common/auth.py b/common/auth.py index fa53262..7ba13de 100644 --- a/common/auth.py +++ b/common/auth.py @@ -5,11 +5,13 @@ application, it should be fine. import secrets import yaml -from fastapi import Header, HTTPException +from fastapi import Header, HTTPException, Request from pydantic import BaseModel from loguru import logger from typing import Optional +from common.utils import coalesce + class AuthKeys(BaseModel): """ @@ -75,7 +77,23 @@ def load_auth_keys(disable_from_config: bool): ) -async def validate_key_permission(test_key: str): +def get_key_permission(request: Request): + """ + Gets the key permission from a request. + + Internal only! Use the depends functions for incoming requests. + """ + + # Hyphens are okay here + test_key = coalesce( + request.headers.get("authorization"), + request.headers.get("x-admin-key"), + request.headers.get("x-api-key"), + ) + + if test_key is None: + raise ValueError("The provided authentication key is missing.") + if test_key.lower().startswith("bearer"): test_key = test_key.split(" ")[1] @@ -88,7 +106,9 @@ async def validate_key_permission(test_key: str): async def check_api_key( - x_api_key: str = Header(None), authorization: str = Header(None) + x_api_key: str = Header(None), + x_admin_key: str = Header(None), + authorization: str = Header(None), ): """Check if the API key is valid.""" @@ -101,6 +121,11 @@ async def check_api_key( raise HTTPException(401, "Invalid API key") return x_api_key + if x_admin_key: + if not AUTH_KEYS.verify_key(x_admin_key, "admin_key"): + raise HTTPException(401, "Invalid API key") + return x_admin_key + if authorization: split_key = authorization.split(" ") if len(split_key) < 2: diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 0e4f27b..78c29b1 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,16 +1,15 @@ import asyncio import pathlib -from fastapi import APIRouter, Depends, HTTPException, Header, Request +from fastapi import APIRouter, Depends, HTTPException, Request from sse_starlette import EventSourceResponse from sys import maxsize -from typing import Optional from common import config, model, gen_logging, sampling -from common.auth import check_admin_key, check_api_key, validate_key_permission +from common.auth import check_admin_key, check_api_key, get_key_permission from common.downloader import hf_repo_download from common.networking import handle_request_error, run_with_request_disconnect from common.templating import PromptTemplate, get_all_templates -from common.utils import coalesce, unwrap +from common.utils import unwrap from endpoints.OAI.types.auth import AuthPermissionResponse from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse from endpoints.OAI.types.chat_completion import ( @@ -432,24 +431,18 @@ async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse: @router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)]) -async def get_key_permission( - x_admin_key: Optional[str] = Header(None), - x_api_key: Optional[str] = Header(None), - authorization: Optional[str] = Header(None), -) -> AuthPermissionResponse: +async def key_permission(request: Request) -> AuthPermissionResponse: """ Gets the access level/permission of a provided key in headers. Priority: - - X-api-key - - X-admin-key - Authorization + - X-admin-key + - X-api-key """ - test_key = coalesce(x_admin_key, x_api_key, authorization) - try: - permission = await validate_key_permission(test_key) + permission = get_key_permission(request) return AuthPermissionResponse(permission=permission) except ValueError as exc: error_message = handle_request_error(str(exc)).error.message