Add persistence to chromadb

Saves chromadb to specified persistence folder (with default being .chroma_db). Purge correctly purges persistence folder.
This commit is contained in:
BlipRanger
2023-05-30 18:14:00 -04:00
committed by GitHub
parent f011989b25
commit bf59b363ca

View File

@@ -64,6 +64,7 @@ parser.add_argument("--prompt-model", help="Load a custom prompt generation mode
parser.add_argument("--embedding-model", help="Load a custom text embedding model")
parser.add_argument("--chroma-host", help="Host IP for a remote ChromaDB instance")
parser.add_argument("--chroma-port", help="HTTP port for a remote ChromaDB instance (defaults to 8000)")
parser.add_argument("--chroma-folder", help="Path for chromadb persistence folder", default='.chroma_db')
sd_group = parser.add_mutually_exclusive_group()
@@ -247,8 +248,8 @@ if "chromadb" in modules:
# Also disable chromadb telemetry
posthog.capture = lambda *args, **kwargs: None
if args.chroma_host is None:
chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
print("ChromaDB is running in-memory. It will be cleared when the server is restarted!")
chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False, persist_directory=args.chroma_folder, chroma_db_impl='duckdb+parquet'))
print(f"ChromaDB is running in-memory with persistence. Persistence is stored in {args.chroma_folder}. Can be cleared by deleting the folder or purging db.")
else:
chroma_port=(
args.chroma_port if args.chroma_port else DEFAULT_CHROMA_PORT
@@ -732,8 +733,11 @@ def chromadb_purge():
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
)
deleted = collection.delete()
print("ChromaDB embeddings deleted", len(deleted))
count = collection.count()
collection.delete()
#Write deletion to persistent folder
chromadb_client.persist()
print("ChromaDB embeddings deleted", count)
return 'Ok', 200
@@ -788,7 +792,7 @@ def chromadb_export():
data = request.get_json()
if "chat_id" not in data or not isinstance(data["chat_id"], str):
abort(400, '"chat_id" is required')
chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
collection = chromadb_client.get_or_create_collection(
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn