Add chromadb

This commit is contained in:
Mark Ceter
2023-05-17 17:10:16 +00:00
parent 586f6dbd74
commit 6c281c3b59
4 changed files with 161 additions and 2 deletions

View File

@@ -16,6 +16,7 @@ import base64
from io import BytesIO
from random import randint
import webuiapi
import hashlib
from colorama import Fore, Style, init as colorama_init
colorama_init()
@@ -31,6 +32,7 @@ DEFAULT_CAPTIONING_MODEL = 'Salesforce/blip-image-captioning-large'
DEFAULT_KEYPHRASE_MODEL = 'ml6team/keyphrase-extraction-distilbert-inspec'
DEFAULT_PROMPT_MODEL = 'FredZhang7/anime-anything-promptgen-v2'
DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
DEFAULT_REMOTE_SD_PORT = 7860
SILERO_SAMPLES_PATH = 'tts_samples'
@@ -88,13 +90,17 @@ parser.add_argument('--keyphrase-model',
help="Load a custom keyphrase extraction model")
parser.add_argument('--prompt-model',
help="Load a custom prompt generation model")
parser.add_argument('--embedding-model',
help="Load a custom text embedding model")
sd_group = parser.add_mutually_exclusive_group()
local_sd = sd_group.add_argument_group('sd-local')
local_sd.add_argument('--sd-model',
help="Load a custom SD image generation model")
local_sd.add_argument('--sd-cpu',
help="Force the SD pipeline to run on the CPU")
remote_sd = sd_group.add_argument_group('sd-remote')
remote_sd.add_argument('--sd-remote', action='store_true',
help="Use a remote backend for SD")
@@ -119,6 +125,7 @@ classification_model = args.classification_model if args.classification_model el
captioning_model = args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL
keyphrase_model = args.keyphrase_model if args.keyphrase_model else DEFAULT_KEYPHRASE_MODEL
prompt_model = args.prompt_model if args.prompt_model else DEFAULT_PROMPT_MODEL
embedding_model = args.embedding_model if args.embedding_model else DEFAULT_EMBEDDING_MODEL
sd_use_remote = False if args.sd_model else True
sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL
@@ -203,6 +210,20 @@ if 'tts' in modules:
tts_service.update_sample_text(SILERO_SAMPLE_TEXT)
tts_service.generate_samples()
if 'chromadb' in modules:
print('Initializing ChromaDB')
import chromadb
import posthog
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
# disable chromadb telemetry
posthog.capture = lambda *args, **kwargs: None
chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
chromadb_embedder = SentenceTransformer(embedding_model)
chromadb_embed_fn = chromadb_embedder.encode
PROMPT_PREFIX = "best quality, absurdres, "
NEGATIVE_PROMPT = """lowres, bad anatomy, error body, error hair, error arm,
error hands, bad hands, error fingers, bad fingers, missing fingers
@@ -601,6 +622,80 @@ def tts_generate():
def tts_play_sample(speaker: str):
return send_from_directory(SILERO_SAMPLES_PATH, f"{speaker}.wav")
@app.route("/api/chromadb", methods=['POST'])
@require_module('chromadb')
def chromadb_add_messages():
data = request.get_json()
if 'chat_id' not in data or not isinstance(data['chat_id'], str):
abort(400, '"chat_id" is required')
if 'messages' not in data or not isinstance(data['messages'], list):
abort(400, '"messages" is required')
chat_id_md5 = hashlib.md5(data['chat_id'].encode()).hexdigest()
collection = chromadb_client.get_or_create_collection(
name=f'chat-{chat_id_md5}',
embedding_function=chromadb_embed_fn
)
documents = [m['content'] for m in data['messages']]
ids = [m['id'] for m in data['messages']]
metadatas = [{'role': m['role'], 'date': m['date']} for m in data['messages']]
collection.upsert(
ids=ids,
documents=documents,
metadatas=metadatas,
)
return jsonify({'count': len(ids)})
@app.route("/api/chromadb/query", methods=['POST'])
@require_module('chromadb')
def chromadb_query():
data = request.get_json()
if 'chat_id' not in data or not isinstance(data['chat_id'], str):
abort(400, '"chat_id" is required')
if 'query' not in data or not isinstance(data['query'], str):
abort(400, '"query" is required')
if 'n_results' not in data or not isinstance(data['n_results'], int):
n_results = 1
else:
n_results = data['n_results']
chat_id_md5 = hashlib.md5(data['chat_id'].encode()).hexdigest()
collection = chromadb_client.get_or_create_collection(
name=f'chat-{chat_id_md5}',
embedding_function=chromadb_embed_fn
)
n_results = min(collection.count(), n_results)
query_result = collection.query(
query_texts=[data['query']],
n_results=n_results,
)
print(query_result)
documents = query_result['documents'][0]
ids = query_result['ids'][0]
metadatas = query_result['metadatas'][0]
distances = query_result['distances'][0]
messages = [
{
'id': ids[i],
'date': metadatas[i]['date'],
'role': metadatas[i]['role'],
'content': documents[i],
'distance': distances[i]
} for i in range(len(ids))
]
return jsonify(messages)
if args.share:
from flask_cloudflared import _run_cloudflared
import inspect