Auth: Make key permission check work on Requests

Pass a request and internally unwrap the headers. In addition, allow
X-admin-key to get checked in an API key request.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-07-11 11:16:24 -04:00
parent ff15eed85d
commit b9a58ff01b
2 changed files with 35 additions and 17 deletions

View File

@@ -1,16 +1,15 @@
import asyncio
import pathlib
from fastapi import APIRouter, Depends, HTTPException, Header, Request
from fastapi import APIRouter, Depends, HTTPException, Request
from sse_starlette import EventSourceResponse
from sys import maxsize
from typing import Optional
from common import config, model, gen_logging, sampling
from common.auth import check_admin_key, check_api_key, validate_key_permission
from common.auth import check_admin_key, check_api_key, get_key_permission
from common.downloader import hf_repo_download
from common.networking import handle_request_error, run_with_request_disconnect
from common.templating import PromptTemplate, get_all_templates
from common.utils import coalesce, unwrap
from common.utils import unwrap
from endpoints.OAI.types.auth import AuthPermissionResponse
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
from endpoints.OAI.types.chat_completion import (
@@ -432,24 +431,18 @@ async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse:
@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)])
async def get_key_permission(
x_admin_key: Optional[str] = Header(None),
x_api_key: Optional[str] = Header(None),
authorization: Optional[str] = Header(None),
) -> AuthPermissionResponse:
async def key_permission(request: Request) -> AuthPermissionResponse:
"""
Gets the access level/permission of a provided key in headers.
Priority:
- X-api-key
- X-admin-key
- Authorization
- X-admin-key
- X-api-key
"""
test_key = coalesce(x_admin_key, x_api_key, authorization)
try:
permission = await validate_key_permission(test_key)
permission = get_key_permission(request)
return AuthPermissionResponse(permission=permission)
except ValueError as exc:
error_message = handle_request_error(str(exc)).error.message