From 6c281c3b5907df0feb89fb8b014362a75dbd6d4a Mon Sep 17 00:00:00 2001
From: Mark Ceter <133643956+maceter@users.noreply.github.com>
Date: Wed, 17 May 2023 17:10:16 +0000
Subject: [PATCH] Add chromadb
---
.gitignore | 3 +-
README.md | 62 +++++++++++++++++++++++++
requirements-complete.txt | 3 +-
server.py | 95 +++++++++++++++++++++++++++++++++++++++
4 files changed, 161 insertions(+), 2 deletions(-)
diff --git a/.gitignore b/.gitignore
index 481e3e8..44f0991 100644
--- a/.gitignore
+++ b/.gitignore
@@ -132,4 +132,5 @@ debug.png
test.wav
/tts_samples
model.pt
-.DS_Store
\ No newline at end of file
+.DS_Store
+.chroma
\ No newline at end of file
diff --git a/README.md b/README.md
index 9c7e541..777c4e5 100644
--- a/README.md
+++ b/README.md
@@ -113,6 +113,7 @@ cd SillyTavern-extras
| `prompt` | SD prompt generation from text | ✔️ Yes |
| `sd` | Stable Diffusion image generation | :x: No (✔️ remote) |
| `tts` | [Silero TTS server](https://github.com/ouoertheo/silero-api-server) | :x: |
+| `chromadb` | Infinity context server | :x: No |
## Additional options
@@ -128,6 +129,7 @@ cd SillyTavern-extras
| `--captioning-model` | Load a custom captioning model.
Expects a HuggingFace model ID.
Default: [Salesforce/blip-image-captioning-large](https://huggingface.co/Salesforce/blip-image-captioning-large) |
| `--keyphrase-model` | Load a custom key phrase extraction model.
Expects a HuggingFace model ID.
Default: [ml6team/keyphrase-extraction-distilbert-inspec](https://huggingface.co/ml6team/keyphrase-extraction-distilbert-inspec) |
| `--prompt-model` | Load a custom prompt generation model.
Expects a HuggingFace model ID.
Default: [FredZhang7/anime-anything-promptgen-v2](https://huggingface.co/FredZhang7/anime-anything-promptgen-v2) |
+| `--embedding-model` | Load a custom text embedding model.
Expects a HuggingFace model ID.
Default: [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) |
| `--sd-model` | Load a custom Stable Diffusion image generation model.
Expects a HuggingFace model ID.
Default: [ckpt/anything-v4.5-vae-swapped](https://huggingface.co/ckpt/anything-v4.5-vae-swapped)
*Must have VAE pre-baked in PyTorch format or the output will look drab!* |
| `--sd-cpu` | Force the Stable Diffusion generation pipeline to run on the CPU.
**SLOW!** |
| `--sd-remote` | Use a remote SD backend.
**Supported APIs: [sd-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)** |
@@ -319,3 +321,63 @@ WAV audio file.
`GET /api/tts/sample/`
#### **Output**
WAV audio file.
+
+### Add messages to chromadb
+`POST /api/chromadb`
+#### **Input**
+```
+{
+ "chat_id": "chat1 - 2023-12-31",
+ "messages": [
+ {
+ "id": "633a4bd1-8350-46b5-9ef2-f5d27acdecb7",
+ "date": 1684164339877,
+ "role": "user",
+ "content": "Hello, AI world!"
+ },
+ {
+ "id": "8a2ed36b-c212-4a1b-84a3-0ffbe0896506",
+ "date": 1684164411759,
+ "role": "assistant",
+ "content": "Hello, Hooman!"
+ },
+ ]
+}
+```
+#### **Output**
+```
+{ "count": 2 }
+```
+
+### Query chromadb
+`POST /api/chromadb/query`
+#### **Input**
+```
+{
+ "chat_id": "chat1 - 2023-12-31",
+ "query": "Hello",
+ "n_results": 2,
+}
+```
+#### **Output**
+```
+{
+ "chat_id": "chat1 - 2023-12-31",
+ "messages": [
+ {
+ "id": "633a4bd1-8350-46b5-9ef2-f5d27acdecb7",
+ "date": 1684164339877,
+ "role": "user",
+ "content": "Hello, AI world!",
+ "distance": 0.31
+ },
+ {
+ "id": "8a2ed36b-c212-4a1b-84a3-0ffbe0896506",
+ "date": 1684164411759,
+ "role": "assistant",
+ "content": "Hello, Hooman!",
+ "distance": 0.29
+ },
+ ]
+}
+```
diff --git a/requirements-complete.txt b/requirements-complete.txt
index db35749..d08c980 100644
--- a/requirements-complete.txt
+++ b/requirements-complete.txt
@@ -12,4 +12,5 @@ torchaudio==2.0.1+cu117
accelerate
transformers==4.28.1
diffusers==0.16.1
-silero-api-server
\ No newline at end of file
+silero-api-server
+chromadb
\ No newline at end of file
diff --git a/server.py b/server.py
index 927ba28..0021231 100644
--- a/server.py
+++ b/server.py
@@ -16,6 +16,7 @@ import base64
from io import BytesIO
from random import randint
import webuiapi
+import hashlib
from colorama import Fore, Style, init as colorama_init
colorama_init()
@@ -31,6 +32,7 @@ DEFAULT_CAPTIONING_MODEL = 'Salesforce/blip-image-captioning-large'
DEFAULT_KEYPHRASE_MODEL = 'ml6team/keyphrase-extraction-distilbert-inspec'
DEFAULT_PROMPT_MODEL = 'FredZhang7/anime-anything-promptgen-v2'
DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
+DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
DEFAULT_REMOTE_SD_PORT = 7860
SILERO_SAMPLES_PATH = 'tts_samples'
@@ -88,13 +90,17 @@ parser.add_argument('--keyphrase-model',
help="Load a custom keyphrase extraction model")
parser.add_argument('--prompt-model',
help="Load a custom prompt generation model")
+parser.add_argument('--embedding-model',
+ help="Load a custom text embedding model")
sd_group = parser.add_mutually_exclusive_group()
+
local_sd = sd_group.add_argument_group('sd-local')
local_sd.add_argument('--sd-model',
help="Load a custom SD image generation model")
local_sd.add_argument('--sd-cpu',
help="Force the SD pipeline to run on the CPU")
+
remote_sd = sd_group.add_argument_group('sd-remote')
remote_sd.add_argument('--sd-remote', action='store_true',
help="Use a remote backend for SD")
@@ -119,6 +125,7 @@ classification_model = args.classification_model if args.classification_model el
captioning_model = args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL
keyphrase_model = args.keyphrase_model if args.keyphrase_model else DEFAULT_KEYPHRASE_MODEL
prompt_model = args.prompt_model if args.prompt_model else DEFAULT_PROMPT_MODEL
+embedding_model = args.embedding_model if args.embedding_model else DEFAULT_EMBEDDING_MODEL
sd_use_remote = False if args.sd_model else True
sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL
@@ -203,6 +210,20 @@ if 'tts' in modules:
tts_service.update_sample_text(SILERO_SAMPLE_TEXT)
tts_service.generate_samples()
+if 'chromadb' in modules:
+ print('Initializing ChromaDB')
+ import chromadb
+ import posthog
+ from chromadb.config import Settings
+ from sentence_transformers import SentenceTransformer
+
+ # disable chromadb telemetry
+ posthog.capture = lambda *args, **kwargs: None
+ chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
+ chromadb_embedder = SentenceTransformer(embedding_model)
+ chromadb_embed_fn = chromadb_embedder.encode
+
+
PROMPT_PREFIX = "best quality, absurdres, "
NEGATIVE_PROMPT = """lowres, bad anatomy, error body, error hair, error arm,
error hands, bad hands, error fingers, bad fingers, missing fingers
@@ -601,6 +622,80 @@ def tts_generate():
def tts_play_sample(speaker: str):
return send_from_directory(SILERO_SAMPLES_PATH, f"{speaker}.wav")
+
+@app.route("/api/chromadb", methods=['POST'])
+@require_module('chromadb')
+def chromadb_add_messages():
+ data = request.get_json()
+ if 'chat_id' not in data or not isinstance(data['chat_id'], str):
+ abort(400, '"chat_id" is required')
+ if 'messages' not in data or not isinstance(data['messages'], list):
+ abort(400, '"messages" is required')
+
+ chat_id_md5 = hashlib.md5(data['chat_id'].encode()).hexdigest()
+ collection = chromadb_client.get_or_create_collection(
+ name=f'chat-{chat_id_md5}',
+ embedding_function=chromadb_embed_fn
+ )
+
+ documents = [m['content'] for m in data['messages']]
+ ids = [m['id'] for m in data['messages']]
+ metadatas = [{'role': m['role'], 'date': m['date']} for m in data['messages']]
+
+ collection.upsert(
+ ids=ids,
+ documents=documents,
+ metadatas=metadatas,
+ )
+
+ return jsonify({'count': len(ids)})
+
+
+@app.route("/api/chromadb/query", methods=['POST'])
+@require_module('chromadb')
+def chromadb_query():
+ data = request.get_json()
+ if 'chat_id' not in data or not isinstance(data['chat_id'], str):
+ abort(400, '"chat_id" is required')
+ if 'query' not in data or not isinstance(data['query'], str):
+ abort(400, '"query" is required')
+
+ if 'n_results' not in data or not isinstance(data['n_results'], int):
+ n_results = 1
+ else:
+ n_results = data['n_results']
+
+ chat_id_md5 = hashlib.md5(data['chat_id'].encode()).hexdigest()
+ collection = chromadb_client.get_or_create_collection(
+ name=f'chat-{chat_id_md5}',
+ embedding_function=chromadb_embed_fn
+ )
+
+ n_results = min(collection.count(), n_results)
+ query_result = collection.query(
+ query_texts=[data['query']],
+ n_results=n_results,
+ )
+
+ print(query_result)
+
+ documents = query_result['documents'][0]
+ ids = query_result['ids'][0]
+ metadatas = query_result['metadatas'][0]
+ distances = query_result['distances'][0]
+
+ messages = [
+ {
+ 'id': ids[i],
+ 'date': metadatas[i]['date'],
+ 'role': metadatas[i]['role'],
+ 'content': documents[i],
+ 'distance': distances[i]
+ } for i in range(len(ids))
+ ]
+
+ return jsonify(messages)
+
if args.share:
from flask_cloudflared import _run_cloudflared
import inspect