mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-29 19:01:20 +00:00
Initial work on multichat chroma
This commit is contained in:
66
server.py
66
server.py
@@ -808,6 +808,68 @@ def chromadb_query():
|
||||
|
||||
return jsonify(messages)
|
||||
|
||||
@app.route("/api/chromadb/multiquery", methods=["POST"])
|
||||
@require_module("chromadb")
|
||||
def chromadb_multiquery():
|
||||
data = request.get_json()
|
||||
if "chat_list" not in data or not isinstance(data["chat_list"], list):
|
||||
abort(400, '"chat_list" is required and should be a list')
|
||||
if "query" not in data or not isinstance(data["query"], str):
|
||||
abort(400, '"query" is required')
|
||||
|
||||
if "n_results" not in data or not isinstance(data["n_results"], int):
|
||||
n_results = 1
|
||||
else:
|
||||
n_results = data["n_results"]
|
||||
|
||||
messages = []
|
||||
|
||||
for chat_id in data["chat_list"]:
|
||||
if not isinstance(chat_id, str):
|
||||
continue
|
||||
|
||||
try:
|
||||
chat_id_md5 = hashlib.md5(chat_id.encode()).hexdigest()
|
||||
collection = chromadb_client.get_collection(
|
||||
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
|
||||
)
|
||||
|
||||
# Skip this chat if the collection is empty
|
||||
if collection.count() == 0:
|
||||
continue
|
||||
|
||||
n_results_per_chat = min(collection.count(), n_results)
|
||||
query_result = collection.query(
|
||||
query_texts=[data["query"]],
|
||||
n_results=n_results_per_chat,
|
||||
)
|
||||
documents = query_result["documents"][0]
|
||||
ids = query_result["ids"][0]
|
||||
metadatas = query_result["metadatas"][0]
|
||||
distances = query_result["distances"][0]
|
||||
|
||||
chat_messages = [
|
||||
{
|
||||
"id": ids[i],
|
||||
"date": metadatas[i]["date"],
|
||||
"role": metadatas[i]["role"],
|
||||
"meta": metadatas[i]["meta"],
|
||||
"content": documents[i],
|
||||
"distance": distances[i],
|
||||
}
|
||||
for i in range(len(ids))
|
||||
]
|
||||
|
||||
messages.extend(chat_messages)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
#remove duplicate msgs
|
||||
seen = set()
|
||||
messages = [d for d in messages if not (d['content'] in seen or seen.add(d['content']))]
|
||||
|
||||
return jsonify(messages)
|
||||
|
||||
|
||||
@app.route("/api/chromadb/export", methods=["POST"])
|
||||
@require_module("chromadb")
|
||||
@@ -817,7 +879,7 @@ def chromadb_export():
|
||||
abort(400, '"chat_id" is required')
|
||||
|
||||
chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
|
||||
collection = chromadb_client.get_or_create_collection(
|
||||
collection = chromadb_client.get_collection(
|
||||
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
|
||||
)
|
||||
collection_content = collection.get()
|
||||
@@ -855,7 +917,7 @@ def chromadb_import():
|
||||
collection = chromadb_client.get_or_create_collection(
|
||||
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
|
||||
)
|
||||
|
||||
|
||||
documents = [item['document'] for item in content]
|
||||
metadatas = [item['metadata'] for item in content]
|
||||
ids = [item['id'] for item in content]
|
||||
|
||||
Reference in New Issue
Block a user