Add chromadb

This commit is contained in:
Mark Ceter
2023-05-17 17:10:16 +00:00
parent 586f6dbd74
commit 6c281c3b59
4 changed files with 161 additions and 2 deletions

3
.gitignore vendored
View File

@@ -132,4 +132,5 @@ debug.png
test.wav
/tts_samples
model.pt
.DS_Store
.DS_Store
.chroma

View File

@@ -113,6 +113,7 @@ cd SillyTavern-extras
| `prompt` | SD prompt generation from text | ✔️ Yes |
| `sd` | Stable Diffusion image generation | :x: No (✔️ remote) |
| `tts` | [Silero TTS server](https://github.com/ouoertheo/silero-api-server) | :x: |
| `chromadb` | Infinity context server | :x: No |
## Additional options
@@ -128,6 +129,7 @@ cd SillyTavern-extras
| `--captioning-model` | Load a custom captioning model.<br>Expects a HuggingFace model ID.<br>Default: [Salesforce/blip-image-captioning-large](https://huggingface.co/Salesforce/blip-image-captioning-large) |
| `--keyphrase-model` | Load a custom key phrase extraction model.<br>Expects a HuggingFace model ID.<br>Default: [ml6team/keyphrase-extraction-distilbert-inspec](https://huggingface.co/ml6team/keyphrase-extraction-distilbert-inspec) |
| `--prompt-model` | Load a custom prompt generation model.<br>Expects a HuggingFace model ID.<br>Default: [FredZhang7/anime-anything-promptgen-v2](https://huggingface.co/FredZhang7/anime-anything-promptgen-v2) |
| `--embedding-model` | Load a custom text embedding model.<br>Expects a HuggingFace model ID.<br>Default: [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) |
| `--sd-model` | Load a custom Stable Diffusion image generation model.<br>Expects a HuggingFace model ID.<br>Default: [ckpt/anything-v4.5-vae-swapped](https://huggingface.co/ckpt/anything-v4.5-vae-swapped)<br>*Must have VAE pre-baked in PyTorch format or the output will look drab!* |
| `--sd-cpu` | Force the Stable Diffusion generation pipeline to run on the CPU.<br>**SLOW!** |
| `--sd-remote` | Use a remote SD backend.<br>**Supported APIs: [sd-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)** |
@@ -319,3 +321,63 @@ WAV audio file.
`GET /api/tts/sample/<voice_id>`
#### **Output**
WAV audio file.
### Add messages to chromadb
`POST /api/chromadb`
#### **Input**
```
{
"chat_id": "chat1 - 2023-12-31",
"messages": [
{
"id": "633a4bd1-8350-46b5-9ef2-f5d27acdecb7",
"date": 1684164339877,
"role": "user",
"content": "Hello, AI world!"
},
{
"id": "8a2ed36b-c212-4a1b-84a3-0ffbe0896506",
"date": 1684164411759,
"role": "assistant",
"content": "Hello, Hooman!"
},
]
}
```
#### **Output**
```
{ "count": 2 }
```
### Query chromadb
`POST /api/chromadb/query`
#### **Input**
```
{
"chat_id": "chat1 - 2023-12-31",
"query": "Hello",
"n_results": 2,
}
```
#### **Output**
```
{
"chat_id": "chat1 - 2023-12-31",
"messages": [
{
"id": "633a4bd1-8350-46b5-9ef2-f5d27acdecb7",
"date": 1684164339877,
"role": "user",
"content": "Hello, AI world!",
"distance": 0.31
},
{
"id": "8a2ed36b-c212-4a1b-84a3-0ffbe0896506",
"date": 1684164411759,
"role": "assistant",
"content": "Hello, Hooman!",
"distance": 0.29
},
]
}
```

View File

@@ -12,4 +12,5 @@ torchaudio==2.0.1+cu117
accelerate
transformers==4.28.1
diffusers==0.16.1
silero-api-server
silero-api-server
chromadb

View File

@@ -16,6 +16,7 @@ import base64
from io import BytesIO
from random import randint
import webuiapi
import hashlib
from colorama import Fore, Style, init as colorama_init
colorama_init()
@@ -31,6 +32,7 @@ DEFAULT_CAPTIONING_MODEL = 'Salesforce/blip-image-captioning-large'
DEFAULT_KEYPHRASE_MODEL = 'ml6team/keyphrase-extraction-distilbert-inspec'
DEFAULT_PROMPT_MODEL = 'FredZhang7/anime-anything-promptgen-v2'
DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
DEFAULT_REMOTE_SD_PORT = 7860
SILERO_SAMPLES_PATH = 'tts_samples'
@@ -88,13 +90,17 @@ parser.add_argument('--keyphrase-model',
help="Load a custom keyphrase extraction model")
parser.add_argument('--prompt-model',
help="Load a custom prompt generation model")
parser.add_argument('--embedding-model',
help="Load a custom text embedding model")
sd_group = parser.add_mutually_exclusive_group()
local_sd = sd_group.add_argument_group('sd-local')
local_sd.add_argument('--sd-model',
help="Load a custom SD image generation model")
local_sd.add_argument('--sd-cpu',
help="Force the SD pipeline to run on the CPU")
remote_sd = sd_group.add_argument_group('sd-remote')
remote_sd.add_argument('--sd-remote', action='store_true',
help="Use a remote backend for SD")
@@ -119,6 +125,7 @@ classification_model = args.classification_model if args.classification_model el
captioning_model = args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL
keyphrase_model = args.keyphrase_model if args.keyphrase_model else DEFAULT_KEYPHRASE_MODEL
prompt_model = args.prompt_model if args.prompt_model else DEFAULT_PROMPT_MODEL
embedding_model = args.embedding_model if args.embedding_model else DEFAULT_EMBEDDING_MODEL
sd_use_remote = False if args.sd_model else True
sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL
@@ -203,6 +210,20 @@ if 'tts' in modules:
tts_service.update_sample_text(SILERO_SAMPLE_TEXT)
tts_service.generate_samples()
if 'chromadb' in modules:
print('Initializing ChromaDB')
import chromadb
import posthog
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
# disable chromadb telemetry
posthog.capture = lambda *args, **kwargs: None
chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
chromadb_embedder = SentenceTransformer(embedding_model)
chromadb_embed_fn = chromadb_embedder.encode
PROMPT_PREFIX = "best quality, absurdres, "
NEGATIVE_PROMPT = """lowres, bad anatomy, error body, error hair, error arm,
error hands, bad hands, error fingers, bad fingers, missing fingers
@@ -601,6 +622,80 @@ def tts_generate():
def tts_play_sample(speaker: str):
return send_from_directory(SILERO_SAMPLES_PATH, f"{speaker}.wav")
@app.route("/api/chromadb", methods=['POST'])
@require_module('chromadb')
def chromadb_add_messages():
data = request.get_json()
if 'chat_id' not in data or not isinstance(data['chat_id'], str):
abort(400, '"chat_id" is required')
if 'messages' not in data or not isinstance(data['messages'], list):
abort(400, '"messages" is required')
chat_id_md5 = hashlib.md5(data['chat_id'].encode()).hexdigest()
collection = chromadb_client.get_or_create_collection(
name=f'chat-{chat_id_md5}',
embedding_function=chromadb_embed_fn
)
documents = [m['content'] for m in data['messages']]
ids = [m['id'] for m in data['messages']]
metadatas = [{'role': m['role'], 'date': m['date']} for m in data['messages']]
collection.upsert(
ids=ids,
documents=documents,
metadatas=metadatas,
)
return jsonify({'count': len(ids)})
@app.route("/api/chromadb/query", methods=['POST'])
@require_module('chromadb')
def chromadb_query():
data = request.get_json()
if 'chat_id' not in data or not isinstance(data['chat_id'], str):
abort(400, '"chat_id" is required')
if 'query' not in data or not isinstance(data['query'], str):
abort(400, '"query" is required')
if 'n_results' not in data or not isinstance(data['n_results'], int):
n_results = 1
else:
n_results = data['n_results']
chat_id_md5 = hashlib.md5(data['chat_id'].encode()).hexdigest()
collection = chromadb_client.get_or_create_collection(
name=f'chat-{chat_id_md5}',
embedding_function=chromadb_embed_fn
)
n_results = min(collection.count(), n_results)
query_result = collection.query(
query_texts=[data['query']],
n_results=n_results,
)
print(query_result)
documents = query_result['documents'][0]
ids = query_result['ids'][0]
metadatas = query_result['metadatas'][0]
distances = query_result['distances'][0]
messages = [
{
'id': ids[i],
'date': metadatas[i]['date'],
'role': metadatas[i]['role'],
'content': documents[i],
'distance': distances[i]
} for i in range(len(ids))
]
return jsonify(messages)
if args.share:
from flask_cloudflared import _run_cloudflared
import inspect