diff --git a/server.py b/server.py index 5ffa085..79ea070 100644 --- a/server.py +++ b/server.py @@ -12,6 +12,7 @@ from random import randint import secrets import sys import time +from typing import List, Union import unicodedata from colorama import Fore, Style, init as colorama_init @@ -19,6 +20,7 @@ import markdown from PIL import Image +import numpy as np import torch from transformers import pipeline @@ -125,6 +127,10 @@ def index(): content = f.read() return render_template_string(markdown.markdown(content, extensions=["tables"])) +@app.route("/api/modules", methods=["GET"]) +def get_modules(): + return jsonify({"modules": modules}) + @app.route("/api/extensions", methods=["GET"]) def get_extensions(): extensions = dict( @@ -477,10 +483,6 @@ def api_image_samplers(): return jsonify({"samplers": samplers}) -@app.route("/api/modules", methods=["GET"]) -def get_modules(): - return jsonify({"modules": modules}) - # ---------------------------------------- # tts @@ -488,7 +490,7 @@ tts_service = None # populated when the module is loaded @app.route("/api/tts/speakers", methods=["GET"]) @require_module("silero-tts") -def tts_speakers(): +def api_tts_speakers(): voices = [ { "name": speaker, @@ -502,7 +504,7 @@ def tts_speakers(): # Added fix for Silero not working as new files were unable to be created if one already existed. - Rolyat 7/7/23 @app.route("/api/tts/generate", methods=["POST"]) @require_module("silero-tts") -def tts_generate(): +def api_tts_generate(): voice = request.get_json() if "text" not in voice or not isinstance(voice["text"], str): abort(400, '"text" is required') @@ -526,7 +528,7 @@ def tts_generate(): @app.route("/api/tts/sample/", methods=["GET"]) @require_module("silero-tts") -def tts_play_sample(speaker: str): +def api_tts_play_sample(speaker: str): return send_from_directory(SILERO_SAMPLES_PATH, f"{speaker}.wav") # ---------------------------------------- @@ -536,13 +538,13 @@ edge = None # populated when the module is loaded @app.route("/api/edge-tts/list", methods=["GET"]) @require_module("edge-tts") -def edge_tts_list(): +def api_edge_tts_list(): voices = edge.get_voices() return jsonify(voices) @app.route("/api/edge-tts/generate", methods=["POST"]) @require_module("edge-tts") -def edge_tts_generate(): +def api_edge_tts_generate(): data = request.get_json() if "text" not in data or not isinstance(data["text"], str): abort(400, '"text" is required') @@ -561,6 +563,60 @@ def edge_tts_generate(): print(e) abort(500, data["voice"]) +# ---------------------------------------- +# embeddings + +sentence_embedder = None # populated when the module is loaded + +@app.route("/api/embeddings/compute", methods=["POST"]) +@require_module("embeddings") +def api_embeddings_compute(): + """For making vector DB keys. Compute the vector embedding of one or more sentences of text. + + Input format is JSON:: + + {"text": "Blah blah blah."} + + or:: + + {"text": ["Blah blah blah.", + ...]} + + Output is also JSON:: + + {"embedding": array} + + or:: + + {"embedding": [array0, + ...]} + + respectively. + + This is the Extras backend for computing embeddings in the Vector Storage builtin extension. + """ + data = request.get_json() + if "text" not in data: + abort(400, '"text" is required') + sentences: Union[str, List[str]] = data["text"] + if not (isinstance(sentences, str) or (isinstance(sentences, list) and all(isinstance(x, str) for x in sentences))): + abort(400, '"text" must be string or array of strings') + if isinstance(sentences, str): + nitems = 1 + else: + nitems = len(sentences) + print(f"Computing vector embedding for {nitems} item{'s' if nitems != 1 else ''}") + vectors: Union[np.array, List[np.array]] = sentence_embedder.encode(sentences, + show_progress_bar=True, # on ST-extras console + convert_to_numpy=True, + normalize_embeddings=True) + # NumPy arrays are not JSON serializable, so convert to Python lists + if isinstance(vectors, np.ndarray): + vectors = vectors.tolist() + else: # isinstance(vectors, list) and all(isinstance(x, np.ndarray) for x in vectors) + vectors = [x.tolist() for x in vectors] + return jsonify({"embedding": vectors}) + # ---------------------------------------- # chromadb @@ -569,7 +625,7 @@ chromadb_embed_fn = None @app.route("/api/chromadb", methods=["POST"]) @require_module("chromadb") -def chromadb_add_messages(): +def api_chromadb_add_messages(): data = request.get_json() if "chat_id" not in data or not isinstance(data["chat_id"], str): abort(400, '"chat_id" is required') @@ -598,7 +654,7 @@ def chromadb_add_messages(): @app.route("/api/chromadb/purge", methods=["POST"]) @require_module("chromadb") -def chromadb_purge(): +def api_chromadb_purge(): data = request.get_json() if "chat_id" not in data or not isinstance(data["chat_id"], str): abort(400, '"chat_id" is required') @@ -615,7 +671,7 @@ def chromadb_purge(): @app.route("/api/chromadb/query", methods=["POST"]) @require_module("chromadb") -def chromadb_query(): +def api_chromadb_query(): data = request.get_json() if "chat_id" not in data or not isinstance(data["chat_id"], str): abort(400, '"chat_id" is required') @@ -663,7 +719,7 @@ def chromadb_query(): @app.route("/api/chromadb/multiquery", methods=["POST"]) @require_module("chromadb") -def chromadb_multiquery(): +def api_chromadb_multiquery(): data = request.get_json() if "chat_list" not in data or not isinstance(data["chat_list"], list): abort(400, '"chat_list" is required and should be a list') @@ -726,7 +782,7 @@ def chromadb_multiquery(): @app.route("/api/chromadb/export", methods=["POST"]) @require_module("chromadb") -def chromadb_export(): +def api_chromadb_export(): data = request.get_json() if "chat_id" not in data or not isinstance(data["chat_id"], str): abort(400, '"chat_id" is required') @@ -765,7 +821,7 @@ def chromadb_export(): @app.route("/api/chromadb/import", methods=["POST"]) @require_module("chromadb") -def chromadb_import(): +def api_chromadb_import(): data = request.get_json() content = data['content'] if "chat_id" not in data or not isinstance(data["chat_id"], str): @@ -1045,6 +1101,11 @@ if "edge-tts" in modules: print("Initializing Edge TTS client") import tts_edge as edge +if "embeddings" in modules: + print("Initializing embeddings") + from sentence_transformers import SentenceTransformer + sentence_embedder = SentenceTransformer(embedding_model, device=device_string) + if "chromadb" in modules: print("Initializing ChromaDB") import chromadb @@ -1199,6 +1260,6 @@ if args.share: cloudflare = _run_cloudflared(port) print(f"{Fore.GREEN}{Style.NORMAL}Running on: {cloudflare}{Style.RESET_ALL}") -ignore_auth.append(tts_play_sample) +ignore_auth.append(api_tts_play_sample) ignore_auth.append(api_talkinghead_result_feed) app.run(host=host, port=port) diff --git a/talkinghead/README.md b/talkinghead/README.md index cc4b949..691940b 100644 --- a/talkinghead/README.md +++ b/talkinghead/README.md @@ -222,12 +222,12 @@ The following postprocessing filters are available. Options for each filter are **Transport**: -Currently, we provide some filters that simulate a lo-fi analog video look. - - `analog_lowres`: Simulates a low-resolution analog video signal by blurring the image. - `analog_badhsync`: Simulates bad horizontal synchronization (hsync) of an analog video signal, causing a wavy effect that causes the outline of the character to ripple. +- `analog_distort`: Simulates a rippling, runaway hsync near the top or bottom edge of an image. This can happen with some equipment if the video cable is too long. - `analog_vhsglitches`: Simulates a damaged 1980s VHS tape. In each 25 FPS frame, causes random lines to glitch with VHS noise. - `analog_vhstracking`: Simulates a 1980s VHS tape with bad tracking. The image floats up and down, and a band of VHS noise appears at the bottom. +- `shift_distort`: A glitchy digital video transport as sometimes depicted in sci-fi, with random blocks of lines suddenly shifted horizontally temporarily. **Display**: @@ -261,6 +261,15 @@ The bloom works best on a dark background. We use `lumanoise` to add an imperfec Note that we could also use the `translucency` filter to make the character translucent, e.g.: `["translucency", {"alpha": 0.7}]`. +Also, for some glitching video transport that shifts random blocks of lines horizontally, we could add these: + +``` +["shift_distort", {"strength": 0.05, "name": "shift_right"}], +["shift_distort", {"strength": -0.05, "name": "shift_left"}], +``` + +Having a unique name for each instance is important, because the name acts as a cache key. + #### Postprocessor example: cheap video camera, amber monochrome computer monitor We first simulate a cheap video camera with low-quality optics via the `chromatic_aberration` and `vignetting` filters. diff --git a/talkinghead/TODO.md b/talkinghead/TODO.md index 5891a31..955c9ab 100644 --- a/talkinghead/TODO.md +++ b/talkinghead/TODO.md @@ -7,6 +7,8 @@ As of January 2024, preferably to be completed before the next release. #### Frontend +- See if we can get this working also with local classify now that we have a `set_emotion` API endpoint. + - Responsibilities: the client end should set the emotion when it calls classify, instead of relying on the extras server doing it internally when extras classify is called. - Figure out why the crop filter doesn't help in positioning the `talkinghead` sprite in *MovingUI* mode. - There must be some logic at the frontend side that reserves a square shape for the talkinghead sprite output, regardless of the image dimensions or aspect ratio of the actual `result_feed`. @@ -63,20 +65,13 @@ Not scheduled for now. - The effect on speed will be small; the compute-heaviest part is the inference of the THA3 deep-learning model. - Add more postprocessing filters. Possible ideas, no guarantee I'll ever get around to them: - Pixelize, posterize (8-bit look) - - Analog video glitches - - Partition image into bands, move some left/right temporarily (for a few frames now that we can do that) - - Another effect of bad VHS hsync: dynamic "bending" effect near top edge: - - Distortion by horizontal movement - - Topmost row of pixels moves the most, then a smoothly decaying offset profile as a function of height (decaying to zero at maybe 20% of image height, measured from the top) - - The maximum offset flutters dynamically in a semi-regular, semi-unpredictable manner (use a superposition of three sine waves at different frequencies, as functions of time) - Digital data connection glitches - Apply to random rectangles; may need to persist for a few frames to animate and/or make them more noticeable - - May need to protect important regions like the character's head (approximately, from the template); we're after "Hollywood glitchy", not actually glitchy - Types: - Constant-color rectangle - Missing data (zero out the alpha?) - Blur (leads to replacing by average color, with controllable sigma) - - Zigzag deformation + - Zigzag deformation (perhaps not needed now that we have `shift_distort`, which is similar, but with a rectangular shape, and applied to full lines of video) - Investigate if some particular emotions could use a small random per-frame oscillation applied to "iris_small", for that anime "intense emotion" effect (since THA3 doesn't have a morph specifically for the specular reflections in the eyes). @@ -93,7 +88,14 @@ Not scheduled for now. - To save GPU resources, automatically pause animation when the web browser window with SillyTavern is not in focus. Resume when it regains focus. - Needs a new API endpoint for pause/resume. Note the current `/api/talkinghead/unload` is actually a pause function (the client pauses, and then just hides the live image), but there is currently no resume function (except `/api/talkinghead/load`, which requires sending an image file). - +- Lip-sync talking animation to TTS output. + - THA3 has morphs for A, I, U, E, O, and the "mouth delta" shape Δ. + - This needs either: + - Realtime data from client + - Exists already! See `SillyTavern/public/scripts/extensions/tts/index.js`, function `playAudioData`. There's lip sync for VRM (VRoid). + Still need to investigate how the VRM plugin extracts phonemes from the audio data. + - Or if ST-extras generates the TTS output, then at least a start timestamp for the playback of a given TTS output audio file, + and a possibility to stop animating if the user stops the audio. ### Far future @@ -104,10 +106,4 @@ Definitely not scheduled. Ideas for future enhancements. - The algorithm should be cartoon-aware, some modern-day equivalent of waifu2x. A GAN such as 4x-AnimeSharp or Remacri would be nice, but too slow. - Maybe the scaler should run at the client side to avoid the need to stream 1024x1024 PNGs. - What JavaScript anime scalers are there, or which algorithms are simple enough for a small custom implementation? -- Lip-sync talking animation to TTS output. - - THA3 has morphs for A, I, U, E, O, and the "mouth delta" shape Δ. - - This needs either: - - Realtime data from client - - Or if ST-extras generates the TTS output, then at least a start timestamp for the playback of a given TTS output audio file, - and a possibility to stop animating if the user stops the audio. - Group chats / visual novel mode / several talkingheads running simultaneously. diff --git a/talkinghead/tha3/app/postprocessor.py b/talkinghead/tha3/app/postprocessor.py index ec0c2fa..c835aa2 100644 --- a/talkinghead/tha3/app/postprocessor.py +++ b/talkinghead/tha3/app/postprocessor.py @@ -124,6 +124,9 @@ class Postprocessor: self.vhs_glitch_last_frame_no = defaultdict(lambda: 0.0) self.vhs_glitch_last_image = defaultdict(lambda: None) self.vhs_glitch_last_mask = defaultdict(lambda: None) + self.shift_distort_interval = defaultdict(lambda: 0.0) + self.shift_distort_last_frame_no = defaultdict(lambda: 0.0) + self.shift_distort_grid = defaultdict(lambda: None) def render_into(self, image): """Apply current postprocess chain, modifying `image` in-place.""" @@ -492,6 +495,70 @@ class Postprocessor: warped = warped.squeeze(0) # [1, c, h, w] -> [c, h, w] image[:, :, :] = warped + def analog_distort(self, image: torch.tensor, *, + speed: float = 8.0, + strength: float = 0.1, + ripple_amplitude: float = 0.05, + ripple_density1: float = 4.0, + ripple_density2: Optional[float] = 13.0, + ripple_density3: Optional[float] = 27.0, + edge: str = "top") -> None: + """[dynamic] Analog video signal distorted by a runaway hsync near the top or bottom edge. + + A bad video cable connection can do this, e.g. when connecting a game console to a display + with an analog YPbPr component cable 10m in length. In reality, when I ran into this phenomenon, + the distortion only occurred for near-white images, but as glitch art, it looks better if it's + always applied at full strength. + + `speed`: At speed 1.0, a full cycle of the rippling effect completes every `image_height` frames. + So effectively the cycle position updates by `speed * (1 / image_height)` at each frame. + `strength`: Base strength for maximum distortion at the edge of the image. + In units where the height and width of the image are both 2.0. + `ripple_amplitude`: Variation on top of `strength`. + `ripple_density1`: Like `density` in `analog_badhsync`, but in time. How many cycles the first + component wave completes per one cycle of the ripple effect. + `ripple_density2`: Like `ripple_density1`, but for the second component wave. + Set to `None` or to 0.0 to disable the second component wave. + `ripple_density3`: Like `ripple_density1`, but for the third component wave. + Set to `None` or to 0.0 to disable the third component wave. + `edge`: one of "top", "bottom". Near which edge of the image to apply the maximal distortion. + The distortion then decays to zero, with a quadratic profile, in 1/8 of the image height. + + Note that "frame" here refers to the normalized frame number, at a reference of 25 FPS. + """ + c, h, w = image.shape + + # Animation + # FPS correction happens automatically, because `frame_no` is normalized to CALIBRATION_FPS. + cycle_pos = (self.frame_no / h) * speed + cycle_pos = cycle_pos - float(int(cycle_pos)) # fractional part + cycle_pos *= 2.0 # full cycle = 2 units + + # Deformation + # The spatial distort profile is a quadratic curve [0, 1], for 1/8 of the image height. + meshy = self._meshy + if edge == "top": + spatial_distort_profile = (torch.clamp(meshy + 0.75, max=0.0) * 4.0)**2 # distort near y = -1 + else: # edge == "bottom": + spatial_distort_profile = (torch.clamp(meshy - 0.75, min=0.0) * 4.0)**2 # distort near y = +1 + ripple_amplitude = ripple_amplitude + ripple = math.sin(ripple_density1 * cycle_pos * math.pi) + if ripple_density2: + ripple += math.sin(ripple_density2 * cycle_pos * math.pi) + if ripple_density3: + ripple += math.sin(ripple_density3 * cycle_pos * math.pi) + instantaneous_strength = (1.0 - ripple_amplitude) * strength + ripple_amplitude * ripple + # The minus sign: read coordinates toward the left -> shift the image toward the right. + meshx = self._meshx - instantaneous_strength * spatial_distort_profile + + # Then just the usual incantation for applying a geometric distortion in Torch: + grid = torch.stack((meshx, meshy), 2) + grid = grid.unsqueeze(0) # batch of one + image_batch = image.unsqueeze(0) # batch of one -> [1, c, h, w] + warped = torch.nn.functional.grid_sample(image_batch, grid, mode="bilinear", padding_mode="border", align_corners=False) + warped = warped.squeeze(0) # [1, c, h, w] -> [c, h, w] + image[:, :, :] = warped + def _vhs_noise(self, image: torch.tensor, *, height: int) -> torch.tensor: """Generate a horizontal band of noise that looks as if it came from a blank VHS tape. @@ -615,6 +682,59 @@ class Postprocessor: # fade = fade.unsqueeze(0) # [1, w] # image[3, -noise_pixels:, :] = fade + def shift_distort(self, image: torch.tensor, *, + strength: float = 0.05, + unboost: float = 4.0, + max_glitches: int = 3, + min_glitch_height: int = 20, max_glitch_height: int = 30, + hold_min: int = 1, hold_max: int = 3, + name: str = "shift_distort0") -> None: + """[dynamic] Glitchy digital video transport, with transient (per-frame) blocks of lines shifted left or right. + + `strength`: Amount of the horizontal shift, in units where 2.0 is the width of the full image. + Positive values shift toward the right. + For shifting both left and right, use two copies of the filter in your chain, + one with `strength > 0` and one with `strength < 0`. + `unboost`: Use this to adjust the probability profile for the appearance of glitches. + The higher `unboost` is, the less probable it is for glitches to appear at all, + and there will be fewer of them (in the same video frame) when they do appear. + `max_glitches`: Maximum number of glitches in the video frame. + `min_glitch_height`, `max_glitch_height`: in pixels. The height is randomized separately for each glitch. + `hold_min`, `hold_max`: in frames (at a reference of 25 FPS). Limits for the random time that the + filter holds one glitch pattern before randomizing the next one. + + `name`: Optional name for this filter instance in the chain. Used as cache key. + If you have more than one `shift_distort` in the chain, they should have + different names so that each one gets its own cache. + """ + # Re-randomize the glitch pattern whenever enough frames have elapsed after last randomization + if self.shift_distort_grid[name] is None or (int(self.frame_no) - int(self.shift_distort_last_frame_no[name])) >= self.shift_distort_interval[name]: + n_glitches = torch.rand(1, device="cpu")**unboost # unboost: increase probability of having none or few glitching lines + n_glitches = int(max_glitches * n_glitches[0]) + meshy = self._meshy + meshx = self._meshx.clone() # don't modify the original; also, make sure each element has a unique memory address + if n_glitches: + c, h, w = image.shape + glitch_start_lines = torch.rand(n_glitches, device="cpu") + glitch_start_lines = [int((h - (max_glitch_height - 1)) * x) for x in glitch_start_lines] + for line in glitch_start_lines: + glitch_height = torch.rand(1, device="cpu") + glitch_height = int(min_glitch_height + (max_glitch_height - min_glitch_height) * glitch_height[0]) + meshx[line:(line + glitch_height), :] -= strength + shift_distort_grid = torch.stack((meshx, meshy), 2) + shift_distort_grid = shift_distort_grid.unsqueeze(0) # batch of one + self.shift_distort_grid[name] = shift_distort_grid + # Randomize time until next change of glitch pattern + self.shift_distort_interval[name] = round(hold_min + float(torch.rand(1, device="cpu")[0]) * (hold_max - hold_min)) + self.shift_distort_last_frame_no[name] = self.frame_no + else: + shift_distort_grid = self.shift_distort_grid[name] + + image_batch = image.unsqueeze(0) # batch of one -> [1, c, h, w] + warped = torch.nn.functional.grid_sample(image_batch, shift_distort_grid, mode="bilinear", padding_mode="border", align_corners=False) + warped = warped.squeeze(0) # [1, c, h, w] -> [c, h, w] + image[:, :, :] = warped + # -------------------------------------------------------------------------------- # CRT TV output