diff --git a/.gitignore b/.gitignore
index b3cb608..0d6f7c8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -134,4 +134,5 @@ test.wav
model.pt
.DS_Store
.chroma
-/.chroma_db
\ No newline at end of file
+/.chroma_db
+api_key.txt
\ No newline at end of file
diff --git a/README.md b/README.md
index 45d84b6..cdf1791 100644
--- a/README.md
+++ b/README.md
@@ -134,6 +134,7 @@ cd SillyTavern-extras
| `--port` | Specify the port on which the application is hosted. Default: **5100** |
| `--listen` | Host the app on the local network |
| `--share` | Share the app on CloudFlare tunnel |
+| `--secure` | Adds API key authentication requirements. Highly recommended when paired with share! |
| `--cpu` | Run the models on the CPU instead of CUDA |
| `--summarization-model` | Load a custom summarization model.
Expects a HuggingFace model ID.
Default: [Qiliang/bart-large-cnn-samsum-ChatGPT_v3](https://huggingface.co/Qiliang/bart-large-cnn-samsum-ChatGPT_v3) |
| `--classification-model` | Load a custom sentiment classification model.
Expects a HuggingFace model ID.
Default (6 emotions): [nateraw/bert-base-uncased-emotion](https://huggingface.co/nateraw/bert-base-uncased-emotion)
Other solid option is (28 emotions): [joeddav/distilbert-base-uncased-go-emotions-student](https://huggingface.co/joeddav/distilbert-base-uncased-go-emotions-student)
For Chinese language: [touch20032003/xuyuan-trial-sentiment-bert-chinese](https://huggingface.co/touch20032003/xuyuan-trial-sentiment-bert-chinese) |
diff --git a/server.py b/server.py
index 1136a6c..d1de3d2 100644
--- a/server.py
+++ b/server.py
@@ -19,6 +19,7 @@ import torch
import time
import os
import gc
+import secrets
from PIL import Image
import base64
from io import BytesIO
@@ -65,6 +66,9 @@ parser.add_argument("--embedding-model", help="Load a custom text embedding mode
parser.add_argument("--chroma-host", help="Host IP for a remote ChromaDB instance")
parser.add_argument("--chroma-port", help="HTTP port for a remote ChromaDB instance (defaults to 8000)")
parser.add_argument("--chroma-folder", help="Path for chromadb persistence folder", default='.chroma_db')
+parser.add_argument(
+ "--secure", action="store_true", help="Enforces the use of an API key"
+)
sd_group = parser.add_mutually_exclusive_group()
@@ -420,12 +424,39 @@ def image_to_base64(image: Image, quality: int = 75) -> str:
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str
+# Reads an API key from an already existing file. If that file doesn't exist, create it.
+if args.secure:
+ try:
+ with open("api_key.txt", "r") as txt:
+ api_key = txt.read().replace('\n', '')
+ except:
+ api_key = secrets.token_hex(5)
+ with open("api_key.txt", "w") as txt:
+ txt.write(api_key)
+
+ print(f"Your API key is {api_key}")
+elif args.share and args.secure != True:
+ print("WARNING: This instance is publicly exposed without an API key! It is highly recommended to restart with the \"--secure\" argument!")
+else:
+ print("No API key given because you are running locally.")
@app.before_request
-# Request time measuring
def before_request():
+ # Request time measuring
request.start_time = time.time()
+ # Checks if an API key is present and valid, otherwise return unauthorized
+ # The options check is required so CORS doesn't get angry
+ try:
+ if request.method != 'OPTIONS' and args.secure and request.authorization.token != api_key:
+ print(f"WARNING: Unauthorized API key access from {request.remote_addr}")
+ response = jsonify({ 'error': '401: Invalid API key' })
+ response.status_code = 401
+ return response
+ except Exception as e:
+ print(f"API key check error: {e}")
+ return "401 Unauthorized\n{}\n\n".format(e), 401
+
@app.after_request
def after_request(response):