Merge pull request #221 from Technologicat/talkinghead-next6

Talkinghead next6, also vector DB embeddings
This commit is contained in:
Cohee
2024-01-24 16:37:19 +02:00
committed by GitHub
4 changed files with 219 additions and 33 deletions

View File

@@ -12,6 +12,7 @@ from random import randint
import secrets
import sys
import time
from typing import List, Union
import unicodedata
from colorama import Fore, Style, init as colorama_init
@@ -19,6 +20,7 @@ import markdown
from PIL import Image
import numpy as np
import torch
from transformers import pipeline
@@ -125,6 +127,10 @@ def index():
content = f.read()
return render_template_string(markdown.markdown(content, extensions=["tables"]))
@app.route("/api/modules", methods=["GET"])
def get_modules():
return jsonify({"modules": modules})
@app.route("/api/extensions", methods=["GET"])
def get_extensions():
extensions = dict(
@@ -477,10 +483,6 @@ def api_image_samplers():
return jsonify({"samplers": samplers})
@app.route("/api/modules", methods=["GET"])
def get_modules():
return jsonify({"modules": modules})
# ----------------------------------------
# tts
@@ -488,7 +490,7 @@ tts_service = None # populated when the module is loaded
@app.route("/api/tts/speakers", methods=["GET"])
@require_module("silero-tts")
def tts_speakers():
def api_tts_speakers():
voices = [
{
"name": speaker,
@@ -502,7 +504,7 @@ def tts_speakers():
# Added fix for Silero not working as new files were unable to be created if one already existed. - Rolyat 7/7/23
@app.route("/api/tts/generate", methods=["POST"])
@require_module("silero-tts")
def tts_generate():
def api_tts_generate():
voice = request.get_json()
if "text" not in voice or not isinstance(voice["text"], str):
abort(400, '"text" is required')
@@ -526,7 +528,7 @@ def tts_generate():
@app.route("/api/tts/sample/<speaker>", methods=["GET"])
@require_module("silero-tts")
def tts_play_sample(speaker: str):
def api_tts_play_sample(speaker: str):
return send_from_directory(SILERO_SAMPLES_PATH, f"{speaker}.wav")
# ----------------------------------------
@@ -536,13 +538,13 @@ edge = None # populated when the module is loaded
@app.route("/api/edge-tts/list", methods=["GET"])
@require_module("edge-tts")
def edge_tts_list():
def api_edge_tts_list():
voices = edge.get_voices()
return jsonify(voices)
@app.route("/api/edge-tts/generate", methods=["POST"])
@require_module("edge-tts")
def edge_tts_generate():
def api_edge_tts_generate():
data = request.get_json()
if "text" not in data or not isinstance(data["text"], str):
abort(400, '"text" is required')
@@ -561,6 +563,60 @@ def edge_tts_generate():
print(e)
abort(500, data["voice"])
# ----------------------------------------
# embeddings
sentence_embedder = None # populated when the module is loaded
@app.route("/api/embeddings/compute", methods=["POST"])
@require_module("embeddings")
def api_embeddings_compute():
"""For making vector DB keys. Compute the vector embedding of one or more sentences of text.
Input format is JSON::
{"text": "Blah blah blah."}
or::
{"text": ["Blah blah blah.",
...]}
Output is also JSON::
{"embedding": array}
or::
{"embedding": [array0,
...]}
respectively.
This is the Extras backend for computing embeddings in the Vector Storage builtin extension.
"""
data = request.get_json()
if "text" not in data:
abort(400, '"text" is required')
sentences: Union[str, List[str]] = data["text"]
if not (isinstance(sentences, str) or (isinstance(sentences, list) and all(isinstance(x, str) for x in sentences))):
abort(400, '"text" must be string or array of strings')
if isinstance(sentences, str):
nitems = 1
else:
nitems = len(sentences)
print(f"Computing vector embedding for {nitems} item{'s' if nitems != 1 else ''}")
vectors: Union[np.array, List[np.array]] = sentence_embedder.encode(sentences,
show_progress_bar=True, # on ST-extras console
convert_to_numpy=True,
normalize_embeddings=True)
# NumPy arrays are not JSON serializable, so convert to Python lists
if isinstance(vectors, np.ndarray):
vectors = vectors.tolist()
else: # isinstance(vectors, list) and all(isinstance(x, np.ndarray) for x in vectors)
vectors = [x.tolist() for x in vectors]
return jsonify({"embedding": vectors})
# ----------------------------------------
# chromadb
@@ -569,7 +625,7 @@ chromadb_embed_fn = None
@app.route("/api/chromadb", methods=["POST"])
@require_module("chromadb")
def chromadb_add_messages():
def api_chromadb_add_messages():
data = request.get_json()
if "chat_id" not in data or not isinstance(data["chat_id"], str):
abort(400, '"chat_id" is required')
@@ -598,7 +654,7 @@ def chromadb_add_messages():
@app.route("/api/chromadb/purge", methods=["POST"])
@require_module("chromadb")
def chromadb_purge():
def api_chromadb_purge():
data = request.get_json()
if "chat_id" not in data or not isinstance(data["chat_id"], str):
abort(400, '"chat_id" is required')
@@ -615,7 +671,7 @@ def chromadb_purge():
@app.route("/api/chromadb/query", methods=["POST"])
@require_module("chromadb")
def chromadb_query():
def api_chromadb_query():
data = request.get_json()
if "chat_id" not in data or not isinstance(data["chat_id"], str):
abort(400, '"chat_id" is required')
@@ -663,7 +719,7 @@ def chromadb_query():
@app.route("/api/chromadb/multiquery", methods=["POST"])
@require_module("chromadb")
def chromadb_multiquery():
def api_chromadb_multiquery():
data = request.get_json()
if "chat_list" not in data or not isinstance(data["chat_list"], list):
abort(400, '"chat_list" is required and should be a list')
@@ -726,7 +782,7 @@ def chromadb_multiquery():
@app.route("/api/chromadb/export", methods=["POST"])
@require_module("chromadb")
def chromadb_export():
def api_chromadb_export():
data = request.get_json()
if "chat_id" not in data or not isinstance(data["chat_id"], str):
abort(400, '"chat_id" is required')
@@ -765,7 +821,7 @@ def chromadb_export():
@app.route("/api/chromadb/import", methods=["POST"])
@require_module("chromadb")
def chromadb_import():
def api_chromadb_import():
data = request.get_json()
content = data['content']
if "chat_id" not in data or not isinstance(data["chat_id"], str):
@@ -1045,6 +1101,11 @@ if "edge-tts" in modules:
print("Initializing Edge TTS client")
import tts_edge as edge
if "embeddings" in modules:
print("Initializing embeddings")
from sentence_transformers import SentenceTransformer
sentence_embedder = SentenceTransformer(embedding_model, device=device_string)
if "chromadb" in modules:
print("Initializing ChromaDB")
import chromadb
@@ -1199,6 +1260,6 @@ if args.share:
cloudflare = _run_cloudflared(port)
print(f"{Fore.GREEN}{Style.NORMAL}Running on: {cloudflare}{Style.RESET_ALL}")
ignore_auth.append(tts_play_sample)
ignore_auth.append(api_tts_play_sample)
ignore_auth.append(api_talkinghead_result_feed)
app.run(host=host, port=port)

View File

@@ -222,12 +222,12 @@ The following postprocessing filters are available. Options for each filter are
**Transport**:
Currently, we provide some filters that simulate a lo-fi analog video look.
- `analog_lowres`: Simulates a low-resolution analog video signal by blurring the image.
- `analog_badhsync`: Simulates bad horizontal synchronization (hsync) of an analog video signal, causing a wavy effect that causes the outline of the character to ripple.
- `analog_distort`: Simulates a rippling, runaway hsync near the top or bottom edge of an image. This can happen with some equipment if the video cable is too long.
- `analog_vhsglitches`: Simulates a damaged 1980s VHS tape. In each 25 FPS frame, causes random lines to glitch with VHS noise.
- `analog_vhstracking`: Simulates a 1980s VHS tape with bad tracking. The image floats up and down, and a band of VHS noise appears at the bottom.
- `shift_distort`: A glitchy digital video transport as sometimes depicted in sci-fi, with random blocks of lines suddenly shifted horizontally temporarily.
**Display**:
@@ -261,6 +261,15 @@ The bloom works best on a dark background. We use `lumanoise` to add an imperfec
Note that we could also use the `translucency` filter to make the character translucent, e.g.: `["translucency", {"alpha": 0.7}]`.
Also, for some glitching video transport that shifts random blocks of lines horizontally, we could add these:
```
["shift_distort", {"strength": 0.05, "name": "shift_right"}],
["shift_distort", {"strength": -0.05, "name": "shift_left"}],
```
Having a unique name for each instance is important, because the name acts as a cache key.
#### Postprocessor example: cheap video camera, amber monochrome computer monitor
We first simulate a cheap video camera with low-quality optics via the `chromatic_aberration` and `vignetting` filters.

View File

@@ -7,6 +7,8 @@ As of January 2024, preferably to be completed before the next release.
#### Frontend
- See if we can get this working also with local classify now that we have a `set_emotion` API endpoint.
- Responsibilities: the client end should set the emotion when it calls classify, instead of relying on the extras server doing it internally when extras classify is called.
- Figure out why the crop filter doesn't help in positioning the `talkinghead` sprite in *MovingUI* mode.
- There must be some logic at the frontend side that reserves a square shape for the talkinghead sprite output,
regardless of the image dimensions or aspect ratio of the actual `result_feed`.
@@ -63,20 +65,13 @@ Not scheduled for now.
- The effect on speed will be small; the compute-heaviest part is the inference of the THA3 deep-learning model.
- Add more postprocessing filters. Possible ideas, no guarantee I'll ever get around to them:
- Pixelize, posterize (8-bit look)
- Analog video glitches
- Partition image into bands, move some left/right temporarily (for a few frames now that we can do that)
- Another effect of bad VHS hsync: dynamic "bending" effect near top edge:
- Distortion by horizontal movement
- Topmost row of pixels moves the most, then a smoothly decaying offset profile as a function of height (decaying to zero at maybe 20% of image height, measured from the top)
- The maximum offset flutters dynamically in a semi-regular, semi-unpredictable manner (use a superposition of three sine waves at different frequencies, as functions of time)
- Digital data connection glitches
- Apply to random rectangles; may need to persist for a few frames to animate and/or make them more noticeable
- May need to protect important regions like the character's head (approximately, from the template); we're after "Hollywood glitchy", not actually glitchy
- Types:
- Constant-color rectangle
- Missing data (zero out the alpha?)
- Blur (leads to replacing by average color, with controllable sigma)
- Zigzag deformation
- Zigzag deformation (perhaps not needed now that we have `shift_distort`, which is similar, but with a rectangular shape, and applied to full lines of video)
- Investigate if some particular emotions could use a small random per-frame oscillation applied to "iris_small",
for that anime "intense emotion" effect (since THA3 doesn't have a morph specifically for the specular reflections in the eyes).
@@ -93,7 +88,14 @@ Not scheduled for now.
- To save GPU resources, automatically pause animation when the web browser window with SillyTavern is not in focus. Resume when it regains focus.
- Needs a new API endpoint for pause/resume. Note the current `/api/talkinghead/unload` is actually a pause function (the client pauses, and
then just hides the live image), but there is currently no resume function (except `/api/talkinghead/load`, which requires sending an image file).
- Lip-sync talking animation to TTS output.
- THA3 has morphs for A, I, U, E, O, and the "mouth delta" shape Δ.
- This needs either:
- Realtime data from client
- Exists already! See `SillyTavern/public/scripts/extensions/tts/index.js`, function `playAudioData`. There's lip sync for VRM (VRoid).
Still need to investigate how the VRM plugin extracts phonemes from the audio data.
- Or if ST-extras generates the TTS output, then at least a start timestamp for the playback of a given TTS output audio file,
and a possibility to stop animating if the user stops the audio.
### Far future
@@ -104,10 +106,4 @@ Definitely not scheduled. Ideas for future enhancements.
- The algorithm should be cartoon-aware, some modern-day equivalent of waifu2x. A GAN such as 4x-AnimeSharp or Remacri would be nice, but too slow.
- Maybe the scaler should run at the client side to avoid the need to stream 1024x1024 PNGs.
- What JavaScript anime scalers are there, or which algorithms are simple enough for a small custom implementation?
- Lip-sync talking animation to TTS output.
- THA3 has morphs for A, I, U, E, O, and the "mouth delta" shape Δ.
- This needs either:
- Realtime data from client
- Or if ST-extras generates the TTS output, then at least a start timestamp for the playback of a given TTS output audio file,
and a possibility to stop animating if the user stops the audio.
- Group chats / visual novel mode / several talkingheads running simultaneously.

View File

@@ -124,6 +124,9 @@ class Postprocessor:
self.vhs_glitch_last_frame_no = defaultdict(lambda: 0.0)
self.vhs_glitch_last_image = defaultdict(lambda: None)
self.vhs_glitch_last_mask = defaultdict(lambda: None)
self.shift_distort_interval = defaultdict(lambda: 0.0)
self.shift_distort_last_frame_no = defaultdict(lambda: 0.0)
self.shift_distort_grid = defaultdict(lambda: None)
def render_into(self, image):
"""Apply current postprocess chain, modifying `image` in-place."""
@@ -492,6 +495,70 @@ class Postprocessor:
warped = warped.squeeze(0) # [1, c, h, w] -> [c, h, w]
image[:, :, :] = warped
def analog_distort(self, image: torch.tensor, *,
speed: float = 8.0,
strength: float = 0.1,
ripple_amplitude: float = 0.05,
ripple_density1: float = 4.0,
ripple_density2: Optional[float] = 13.0,
ripple_density3: Optional[float] = 27.0,
edge: str = "top") -> None:
"""[dynamic] Analog video signal distorted by a runaway hsync near the top or bottom edge.
A bad video cable connection can do this, e.g. when connecting a game console to a display
with an analog YPbPr component cable 10m in length. In reality, when I ran into this phenomenon,
the distortion only occurred for near-white images, but as glitch art, it looks better if it's
always applied at full strength.
`speed`: At speed 1.0, a full cycle of the rippling effect completes every `image_height` frames.
So effectively the cycle position updates by `speed * (1 / image_height)` at each frame.
`strength`: Base strength for maximum distortion at the edge of the image.
In units where the height and width of the image are both 2.0.
`ripple_amplitude`: Variation on top of `strength`.
`ripple_density1`: Like `density` in `analog_badhsync`, but in time. How many cycles the first
component wave completes per one cycle of the ripple effect.
`ripple_density2`: Like `ripple_density1`, but for the second component wave.
Set to `None` or to 0.0 to disable the second component wave.
`ripple_density3`: Like `ripple_density1`, but for the third component wave.
Set to `None` or to 0.0 to disable the third component wave.
`edge`: one of "top", "bottom". Near which edge of the image to apply the maximal distortion.
The distortion then decays to zero, with a quadratic profile, in 1/8 of the image height.
Note that "frame" here refers to the normalized frame number, at a reference of 25 FPS.
"""
c, h, w = image.shape
# Animation
# FPS correction happens automatically, because `frame_no` is normalized to CALIBRATION_FPS.
cycle_pos = (self.frame_no / h) * speed
cycle_pos = cycle_pos - float(int(cycle_pos)) # fractional part
cycle_pos *= 2.0 # full cycle = 2 units
# Deformation
# The spatial distort profile is a quadratic curve [0, 1], for 1/8 of the image height.
meshy = self._meshy
if edge == "top":
spatial_distort_profile = (torch.clamp(meshy + 0.75, max=0.0) * 4.0)**2 # distort near y = -1
else: # edge == "bottom":
spatial_distort_profile = (torch.clamp(meshy - 0.75, min=0.0) * 4.0)**2 # distort near y = +1
ripple_amplitude = ripple_amplitude
ripple = math.sin(ripple_density1 * cycle_pos * math.pi)
if ripple_density2:
ripple += math.sin(ripple_density2 * cycle_pos * math.pi)
if ripple_density3:
ripple += math.sin(ripple_density3 * cycle_pos * math.pi)
instantaneous_strength = (1.0 - ripple_amplitude) * strength + ripple_amplitude * ripple
# The minus sign: read coordinates toward the left -> shift the image toward the right.
meshx = self._meshx - instantaneous_strength * spatial_distort_profile
# Then just the usual incantation for applying a geometric distortion in Torch:
grid = torch.stack((meshx, meshy), 2)
grid = grid.unsqueeze(0) # batch of one
image_batch = image.unsqueeze(0) # batch of one -> [1, c, h, w]
warped = torch.nn.functional.grid_sample(image_batch, grid, mode="bilinear", padding_mode="border", align_corners=False)
warped = warped.squeeze(0) # [1, c, h, w] -> [c, h, w]
image[:, :, :] = warped
def _vhs_noise(self, image: torch.tensor, *,
height: int) -> torch.tensor:
"""Generate a horizontal band of noise that looks as if it came from a blank VHS tape.
@@ -615,6 +682,59 @@ class Postprocessor:
# fade = fade.unsqueeze(0) # [1, w]
# image[3, -noise_pixels:, :] = fade
def shift_distort(self, image: torch.tensor, *,
strength: float = 0.05,
unboost: float = 4.0,
max_glitches: int = 3,
min_glitch_height: int = 20, max_glitch_height: int = 30,
hold_min: int = 1, hold_max: int = 3,
name: str = "shift_distort0") -> None:
"""[dynamic] Glitchy digital video transport, with transient (per-frame) blocks of lines shifted left or right.
`strength`: Amount of the horizontal shift, in units where 2.0 is the width of the full image.
Positive values shift toward the right.
For shifting both left and right, use two copies of the filter in your chain,
one with `strength > 0` and one with `strength < 0`.
`unboost`: Use this to adjust the probability profile for the appearance of glitches.
The higher `unboost` is, the less probable it is for glitches to appear at all,
and there will be fewer of them (in the same video frame) when they do appear.
`max_glitches`: Maximum number of glitches in the video frame.
`min_glitch_height`, `max_glitch_height`: in pixels. The height is randomized separately for each glitch.
`hold_min`, `hold_max`: in frames (at a reference of 25 FPS). Limits for the random time that the
filter holds one glitch pattern before randomizing the next one.
`name`: Optional name for this filter instance in the chain. Used as cache key.
If you have more than one `shift_distort` in the chain, they should have
different names so that each one gets its own cache.
"""
# Re-randomize the glitch pattern whenever enough frames have elapsed after last randomization
if self.shift_distort_grid[name] is None or (int(self.frame_no) - int(self.shift_distort_last_frame_no[name])) >= self.shift_distort_interval[name]:
n_glitches = torch.rand(1, device="cpu")**unboost # unboost: increase probability of having none or few glitching lines
n_glitches = int(max_glitches * n_glitches[0])
meshy = self._meshy
meshx = self._meshx.clone() # don't modify the original; also, make sure each element has a unique memory address
if n_glitches:
c, h, w = image.shape
glitch_start_lines = torch.rand(n_glitches, device="cpu")
glitch_start_lines = [int((h - (max_glitch_height - 1)) * x) for x in glitch_start_lines]
for line in glitch_start_lines:
glitch_height = torch.rand(1, device="cpu")
glitch_height = int(min_glitch_height + (max_glitch_height - min_glitch_height) * glitch_height[0])
meshx[line:(line + glitch_height), :] -= strength
shift_distort_grid = torch.stack((meshx, meshy), 2)
shift_distort_grid = shift_distort_grid.unsqueeze(0) # batch of one
self.shift_distort_grid[name] = shift_distort_grid
# Randomize time until next change of glitch pattern
self.shift_distort_interval[name] = round(hold_min + float(torch.rand(1, device="cpu")[0]) * (hold_max - hold_min))
self.shift_distort_last_frame_no[name] = self.frame_no
else:
shift_distort_grid = self.shift_distort_grid[name]
image_batch = image.unsqueeze(0) # batch of one -> [1, c, h, w]
warped = torch.nn.functional.grid_sample(image_batch, shift_distort_grid, mode="bilinear", padding_mode="border", align_corners=False)
warped = warped.squeeze(0) # [1, c, h, w] -> [c, h, w]
image[:, :, :] = warped
# --------------------------------------------------------------------------------
# CRT TV output