Initial work on multichat chroma

This commit is contained in:
BlipRanger
2023-06-24 14:49:08 -04:00
committed by GitHub
parent ff0e9a0ba3
commit 7b733e7498

View File

@@ -808,6 +808,68 @@ def chromadb_query():
return jsonify(messages)
@app.route("/api/chromadb/multiquery", methods=["POST"])
@require_module("chromadb")
def chromadb_multiquery():
data = request.get_json()
if "chat_list" not in data or not isinstance(data["chat_list"], list):
abort(400, '"chat_list" is required and should be a list')
if "query" not in data or not isinstance(data["query"], str):
abort(400, '"query" is required')
if "n_results" not in data or not isinstance(data["n_results"], int):
n_results = 1
else:
n_results = data["n_results"]
messages = []
for chat_id in data["chat_list"]:
if not isinstance(chat_id, str):
continue
try:
chat_id_md5 = hashlib.md5(chat_id.encode()).hexdigest()
collection = chromadb_client.get_collection(
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
)
# Skip this chat if the collection is empty
if collection.count() == 0:
continue
n_results_per_chat = min(collection.count(), n_results)
query_result = collection.query(
query_texts=[data["query"]],
n_results=n_results_per_chat,
)
documents = query_result["documents"][0]
ids = query_result["ids"][0]
metadatas = query_result["metadatas"][0]
distances = query_result["distances"][0]
chat_messages = [
{
"id": ids[i],
"date": metadatas[i]["date"],
"role": metadatas[i]["role"],
"meta": metadatas[i]["meta"],
"content": documents[i],
"distance": distances[i],
}
for i in range(len(ids))
]
messages.extend(chat_messages)
except Exception as e:
print(e)
#remove duplicate msgs
seen = set()
messages = [d for d in messages if not (d['content'] in seen or seen.add(d['content']))]
return jsonify(messages)
@app.route("/api/chromadb/export", methods=["POST"])
@require_module("chromadb")
@@ -817,7 +879,7 @@ def chromadb_export():
abort(400, '"chat_id" is required')
chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
collection = chromadb_client.get_or_create_collection(
collection = chromadb_client.get_collection(
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
)
collection_content = collection.get()
@@ -855,7 +917,7 @@ def chromadb_import():
collection = chromadb_client.get_or_create_collection(
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
)
documents = [item['document'] for item in content]
metadatas = [item['metadata'] for item in content]
ids = [item['id'] for item in content]