Auth: Fix methods for writing and validation

These were not working properly. Make the YAML file get written
to properly and the validator to return a 401 when the bearer
token is invalid.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-11-15 00:19:15 -05:00
parent cb8da7f092
commit bbb59d0747

37
auth.py
View File

@@ -1,7 +1,7 @@
import secrets import secrets
import yaml import yaml
from fastapi import Header, HTTPException from fastapi import Header, HTTPException
from typing import Optional from typing import Optional, Dict
""" """
This method of authorization is pretty insecure, but since TabbyAPI is a local This method of authorization is pretty insecure, but since TabbyAPI is a local
@@ -16,13 +16,18 @@ class AuthKeys:
self.api_key = api_key self.api_key = api_key
self.admin_key = admin_key self.admin_key = admin_key
def __init__(self, d: Dict[str, str] = None):
for key, value in d.items():
setattr(self, key, value)
auth_keys: Optional[AuthKeys] = None auth_keys: Optional[AuthKeys] = None
def load_auth_keys(): def load_auth_keys():
global auth_keys global auth_keys
try: try:
with open("api_tokens.yml", "r") as auth_file: with open("api_tokens.yml", "r") as auth_file:
auth_keys = yaml.safe_load(auth_file) auth_keys_dict = yaml.safe_load(auth_file)
auth_keys = AuthKeys(d = auth_keys_dict)
except: except:
new_auth_keys = AuthKeys( new_auth_keys = AuthKeys(
api_key = secrets.token_hex(16), api_key = secrets.token_hex(16),
@@ -31,24 +36,34 @@ def load_auth_keys():
auth_keys = new_auth_keys auth_keys = new_auth_keys
with open("api_tokens.yml", "w") as auth_file: with open("api_tokens.yml", "w") as auth_file:
yaml.dump(auth_keys, auth_file) yaml.safe_dump(vars(auth_keys), auth_file, default_flow_style=False)
def check_api_key(x_api_key: str = Header(None), authorization: str = Header(None)): def check_api_key(x_api_key: str = Header(None), authorization: str = Header(None)):
if x_api_key and x_api_key == auth_keys.api_key: if x_api_key:
return x_api_key if x_api_key in auth_keys.api_key:
return x_api_key
else:
raise HTTPException(401, "Invalid API key")
elif authorization: elif authorization:
split_key = authorization.split(" ") split_key = authorization.split(" ")
if split_key[0].lower() == "bearer" and split_key[1] == auth_keys.api_key: if split_key[0].lower() == "bearer" and split_key[1] in auth_keys.api_key:
return authorization return authorization
else:
raise HTTPException(401, "Invalid API key")
else: else:
raise HTTPException(401, "Invalid API key") raise HTTPException(401, "Please provide an API key")
def check_admin_key(x_admin_key: str = Header(None), authorization: str = Header(None)): def check_admin_key(x_admin_key: str = Header(None), authorization: str = Header(None)):
if x_admin_key and x_admin_key == auth_keys.admin_key: if x_admin_key:
return x_admin_key if x_admin_key == auth_keys.admin_key:
return x_admin_key
else:
raise HTTPException(401, "Invalid admin key")
elif authorization: elif authorization:
split_key = authorization.split(" ") split_key = authorization.split(" ")
if split_key[0].lower() == "bearer" and split_key[1] == auth_keys.admin_key: if split_key[0].lower() == "bearer" and split_key[1] in auth_keys.admin_key:
return authorization return authorization
else:
raise HTTPException(401, "Invalid admin key")
else: else:
raise HTTPException(401, "Invalid admin key") raise HTTPException(401, "Please provide an admin key")