mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-28 18:31:19 +00:00
Add chromadb
This commit is contained in:
95
server.py
95
server.py
@@ -16,6 +16,7 @@ import base64
|
||||
from io import BytesIO
|
||||
from random import randint
|
||||
import webuiapi
|
||||
import hashlib
|
||||
from colorama import Fore, Style, init as colorama_init
|
||||
|
||||
colorama_init()
|
||||
@@ -31,6 +32,7 @@ DEFAULT_CAPTIONING_MODEL = 'Salesforce/blip-image-captioning-large'
|
||||
DEFAULT_KEYPHRASE_MODEL = 'ml6team/keyphrase-extraction-distilbert-inspec'
|
||||
DEFAULT_PROMPT_MODEL = 'FredZhang7/anime-anything-promptgen-v2'
|
||||
DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
|
||||
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
|
||||
DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
|
||||
DEFAULT_REMOTE_SD_PORT = 7860
|
||||
SILERO_SAMPLES_PATH = 'tts_samples'
|
||||
@@ -88,13 +90,17 @@ parser.add_argument('--keyphrase-model',
|
||||
help="Load a custom keyphrase extraction model")
|
||||
parser.add_argument('--prompt-model',
|
||||
help="Load a custom prompt generation model")
|
||||
parser.add_argument('--embedding-model',
|
||||
help="Load a custom text embedding model")
|
||||
|
||||
sd_group = parser.add_mutually_exclusive_group()
|
||||
|
||||
local_sd = sd_group.add_argument_group('sd-local')
|
||||
local_sd.add_argument('--sd-model',
|
||||
help="Load a custom SD image generation model")
|
||||
local_sd.add_argument('--sd-cpu',
|
||||
help="Force the SD pipeline to run on the CPU")
|
||||
|
||||
remote_sd = sd_group.add_argument_group('sd-remote')
|
||||
remote_sd.add_argument('--sd-remote', action='store_true',
|
||||
help="Use a remote backend for SD")
|
||||
@@ -119,6 +125,7 @@ classification_model = args.classification_model if args.classification_model el
|
||||
captioning_model = args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL
|
||||
keyphrase_model = args.keyphrase_model if args.keyphrase_model else DEFAULT_KEYPHRASE_MODEL
|
||||
prompt_model = args.prompt_model if args.prompt_model else DEFAULT_PROMPT_MODEL
|
||||
embedding_model = args.embedding_model if args.embedding_model else DEFAULT_EMBEDDING_MODEL
|
||||
|
||||
sd_use_remote = False if args.sd_model else True
|
||||
sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL
|
||||
@@ -203,6 +210,20 @@ if 'tts' in modules:
|
||||
tts_service.update_sample_text(SILERO_SAMPLE_TEXT)
|
||||
tts_service.generate_samples()
|
||||
|
||||
if 'chromadb' in modules:
|
||||
print('Initializing ChromaDB')
|
||||
import chromadb
|
||||
import posthog
|
||||
from chromadb.config import Settings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# disable chromadb telemetry
|
||||
posthog.capture = lambda *args, **kwargs: None
|
||||
chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
||||
chromadb_embedder = SentenceTransformer(embedding_model)
|
||||
chromadb_embed_fn = chromadb_embedder.encode
|
||||
|
||||
|
||||
PROMPT_PREFIX = "best quality, absurdres, "
|
||||
NEGATIVE_PROMPT = """lowres, bad anatomy, error body, error hair, error arm,
|
||||
error hands, bad hands, error fingers, bad fingers, missing fingers
|
||||
@@ -601,6 +622,80 @@ def tts_generate():
|
||||
def tts_play_sample(speaker: str):
|
||||
return send_from_directory(SILERO_SAMPLES_PATH, f"{speaker}.wav")
|
||||
|
||||
|
||||
@app.route("/api/chromadb", methods=['POST'])
|
||||
@require_module('chromadb')
|
||||
def chromadb_add_messages():
|
||||
data = request.get_json()
|
||||
if 'chat_id' not in data or not isinstance(data['chat_id'], str):
|
||||
abort(400, '"chat_id" is required')
|
||||
if 'messages' not in data or not isinstance(data['messages'], list):
|
||||
abort(400, '"messages" is required')
|
||||
|
||||
chat_id_md5 = hashlib.md5(data['chat_id'].encode()).hexdigest()
|
||||
collection = chromadb_client.get_or_create_collection(
|
||||
name=f'chat-{chat_id_md5}',
|
||||
embedding_function=chromadb_embed_fn
|
||||
)
|
||||
|
||||
documents = [m['content'] for m in data['messages']]
|
||||
ids = [m['id'] for m in data['messages']]
|
||||
metadatas = [{'role': m['role'], 'date': m['date']} for m in data['messages']]
|
||||
|
||||
collection.upsert(
|
||||
ids=ids,
|
||||
documents=documents,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
||||
return jsonify({'count': len(ids)})
|
||||
|
||||
|
||||
@app.route("/api/chromadb/query", methods=['POST'])
|
||||
@require_module('chromadb')
|
||||
def chromadb_query():
|
||||
data = request.get_json()
|
||||
if 'chat_id' not in data or not isinstance(data['chat_id'], str):
|
||||
abort(400, '"chat_id" is required')
|
||||
if 'query' not in data or not isinstance(data['query'], str):
|
||||
abort(400, '"query" is required')
|
||||
|
||||
if 'n_results' not in data or not isinstance(data['n_results'], int):
|
||||
n_results = 1
|
||||
else:
|
||||
n_results = data['n_results']
|
||||
|
||||
chat_id_md5 = hashlib.md5(data['chat_id'].encode()).hexdigest()
|
||||
collection = chromadb_client.get_or_create_collection(
|
||||
name=f'chat-{chat_id_md5}',
|
||||
embedding_function=chromadb_embed_fn
|
||||
)
|
||||
|
||||
n_results = min(collection.count(), n_results)
|
||||
query_result = collection.query(
|
||||
query_texts=[data['query']],
|
||||
n_results=n_results,
|
||||
)
|
||||
|
||||
print(query_result)
|
||||
|
||||
documents = query_result['documents'][0]
|
||||
ids = query_result['ids'][0]
|
||||
metadatas = query_result['metadatas'][0]
|
||||
distances = query_result['distances'][0]
|
||||
|
||||
messages = [
|
||||
{
|
||||
'id': ids[i],
|
||||
'date': metadatas[i]['date'],
|
||||
'role': metadatas[i]['role'],
|
||||
'content': documents[i],
|
||||
'distance': distances[i]
|
||||
} for i in range(len(ids))
|
||||
]
|
||||
|
||||
return jsonify(messages)
|
||||
|
||||
if args.share:
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
import inspect
|
||||
|
||||
Reference in New Issue
Block a user