add vector embeddings provider

This commit is contained in:
Juha Jeronen
2024-01-24 15:12:15 +02:00
parent 1b1ae11c70
commit 0e506b1ce8

View File

@@ -12,6 +12,7 @@ from random import randint
import secrets
import sys
import time
from typing import List, Union
import unicodedata
from colorama import Fore, Style, init as colorama_init
@@ -19,6 +20,7 @@ import markdown
from PIL import Image
import numpy as np
import torch
from transformers import pipeline
@@ -561,6 +563,55 @@ def api_edge_tts_generate():
print(e)
abort(500, data["voice"])
# ----------------------------------------
# embeddings
sentence_embedder = None # populated when the module is loaded
@app.route("/api/embeddings/compute", methods=["POST"])
@require_module("embeddings")
def api_embeddings_compute():
"""For making vector DB keys. Compute the vector embedding of one or more sentences of text.
Input format is JSON::
{"text": "Blah blah blah."}
or::
{"text": ["Blah blah blah.",
...]}
Output is also JSON::
{"embedding": array}
or::
{"embedding": [array0,
...]}
respectively.
This is the Extras backend for computing embeddings in the Vector Storage builtin extension.
"""
data = request.get_json()
if "text" not in data:
abort(400, '"text" is required')
sentences: Union[str, List[str]] = data["text"]
if not (isinstance(sentences, str) or (isinstance(sentences, list) and all(isinstance(x, str) for x in sentences))):
abort(400, '"text" must be string or array of strings')
vectors: Union[np.array, List[np.array]] = sentence_embedder.encode(sentences,
show_progress_bar=True, # on ST-extras console
convert_to_numpy=True,
normalize_embeddings=True)
# NumPy arrays are not JSON serializable, so convert to Python lists
if isinstance(vectors, np.ndarray):
vectors = vectors.tolist()
else: # isinstance(vectors, list) and all(isinstance(x, np.ndarray) for x in vectors)
vectors = [x.tolist() for x in vectors]
return jsonify({"embedding": vectors})
# ----------------------------------------
# chromadb
@@ -1045,6 +1096,11 @@ if "edge-tts" in modules:
print("Initializing Edge TTS client")
import tts_edge as edge
if "embeddings" in modules:
print("Initializing embeddings")
from sentence_transformers import SentenceTransformer
sentence_embedder = SentenceTransformer(embedding_model, device=device_string)
if "chromadb" in modules:
print("Initializing ChromaDB")
import chromadb