mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-11 06:20:12 +00:00
Add chromadb
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -132,4 +132,5 @@ debug.png
|
||||
test.wav
|
||||
/tts_samples
|
||||
model.pt
|
||||
.DS_Store
|
||||
.DS_Store
|
||||
.chroma
|
||||
62
README.md
62
README.md
@@ -113,6 +113,7 @@ cd SillyTavern-extras
|
||||
| `prompt` | SD prompt generation from text | ✔️ Yes |
|
||||
| `sd` | Stable Diffusion image generation | :x: No (✔️ remote) |
|
||||
| `tts` | [Silero TTS server](https://github.com/ouoertheo/silero-api-server) | :x: |
|
||||
| `chromadb` | Infinity context server | :x: No |
|
||||
|
||||
|
||||
## Additional options
|
||||
@@ -128,6 +129,7 @@ cd SillyTavern-extras
|
||||
| `--captioning-model` | Load a custom captioning model.<br>Expects a HuggingFace model ID.<br>Default: [Salesforce/blip-image-captioning-large](https://huggingface.co/Salesforce/blip-image-captioning-large) |
|
||||
| `--keyphrase-model` | Load a custom key phrase extraction model.<br>Expects a HuggingFace model ID.<br>Default: [ml6team/keyphrase-extraction-distilbert-inspec](https://huggingface.co/ml6team/keyphrase-extraction-distilbert-inspec) |
|
||||
| `--prompt-model` | Load a custom prompt generation model.<br>Expects a HuggingFace model ID.<br>Default: [FredZhang7/anime-anything-promptgen-v2](https://huggingface.co/FredZhang7/anime-anything-promptgen-v2) |
|
||||
| `--embedding-model` | Load a custom text embedding model.<br>Expects a HuggingFace model ID.<br>Default: [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) |
|
||||
| `--sd-model` | Load a custom Stable Diffusion image generation model.<br>Expects a HuggingFace model ID.<br>Default: [ckpt/anything-v4.5-vae-swapped](https://huggingface.co/ckpt/anything-v4.5-vae-swapped)<br>*Must have VAE pre-baked in PyTorch format or the output will look drab!* |
|
||||
| `--sd-cpu` | Force the Stable Diffusion generation pipeline to run on the CPU.<br>**SLOW!** |
|
||||
| `--sd-remote` | Use a remote SD backend.<br>**Supported APIs: [sd-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)** |
|
||||
@@ -319,3 +321,63 @@ WAV audio file.
|
||||
`GET /api/tts/sample/<voice_id>`
|
||||
#### **Output**
|
||||
WAV audio file.
|
||||
|
||||
### Add messages to chromadb
|
||||
`POST /api/chromadb`
|
||||
#### **Input**
|
||||
```
|
||||
{
|
||||
"chat_id": "chat1 - 2023-12-31",
|
||||
"messages": [
|
||||
{
|
||||
"id": "633a4bd1-8350-46b5-9ef2-f5d27acdecb7",
|
||||
"date": 1684164339877,
|
||||
"role": "user",
|
||||
"content": "Hello, AI world!"
|
||||
},
|
||||
{
|
||||
"id": "8a2ed36b-c212-4a1b-84a3-0ffbe0896506",
|
||||
"date": 1684164411759,
|
||||
"role": "assistant",
|
||||
"content": "Hello, Hooman!"
|
||||
},
|
||||
]
|
||||
}
|
||||
```
|
||||
#### **Output**
|
||||
```
|
||||
{ "count": 2 }
|
||||
```
|
||||
|
||||
### Query chromadb
|
||||
`POST /api/chromadb/query`
|
||||
#### **Input**
|
||||
```
|
||||
{
|
||||
"chat_id": "chat1 - 2023-12-31",
|
||||
"query": "Hello",
|
||||
"n_results": 2,
|
||||
}
|
||||
```
|
||||
#### **Output**
|
||||
```
|
||||
{
|
||||
"chat_id": "chat1 - 2023-12-31",
|
||||
"messages": [
|
||||
{
|
||||
"id": "633a4bd1-8350-46b5-9ef2-f5d27acdecb7",
|
||||
"date": 1684164339877,
|
||||
"role": "user",
|
||||
"content": "Hello, AI world!",
|
||||
"distance": 0.31
|
||||
},
|
||||
{
|
||||
"id": "8a2ed36b-c212-4a1b-84a3-0ffbe0896506",
|
||||
"date": 1684164411759,
|
||||
"role": "assistant",
|
||||
"content": "Hello, Hooman!",
|
||||
"distance": 0.29
|
||||
},
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
@@ -12,4 +12,5 @@ torchaudio==2.0.1+cu117
|
||||
accelerate
|
||||
transformers==4.28.1
|
||||
diffusers==0.16.1
|
||||
silero-api-server
|
||||
silero-api-server
|
||||
chromadb
|
||||
95
server.py
95
server.py
@@ -16,6 +16,7 @@ import base64
|
||||
from io import BytesIO
|
||||
from random import randint
|
||||
import webuiapi
|
||||
import hashlib
|
||||
from colorama import Fore, Style, init as colorama_init
|
||||
|
||||
colorama_init()
|
||||
@@ -31,6 +32,7 @@ DEFAULT_CAPTIONING_MODEL = 'Salesforce/blip-image-captioning-large'
|
||||
DEFAULT_KEYPHRASE_MODEL = 'ml6team/keyphrase-extraction-distilbert-inspec'
|
||||
DEFAULT_PROMPT_MODEL = 'FredZhang7/anime-anything-promptgen-v2'
|
||||
DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
|
||||
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
|
||||
DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
|
||||
DEFAULT_REMOTE_SD_PORT = 7860
|
||||
SILERO_SAMPLES_PATH = 'tts_samples'
|
||||
@@ -88,13 +90,17 @@ parser.add_argument('--keyphrase-model',
|
||||
help="Load a custom keyphrase extraction model")
|
||||
parser.add_argument('--prompt-model',
|
||||
help="Load a custom prompt generation model")
|
||||
parser.add_argument('--embedding-model',
|
||||
help="Load a custom text embedding model")
|
||||
|
||||
sd_group = parser.add_mutually_exclusive_group()
|
||||
|
||||
local_sd = sd_group.add_argument_group('sd-local')
|
||||
local_sd.add_argument('--sd-model',
|
||||
help="Load a custom SD image generation model")
|
||||
local_sd.add_argument('--sd-cpu',
|
||||
help="Force the SD pipeline to run on the CPU")
|
||||
|
||||
remote_sd = sd_group.add_argument_group('sd-remote')
|
||||
remote_sd.add_argument('--sd-remote', action='store_true',
|
||||
help="Use a remote backend for SD")
|
||||
@@ -119,6 +125,7 @@ classification_model = args.classification_model if args.classification_model el
|
||||
captioning_model = args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL
|
||||
keyphrase_model = args.keyphrase_model if args.keyphrase_model else DEFAULT_KEYPHRASE_MODEL
|
||||
prompt_model = args.prompt_model if args.prompt_model else DEFAULT_PROMPT_MODEL
|
||||
embedding_model = args.embedding_model if args.embedding_model else DEFAULT_EMBEDDING_MODEL
|
||||
|
||||
sd_use_remote = False if args.sd_model else True
|
||||
sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL
|
||||
@@ -203,6 +210,20 @@ if 'tts' in modules:
|
||||
tts_service.update_sample_text(SILERO_SAMPLE_TEXT)
|
||||
tts_service.generate_samples()
|
||||
|
||||
if 'chromadb' in modules:
|
||||
print('Initializing ChromaDB')
|
||||
import chromadb
|
||||
import posthog
|
||||
from chromadb.config import Settings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# disable chromadb telemetry
|
||||
posthog.capture = lambda *args, **kwargs: None
|
||||
chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
||||
chromadb_embedder = SentenceTransformer(embedding_model)
|
||||
chromadb_embed_fn = chromadb_embedder.encode
|
||||
|
||||
|
||||
PROMPT_PREFIX = "best quality, absurdres, "
|
||||
NEGATIVE_PROMPT = """lowres, bad anatomy, error body, error hair, error arm,
|
||||
error hands, bad hands, error fingers, bad fingers, missing fingers
|
||||
@@ -601,6 +622,80 @@ def tts_generate():
|
||||
def tts_play_sample(speaker: str):
|
||||
return send_from_directory(SILERO_SAMPLES_PATH, f"{speaker}.wav")
|
||||
|
||||
|
||||
@app.route("/api/chromadb", methods=['POST'])
|
||||
@require_module('chromadb')
|
||||
def chromadb_add_messages():
|
||||
data = request.get_json()
|
||||
if 'chat_id' not in data or not isinstance(data['chat_id'], str):
|
||||
abort(400, '"chat_id" is required')
|
||||
if 'messages' not in data or not isinstance(data['messages'], list):
|
||||
abort(400, '"messages" is required')
|
||||
|
||||
chat_id_md5 = hashlib.md5(data['chat_id'].encode()).hexdigest()
|
||||
collection = chromadb_client.get_or_create_collection(
|
||||
name=f'chat-{chat_id_md5}',
|
||||
embedding_function=chromadb_embed_fn
|
||||
)
|
||||
|
||||
documents = [m['content'] for m in data['messages']]
|
||||
ids = [m['id'] for m in data['messages']]
|
||||
metadatas = [{'role': m['role'], 'date': m['date']} for m in data['messages']]
|
||||
|
||||
collection.upsert(
|
||||
ids=ids,
|
||||
documents=documents,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
||||
return jsonify({'count': len(ids)})
|
||||
|
||||
|
||||
@app.route("/api/chromadb/query", methods=['POST'])
|
||||
@require_module('chromadb')
|
||||
def chromadb_query():
|
||||
data = request.get_json()
|
||||
if 'chat_id' not in data or not isinstance(data['chat_id'], str):
|
||||
abort(400, '"chat_id" is required')
|
||||
if 'query' not in data or not isinstance(data['query'], str):
|
||||
abort(400, '"query" is required')
|
||||
|
||||
if 'n_results' not in data or not isinstance(data['n_results'], int):
|
||||
n_results = 1
|
||||
else:
|
||||
n_results = data['n_results']
|
||||
|
||||
chat_id_md5 = hashlib.md5(data['chat_id'].encode()).hexdigest()
|
||||
collection = chromadb_client.get_or_create_collection(
|
||||
name=f'chat-{chat_id_md5}',
|
||||
embedding_function=chromadb_embed_fn
|
||||
)
|
||||
|
||||
n_results = min(collection.count(), n_results)
|
||||
query_result = collection.query(
|
||||
query_texts=[data['query']],
|
||||
n_results=n_results,
|
||||
)
|
||||
|
||||
print(query_result)
|
||||
|
||||
documents = query_result['documents'][0]
|
||||
ids = query_result['ids'][0]
|
||||
metadatas = query_result['metadatas'][0]
|
||||
distances = query_result['distances'][0]
|
||||
|
||||
messages = [
|
||||
{
|
||||
'id': ids[i],
|
||||
'date': metadatas[i]['date'],
|
||||
'role': metadatas[i]['role'],
|
||||
'content': documents[i],
|
||||
'distance': distances[i]
|
||||
} for i in range(len(ids))
|
||||
]
|
||||
|
||||
return jsonify(messages)
|
||||
|
||||
if args.share:
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
import inspect
|
||||
|
||||
Reference in New Issue
Block a user