diff --git a/.gitignore b/.gitignore index 481e3e8..44f0991 100644 --- a/.gitignore +++ b/.gitignore @@ -132,4 +132,5 @@ debug.png test.wav /tts_samples model.pt -.DS_Store \ No newline at end of file +.DS_Store +.chroma \ No newline at end of file diff --git a/README.md b/README.md index 9c7e541..777c4e5 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,7 @@ cd SillyTavern-extras | `prompt` | SD prompt generation from text | ✔️ Yes | | `sd` | Stable Diffusion image generation | :x: No (✔️ remote) | | `tts` | [Silero TTS server](https://github.com/ouoertheo/silero-api-server) | :x: | +| `chromadb` | Infinity context server | :x: No | ## Additional options @@ -128,6 +129,7 @@ cd SillyTavern-extras | `--captioning-model` | Load a custom captioning model.
Expects a HuggingFace model ID.
Default: [Salesforce/blip-image-captioning-large](https://huggingface.co/Salesforce/blip-image-captioning-large) | | `--keyphrase-model` | Load a custom key phrase extraction model.
Expects a HuggingFace model ID.
Default: [ml6team/keyphrase-extraction-distilbert-inspec](https://huggingface.co/ml6team/keyphrase-extraction-distilbert-inspec) | | `--prompt-model` | Load a custom prompt generation model.
Expects a HuggingFace model ID.
Default: [FredZhang7/anime-anything-promptgen-v2](https://huggingface.co/FredZhang7/anime-anything-promptgen-v2) | +| `--embedding-model` | Load a custom text embedding model.
Expects a HuggingFace model ID.
Default: [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) | | `--sd-model` | Load a custom Stable Diffusion image generation model.
Expects a HuggingFace model ID.
Default: [ckpt/anything-v4.5-vae-swapped](https://huggingface.co/ckpt/anything-v4.5-vae-swapped)
*Must have VAE pre-baked in PyTorch format or the output will look drab!* | | `--sd-cpu` | Force the Stable Diffusion generation pipeline to run on the CPU.
**SLOW!** | | `--sd-remote` | Use a remote SD backend.
**Supported APIs: [sd-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)** | @@ -319,3 +321,63 @@ WAV audio file. `GET /api/tts/sample/` #### **Output** WAV audio file. + +### Add messages to chromadb +`POST /api/chromadb` +#### **Input** +``` +{ + "chat_id": "chat1 - 2023-12-31", + "messages": [ + { + "id": "633a4bd1-8350-46b5-9ef2-f5d27acdecb7", + "date": 1684164339877, + "role": "user", + "content": "Hello, AI world!" + }, + { + "id": "8a2ed36b-c212-4a1b-84a3-0ffbe0896506", + "date": 1684164411759, + "role": "assistant", + "content": "Hello, Hooman!" + }, + ] +} +``` +#### **Output** +``` +{ "count": 2 } +``` + +### Query chromadb +`POST /api/chromadb/query` +#### **Input** +``` +{ + "chat_id": "chat1 - 2023-12-31", + "query": "Hello", + "n_results": 2, +} +``` +#### **Output** +``` +{ + "chat_id": "chat1 - 2023-12-31", + "messages": [ + { + "id": "633a4bd1-8350-46b5-9ef2-f5d27acdecb7", + "date": 1684164339877, + "role": "user", + "content": "Hello, AI world!", + "distance": 0.31 + }, + { + "id": "8a2ed36b-c212-4a1b-84a3-0ffbe0896506", + "date": 1684164411759, + "role": "assistant", + "content": "Hello, Hooman!", + "distance": 0.29 + }, + ] +} +``` diff --git a/requirements-complete.txt b/requirements-complete.txt index db35749..d08c980 100644 --- a/requirements-complete.txt +++ b/requirements-complete.txt @@ -12,4 +12,5 @@ torchaudio==2.0.1+cu117 accelerate transformers==4.28.1 diffusers==0.16.1 -silero-api-server \ No newline at end of file +silero-api-server +chromadb \ No newline at end of file diff --git a/server.py b/server.py index 927ba28..0021231 100644 --- a/server.py +++ b/server.py @@ -16,6 +16,7 @@ import base64 from io import BytesIO from random import randint import webuiapi +import hashlib from colorama import Fore, Style, init as colorama_init colorama_init() @@ -31,6 +32,7 @@ DEFAULT_CAPTIONING_MODEL = 'Salesforce/blip-image-captioning-large' DEFAULT_KEYPHRASE_MODEL = 'ml6team/keyphrase-extraction-distilbert-inspec' DEFAULT_PROMPT_MODEL = 'FredZhang7/anime-anything-promptgen-v2' DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped" +DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2" DEFAULT_REMOTE_SD_HOST = "127.0.0.1" DEFAULT_REMOTE_SD_PORT = 7860 SILERO_SAMPLES_PATH = 'tts_samples' @@ -88,13 +90,17 @@ parser.add_argument('--keyphrase-model', help="Load a custom keyphrase extraction model") parser.add_argument('--prompt-model', help="Load a custom prompt generation model") +parser.add_argument('--embedding-model', + help="Load a custom text embedding model") sd_group = parser.add_mutually_exclusive_group() + local_sd = sd_group.add_argument_group('sd-local') local_sd.add_argument('--sd-model', help="Load a custom SD image generation model") local_sd.add_argument('--sd-cpu', help="Force the SD pipeline to run on the CPU") + remote_sd = sd_group.add_argument_group('sd-remote') remote_sd.add_argument('--sd-remote', action='store_true', help="Use a remote backend for SD") @@ -119,6 +125,7 @@ classification_model = args.classification_model if args.classification_model el captioning_model = args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL keyphrase_model = args.keyphrase_model if args.keyphrase_model else DEFAULT_KEYPHRASE_MODEL prompt_model = args.prompt_model if args.prompt_model else DEFAULT_PROMPT_MODEL +embedding_model = args.embedding_model if args.embedding_model else DEFAULT_EMBEDDING_MODEL sd_use_remote = False if args.sd_model else True sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL @@ -203,6 +210,20 @@ if 'tts' in modules: tts_service.update_sample_text(SILERO_SAMPLE_TEXT) tts_service.generate_samples() +if 'chromadb' in modules: + print('Initializing ChromaDB') + import chromadb + import posthog + from chromadb.config import Settings + from sentence_transformers import SentenceTransformer + + # disable chromadb telemetry + posthog.capture = lambda *args, **kwargs: None + chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False)) + chromadb_embedder = SentenceTransformer(embedding_model) + chromadb_embed_fn = chromadb_embedder.encode + + PROMPT_PREFIX = "best quality, absurdres, " NEGATIVE_PROMPT = """lowres, bad anatomy, error body, error hair, error arm, error hands, bad hands, error fingers, bad fingers, missing fingers @@ -601,6 +622,80 @@ def tts_generate(): def tts_play_sample(speaker: str): return send_from_directory(SILERO_SAMPLES_PATH, f"{speaker}.wav") + +@app.route("/api/chromadb", methods=['POST']) +@require_module('chromadb') +def chromadb_add_messages(): + data = request.get_json() + if 'chat_id' not in data or not isinstance(data['chat_id'], str): + abort(400, '"chat_id" is required') + if 'messages' not in data or not isinstance(data['messages'], list): + abort(400, '"messages" is required') + + chat_id_md5 = hashlib.md5(data['chat_id'].encode()).hexdigest() + collection = chromadb_client.get_or_create_collection( + name=f'chat-{chat_id_md5}', + embedding_function=chromadb_embed_fn + ) + + documents = [m['content'] for m in data['messages']] + ids = [m['id'] for m in data['messages']] + metadatas = [{'role': m['role'], 'date': m['date']} for m in data['messages']] + + collection.upsert( + ids=ids, + documents=documents, + metadatas=metadatas, + ) + + return jsonify({'count': len(ids)}) + + +@app.route("/api/chromadb/query", methods=['POST']) +@require_module('chromadb') +def chromadb_query(): + data = request.get_json() + if 'chat_id' not in data or not isinstance(data['chat_id'], str): + abort(400, '"chat_id" is required') + if 'query' not in data or not isinstance(data['query'], str): + abort(400, '"query" is required') + + if 'n_results' not in data or not isinstance(data['n_results'], int): + n_results = 1 + else: + n_results = data['n_results'] + + chat_id_md5 = hashlib.md5(data['chat_id'].encode()).hexdigest() + collection = chromadb_client.get_or_create_collection( + name=f'chat-{chat_id_md5}', + embedding_function=chromadb_embed_fn + ) + + n_results = min(collection.count(), n_results) + query_result = collection.query( + query_texts=[data['query']], + n_results=n_results, + ) + + print(query_result) + + documents = query_result['documents'][0] + ids = query_result['ids'][0] + metadatas = query_result['metadatas'][0] + distances = query_result['distances'][0] + + messages = [ + { + 'id': ids[i], + 'date': metadatas[i]['date'], + 'role': metadatas[i]['role'], + 'content': documents[i], + 'distance': distances[i] + } for i in range(len(ids)) + ] + + return jsonify(messages) + if args.share: from flask_cloudflared import _run_cloudflared import inspect