Merge pull request #44 from bdashore3/main

This commit is contained in:
Cohee
2023-06-03 13:39:12 +03:00
committed by GitHub
3 changed files with 35 additions and 2 deletions

3
.gitignore vendored
View File

@@ -134,4 +134,5 @@ test.wav
model.pt
.DS_Store
.chroma
/.chroma_db
/.chroma_db
api_key.txt

View File

@@ -134,6 +134,7 @@ cd SillyTavern-extras
| `--port` | Specify the port on which the application is hosted. Default: **5100** |
| `--listen` | Host the app on the local network |
| `--share` | Share the app on CloudFlare tunnel |
| `--secure` | Adds API key authentication requirements. Highly recommended when paired with share! |
| `--cpu` | Run the models on the CPU instead of CUDA |
| `--summarization-model` | Load a custom summarization model.<br>Expects a HuggingFace model ID.<br>Default: [Qiliang/bart-large-cnn-samsum-ChatGPT_v3](https://huggingface.co/Qiliang/bart-large-cnn-samsum-ChatGPT_v3) |
| `--classification-model` | Load a custom sentiment classification model.<br>Expects a HuggingFace model ID.<br>Default (6 emotions): [nateraw/bert-base-uncased-emotion](https://huggingface.co/nateraw/bert-base-uncased-emotion)<br>Other solid option is (28 emotions): [joeddav/distilbert-base-uncased-go-emotions-student](https://huggingface.co/joeddav/distilbert-base-uncased-go-emotions-student)<br>For Chinese language: [touch20032003/xuyuan-trial-sentiment-bert-chinese](https://huggingface.co/touch20032003/xuyuan-trial-sentiment-bert-chinese) |

View File

@@ -19,6 +19,7 @@ import torch
import time
import os
import gc
import secrets
from PIL import Image
import base64
from io import BytesIO
@@ -65,6 +66,9 @@ parser.add_argument("--embedding-model", help="Load a custom text embedding mode
parser.add_argument("--chroma-host", help="Host IP for a remote ChromaDB instance")
parser.add_argument("--chroma-port", help="HTTP port for a remote ChromaDB instance (defaults to 8000)")
parser.add_argument("--chroma-folder", help="Path for chromadb persistence folder", default='.chroma_db')
parser.add_argument(
"--secure", action="store_true", help="Enforces the use of an API key"
)
sd_group = parser.add_mutually_exclusive_group()
@@ -420,12 +424,39 @@ def image_to_base64(image: Image, quality: int = 75) -> str:
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str
# Reads an API key from an already existing file. If that file doesn't exist, create it.
if args.secure:
try:
with open("api_key.txt", "r") as txt:
api_key = txt.read().replace('\n', '')
except:
api_key = secrets.token_hex(5)
with open("api_key.txt", "w") as txt:
txt.write(api_key)
print(f"Your API key is {api_key}")
elif args.share and args.secure != True:
print("WARNING: This instance is publicly exposed without an API key! It is highly recommended to restart with the \"--secure\" argument!")
else:
print("No API key given because you are running locally.")
@app.before_request
# Request time measuring
def before_request():
# Request time measuring
request.start_time = time.time()
# Checks if an API key is present and valid, otherwise return unauthorized
# The options check is required so CORS doesn't get angry
try:
if request.method != 'OPTIONS' and args.secure and request.authorization.token != api_key:
print(f"WARNING: Unauthorized API key access from {request.remote_addr}")
response = jsonify({ 'error': '401: Invalid API key' })
response.status_code = 401
return response
except Exception as e:
print(f"API key check error: {e}")
return "401 Unauthorized\n{}\n\n".format(e), 401
@app.after_request
def after_request(response):