From 8381305f70c296c7285c8dccbe92b4ac6085409e Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 2 Jun 2023 22:16:50 -0400 Subject: [PATCH] Server: Add API key for security When the extras server was hosted publicly, there was a huge security risk of anyone finding a cloudflare tunnel URL and directly querying API routes. However, this had a simple solution of implementing middleware to check if a generated API key is valid. Since the server is simple, the API key is a string of bytes stored in a textfile. If that textfile is deleted, extras will automatically create a new API key/textfile. Additionally, this is enabled via an optional argument to prevent local user irritation. Signed-off-by: kingbri --- .gitignore | 3 ++- README.md | 1 + server.py | 33 ++++++++++++++++++++++++++++++++- 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index b3cb608..0d6f7c8 100644 --- a/.gitignore +++ b/.gitignore @@ -134,4 +134,5 @@ test.wav model.pt .DS_Store .chroma -/.chroma_db \ No newline at end of file +/.chroma_db +api_key.txt \ No newline at end of file diff --git a/README.md b/README.md index 45d84b6..cdf1791 100644 --- a/README.md +++ b/README.md @@ -134,6 +134,7 @@ cd SillyTavern-extras | `--port` | Specify the port on which the application is hosted. Default: **5100** | | `--listen` | Host the app on the local network | | `--share` | Share the app on CloudFlare tunnel | +| `--secure` | Adds API key authentication requirements. Highly recommended when paired with share! | | `--cpu` | Run the models on the CPU instead of CUDA | | `--summarization-model` | Load a custom summarization model.
Expects a HuggingFace model ID.
Default: [Qiliang/bart-large-cnn-samsum-ChatGPT_v3](https://huggingface.co/Qiliang/bart-large-cnn-samsum-ChatGPT_v3) | | `--classification-model` | Load a custom sentiment classification model.
Expects a HuggingFace model ID.
Default (6 emotions): [nateraw/bert-base-uncased-emotion](https://huggingface.co/nateraw/bert-base-uncased-emotion)
Other solid option is (28 emotions): [joeddav/distilbert-base-uncased-go-emotions-student](https://huggingface.co/joeddav/distilbert-base-uncased-go-emotions-student)
For Chinese language: [touch20032003/xuyuan-trial-sentiment-bert-chinese](https://huggingface.co/touch20032003/xuyuan-trial-sentiment-bert-chinese) | diff --git a/server.py b/server.py index 1136a6c..d1de3d2 100644 --- a/server.py +++ b/server.py @@ -19,6 +19,7 @@ import torch import time import os import gc +import secrets from PIL import Image import base64 from io import BytesIO @@ -65,6 +66,9 @@ parser.add_argument("--embedding-model", help="Load a custom text embedding mode parser.add_argument("--chroma-host", help="Host IP for a remote ChromaDB instance") parser.add_argument("--chroma-port", help="HTTP port for a remote ChromaDB instance (defaults to 8000)") parser.add_argument("--chroma-folder", help="Path for chromadb persistence folder", default='.chroma_db') +parser.add_argument( + "--secure", action="store_true", help="Enforces the use of an API key" +) sd_group = parser.add_mutually_exclusive_group() @@ -420,12 +424,39 @@ def image_to_base64(image: Image, quality: int = 75) -> str: img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return img_str +# Reads an API key from an already existing file. If that file doesn't exist, create it. +if args.secure: + try: + with open("api_key.txt", "r") as txt: + api_key = txt.read().replace('\n', '') + except: + api_key = secrets.token_hex(5) + with open("api_key.txt", "w") as txt: + txt.write(api_key) + + print(f"Your API key is {api_key}") +elif args.share and args.secure != True: + print("WARNING: This instance is publicly exposed without an API key! It is highly recommended to restart with the \"--secure\" argument!") +else: + print("No API key given because you are running locally.") @app.before_request -# Request time measuring def before_request(): + # Request time measuring request.start_time = time.time() + # Checks if an API key is present and valid, otherwise return unauthorized + # The options check is required so CORS doesn't get angry + try: + if request.method != 'OPTIONS' and args.secure and request.authorization.token != api_key: + print(f"WARNING: Unauthorized API key access from {request.remote_addr}") + response = jsonify({ 'error': '401: Invalid API key' }) + response.status_code = 401 + return response + except Exception as e: + print(f"API key check error: {e}") + return "401 Unauthorized\n{}\n\n".format(e), 401 + @app.after_request def after_request(response):