diff --git a/auth.py b/auth.py index 38e3a48..3314034 100644 --- a/auth.py +++ b/auth.py @@ -1,7 +1,7 @@ import secrets import yaml from fastapi import Header, HTTPException -from typing import Optional +from typing import Optional, Dict """ This method of authorization is pretty insecure, but since TabbyAPI is a local @@ -16,13 +16,18 @@ class AuthKeys: self.api_key = api_key self.admin_key = admin_key + def __init__(self, d: Dict[str, str] = None): + for key, value in d.items(): + setattr(self, key, value) + auth_keys: Optional[AuthKeys] = None def load_auth_keys(): global auth_keys try: with open("api_tokens.yml", "r") as auth_file: - auth_keys = yaml.safe_load(auth_file) + auth_keys_dict = yaml.safe_load(auth_file) + auth_keys = AuthKeys(d = auth_keys_dict) except: new_auth_keys = AuthKeys( api_key = secrets.token_hex(16), @@ -31,24 +36,34 @@ def load_auth_keys(): auth_keys = new_auth_keys with open("api_tokens.yml", "w") as auth_file: - yaml.dump(auth_keys, auth_file) + yaml.safe_dump(vars(auth_keys), auth_file, default_flow_style=False) def check_api_key(x_api_key: str = Header(None), authorization: str = Header(None)): - if x_api_key and x_api_key == auth_keys.api_key: - return x_api_key + if x_api_key: + if x_api_key in auth_keys.api_key: + return x_api_key + else: + raise HTTPException(401, "Invalid API key") elif authorization: split_key = authorization.split(" ") - if split_key[0].lower() == "bearer" and split_key[1] == auth_keys.api_key: + if split_key[0].lower() == "bearer" and split_key[1] in auth_keys.api_key: return authorization + else: + raise HTTPException(401, "Invalid API key") else: - raise HTTPException(401, "Invalid API key") + raise HTTPException(401, "Please provide an API key") def check_admin_key(x_admin_key: str = Header(None), authorization: str = Header(None)): - if x_admin_key and x_admin_key == auth_keys.admin_key: - return x_admin_key + if x_admin_key: + if x_admin_key == auth_keys.admin_key: + return x_admin_key + else: + raise HTTPException(401, "Invalid admin key") elif authorization: split_key = authorization.split(" ") - if split_key[0].lower() == "bearer" and split_key[1] == auth_keys.admin_key: + if split_key[0].lower() == "bearer" and split_key[1] in auth_keys.admin_key: return authorization + else: + raise HTTPException(401, "Invalid admin key") else: - raise HTTPException(401, "Invalid admin key") + raise HTTPException(401, "Please provide an admin key")