mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-05-01 03:41:24 +00:00
add vector embeddings provider
This commit is contained in:
56
server.py
56
server.py
@@ -12,6 +12,7 @@ from random import randint
|
||||
import secrets
|
||||
import sys
|
||||
import time
|
||||
from typing import List, Union
|
||||
import unicodedata
|
||||
|
||||
from colorama import Fore, Style, init as colorama_init
|
||||
@@ -19,6 +20,7 @@ import markdown
|
||||
|
||||
from PIL import Image
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
@@ -561,6 +563,55 @@ def api_edge_tts_generate():
|
||||
print(e)
|
||||
abort(500, data["voice"])
|
||||
|
||||
# ----------------------------------------
|
||||
# embeddings
|
||||
|
||||
sentence_embedder = None # populated when the module is loaded
|
||||
|
||||
@app.route("/api/embeddings/compute", methods=["POST"])
|
||||
@require_module("embeddings")
|
||||
def api_embeddings_compute():
|
||||
"""For making vector DB keys. Compute the vector embedding of one or more sentences of text.
|
||||
|
||||
Input format is JSON::
|
||||
|
||||
{"text": "Blah blah blah."}
|
||||
|
||||
or::
|
||||
|
||||
{"text": ["Blah blah blah.",
|
||||
...]}
|
||||
|
||||
Output is also JSON::
|
||||
|
||||
{"embedding": array}
|
||||
|
||||
or::
|
||||
|
||||
{"embedding": [array0,
|
||||
...]}
|
||||
|
||||
respectively.
|
||||
|
||||
This is the Extras backend for computing embeddings in the Vector Storage builtin extension.
|
||||
"""
|
||||
data = request.get_json()
|
||||
if "text" not in data:
|
||||
abort(400, '"text" is required')
|
||||
sentences: Union[str, List[str]] = data["text"]
|
||||
if not (isinstance(sentences, str) or (isinstance(sentences, list) and all(isinstance(x, str) for x in sentences))):
|
||||
abort(400, '"text" must be string or array of strings')
|
||||
vectors: Union[np.array, List[np.array]] = sentence_embedder.encode(sentences,
|
||||
show_progress_bar=True, # on ST-extras console
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True)
|
||||
# NumPy arrays are not JSON serializable, so convert to Python lists
|
||||
if isinstance(vectors, np.ndarray):
|
||||
vectors = vectors.tolist()
|
||||
else: # isinstance(vectors, list) and all(isinstance(x, np.ndarray) for x in vectors)
|
||||
vectors = [x.tolist() for x in vectors]
|
||||
return jsonify({"embedding": vectors})
|
||||
|
||||
# ----------------------------------------
|
||||
# chromadb
|
||||
|
||||
@@ -1045,6 +1096,11 @@ if "edge-tts" in modules:
|
||||
print("Initializing Edge TTS client")
|
||||
import tts_edge as edge
|
||||
|
||||
if "embeddings" in modules:
|
||||
print("Initializing embeddings")
|
||||
from sentence_transformers import SentenceTransformer
|
||||
sentence_embedder = SentenceTransformer(embedding_model, device=device_string)
|
||||
|
||||
if "chromadb" in modules:
|
||||
print("Initializing ChromaDB")
|
||||
import chromadb
|
||||
|
||||
Reference in New Issue
Block a user