mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-01-26 17:20:04 +00:00
Merge pull request #221 from Technologicat/talkinghead-next6
Talkinghead next6, also vector DB embeddings
This commit is contained in:
93
server.py
93
server.py
@@ -12,6 +12,7 @@ from random import randint
|
||||
import secrets
|
||||
import sys
|
||||
import time
|
||||
from typing import List, Union
|
||||
import unicodedata
|
||||
|
||||
from colorama import Fore, Style, init as colorama_init
|
||||
@@ -19,6 +20,7 @@ import markdown
|
||||
|
||||
from PIL import Image
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
@@ -125,6 +127,10 @@ def index():
|
||||
content = f.read()
|
||||
return render_template_string(markdown.markdown(content, extensions=["tables"]))
|
||||
|
||||
@app.route("/api/modules", methods=["GET"])
|
||||
def get_modules():
|
||||
return jsonify({"modules": modules})
|
||||
|
||||
@app.route("/api/extensions", methods=["GET"])
|
||||
def get_extensions():
|
||||
extensions = dict(
|
||||
@@ -477,10 +483,6 @@ def api_image_samplers():
|
||||
|
||||
return jsonify({"samplers": samplers})
|
||||
|
||||
@app.route("/api/modules", methods=["GET"])
|
||||
def get_modules():
|
||||
return jsonify({"modules": modules})
|
||||
|
||||
# ----------------------------------------
|
||||
# tts
|
||||
|
||||
@@ -488,7 +490,7 @@ tts_service = None # populated when the module is loaded
|
||||
|
||||
@app.route("/api/tts/speakers", methods=["GET"])
|
||||
@require_module("silero-tts")
|
||||
def tts_speakers():
|
||||
def api_tts_speakers():
|
||||
voices = [
|
||||
{
|
||||
"name": speaker,
|
||||
@@ -502,7 +504,7 @@ def tts_speakers():
|
||||
# Added fix for Silero not working as new files were unable to be created if one already existed. - Rolyat 7/7/23
|
||||
@app.route("/api/tts/generate", methods=["POST"])
|
||||
@require_module("silero-tts")
|
||||
def tts_generate():
|
||||
def api_tts_generate():
|
||||
voice = request.get_json()
|
||||
if "text" not in voice or not isinstance(voice["text"], str):
|
||||
abort(400, '"text" is required')
|
||||
@@ -526,7 +528,7 @@ def tts_generate():
|
||||
|
||||
@app.route("/api/tts/sample/<speaker>", methods=["GET"])
|
||||
@require_module("silero-tts")
|
||||
def tts_play_sample(speaker: str):
|
||||
def api_tts_play_sample(speaker: str):
|
||||
return send_from_directory(SILERO_SAMPLES_PATH, f"{speaker}.wav")
|
||||
|
||||
# ----------------------------------------
|
||||
@@ -536,13 +538,13 @@ edge = None # populated when the module is loaded
|
||||
|
||||
@app.route("/api/edge-tts/list", methods=["GET"])
|
||||
@require_module("edge-tts")
|
||||
def edge_tts_list():
|
||||
def api_edge_tts_list():
|
||||
voices = edge.get_voices()
|
||||
return jsonify(voices)
|
||||
|
||||
@app.route("/api/edge-tts/generate", methods=["POST"])
|
||||
@require_module("edge-tts")
|
||||
def edge_tts_generate():
|
||||
def api_edge_tts_generate():
|
||||
data = request.get_json()
|
||||
if "text" not in data or not isinstance(data["text"], str):
|
||||
abort(400, '"text" is required')
|
||||
@@ -561,6 +563,60 @@ def edge_tts_generate():
|
||||
print(e)
|
||||
abort(500, data["voice"])
|
||||
|
||||
# ----------------------------------------
|
||||
# embeddings
|
||||
|
||||
sentence_embedder = None # populated when the module is loaded
|
||||
|
||||
@app.route("/api/embeddings/compute", methods=["POST"])
|
||||
@require_module("embeddings")
|
||||
def api_embeddings_compute():
|
||||
"""For making vector DB keys. Compute the vector embedding of one or more sentences of text.
|
||||
|
||||
Input format is JSON::
|
||||
|
||||
{"text": "Blah blah blah."}
|
||||
|
||||
or::
|
||||
|
||||
{"text": ["Blah blah blah.",
|
||||
...]}
|
||||
|
||||
Output is also JSON::
|
||||
|
||||
{"embedding": array}
|
||||
|
||||
or::
|
||||
|
||||
{"embedding": [array0,
|
||||
...]}
|
||||
|
||||
respectively.
|
||||
|
||||
This is the Extras backend for computing embeddings in the Vector Storage builtin extension.
|
||||
"""
|
||||
data = request.get_json()
|
||||
if "text" not in data:
|
||||
abort(400, '"text" is required')
|
||||
sentences: Union[str, List[str]] = data["text"]
|
||||
if not (isinstance(sentences, str) or (isinstance(sentences, list) and all(isinstance(x, str) for x in sentences))):
|
||||
abort(400, '"text" must be string or array of strings')
|
||||
if isinstance(sentences, str):
|
||||
nitems = 1
|
||||
else:
|
||||
nitems = len(sentences)
|
||||
print(f"Computing vector embedding for {nitems} item{'s' if nitems != 1 else ''}")
|
||||
vectors: Union[np.array, List[np.array]] = sentence_embedder.encode(sentences,
|
||||
show_progress_bar=True, # on ST-extras console
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True)
|
||||
# NumPy arrays are not JSON serializable, so convert to Python lists
|
||||
if isinstance(vectors, np.ndarray):
|
||||
vectors = vectors.tolist()
|
||||
else: # isinstance(vectors, list) and all(isinstance(x, np.ndarray) for x in vectors)
|
||||
vectors = [x.tolist() for x in vectors]
|
||||
return jsonify({"embedding": vectors})
|
||||
|
||||
# ----------------------------------------
|
||||
# chromadb
|
||||
|
||||
@@ -569,7 +625,7 @@ chromadb_embed_fn = None
|
||||
|
||||
@app.route("/api/chromadb", methods=["POST"])
|
||||
@require_module("chromadb")
|
||||
def chromadb_add_messages():
|
||||
def api_chromadb_add_messages():
|
||||
data = request.get_json()
|
||||
if "chat_id" not in data or not isinstance(data["chat_id"], str):
|
||||
abort(400, '"chat_id" is required')
|
||||
@@ -598,7 +654,7 @@ def chromadb_add_messages():
|
||||
|
||||
@app.route("/api/chromadb/purge", methods=["POST"])
|
||||
@require_module("chromadb")
|
||||
def chromadb_purge():
|
||||
def api_chromadb_purge():
|
||||
data = request.get_json()
|
||||
if "chat_id" not in data or not isinstance(data["chat_id"], str):
|
||||
abort(400, '"chat_id" is required')
|
||||
@@ -615,7 +671,7 @@ def chromadb_purge():
|
||||
|
||||
@app.route("/api/chromadb/query", methods=["POST"])
|
||||
@require_module("chromadb")
|
||||
def chromadb_query():
|
||||
def api_chromadb_query():
|
||||
data = request.get_json()
|
||||
if "chat_id" not in data or not isinstance(data["chat_id"], str):
|
||||
abort(400, '"chat_id" is required')
|
||||
@@ -663,7 +719,7 @@ def chromadb_query():
|
||||
|
||||
@app.route("/api/chromadb/multiquery", methods=["POST"])
|
||||
@require_module("chromadb")
|
||||
def chromadb_multiquery():
|
||||
def api_chromadb_multiquery():
|
||||
data = request.get_json()
|
||||
if "chat_list" not in data or not isinstance(data["chat_list"], list):
|
||||
abort(400, '"chat_list" is required and should be a list')
|
||||
@@ -726,7 +782,7 @@ def chromadb_multiquery():
|
||||
|
||||
@app.route("/api/chromadb/export", methods=["POST"])
|
||||
@require_module("chromadb")
|
||||
def chromadb_export():
|
||||
def api_chromadb_export():
|
||||
data = request.get_json()
|
||||
if "chat_id" not in data or not isinstance(data["chat_id"], str):
|
||||
abort(400, '"chat_id" is required')
|
||||
@@ -765,7 +821,7 @@ def chromadb_export():
|
||||
|
||||
@app.route("/api/chromadb/import", methods=["POST"])
|
||||
@require_module("chromadb")
|
||||
def chromadb_import():
|
||||
def api_chromadb_import():
|
||||
data = request.get_json()
|
||||
content = data['content']
|
||||
if "chat_id" not in data or not isinstance(data["chat_id"], str):
|
||||
@@ -1045,6 +1101,11 @@ if "edge-tts" in modules:
|
||||
print("Initializing Edge TTS client")
|
||||
import tts_edge as edge
|
||||
|
||||
if "embeddings" in modules:
|
||||
print("Initializing embeddings")
|
||||
from sentence_transformers import SentenceTransformer
|
||||
sentence_embedder = SentenceTransformer(embedding_model, device=device_string)
|
||||
|
||||
if "chromadb" in modules:
|
||||
print("Initializing ChromaDB")
|
||||
import chromadb
|
||||
@@ -1199,6 +1260,6 @@ if args.share:
|
||||
cloudflare = _run_cloudflared(port)
|
||||
print(f"{Fore.GREEN}{Style.NORMAL}Running on: {cloudflare}{Style.RESET_ALL}")
|
||||
|
||||
ignore_auth.append(tts_play_sample)
|
||||
ignore_auth.append(api_tts_play_sample)
|
||||
ignore_auth.append(api_talkinghead_result_feed)
|
||||
app.run(host=host, port=port)
|
||||
|
||||
@@ -222,12 +222,12 @@ The following postprocessing filters are available. Options for each filter are
|
||||
|
||||
**Transport**:
|
||||
|
||||
Currently, we provide some filters that simulate a lo-fi analog video look.
|
||||
|
||||
- `analog_lowres`: Simulates a low-resolution analog video signal by blurring the image.
|
||||
- `analog_badhsync`: Simulates bad horizontal synchronization (hsync) of an analog video signal, causing a wavy effect that causes the outline of the character to ripple.
|
||||
- `analog_distort`: Simulates a rippling, runaway hsync near the top or bottom edge of an image. This can happen with some equipment if the video cable is too long.
|
||||
- `analog_vhsglitches`: Simulates a damaged 1980s VHS tape. In each 25 FPS frame, causes random lines to glitch with VHS noise.
|
||||
- `analog_vhstracking`: Simulates a 1980s VHS tape with bad tracking. The image floats up and down, and a band of VHS noise appears at the bottom.
|
||||
- `shift_distort`: A glitchy digital video transport as sometimes depicted in sci-fi, with random blocks of lines suddenly shifted horizontally temporarily.
|
||||
|
||||
**Display**:
|
||||
|
||||
@@ -261,6 +261,15 @@ The bloom works best on a dark background. We use `lumanoise` to add an imperfec
|
||||
|
||||
Note that we could also use the `translucency` filter to make the character translucent, e.g.: `["translucency", {"alpha": 0.7}]`.
|
||||
|
||||
Also, for some glitching video transport that shifts random blocks of lines horizontally, we could add these:
|
||||
|
||||
```
|
||||
["shift_distort", {"strength": 0.05, "name": "shift_right"}],
|
||||
["shift_distort", {"strength": -0.05, "name": "shift_left"}],
|
||||
```
|
||||
|
||||
Having a unique name for each instance is important, because the name acts as a cache key.
|
||||
|
||||
#### Postprocessor example: cheap video camera, amber monochrome computer monitor
|
||||
|
||||
We first simulate a cheap video camera with low-quality optics via the `chromatic_aberration` and `vignetting` filters.
|
||||
|
||||
@@ -7,6 +7,8 @@ As of January 2024, preferably to be completed before the next release.
|
||||
|
||||
#### Frontend
|
||||
|
||||
- See if we can get this working also with local classify now that we have a `set_emotion` API endpoint.
|
||||
- Responsibilities: the client end should set the emotion when it calls classify, instead of relying on the extras server doing it internally when extras classify is called.
|
||||
- Figure out why the crop filter doesn't help in positioning the `talkinghead` sprite in *MovingUI* mode.
|
||||
- There must be some logic at the frontend side that reserves a square shape for the talkinghead sprite output,
|
||||
regardless of the image dimensions or aspect ratio of the actual `result_feed`.
|
||||
@@ -63,20 +65,13 @@ Not scheduled for now.
|
||||
- The effect on speed will be small; the compute-heaviest part is the inference of the THA3 deep-learning model.
|
||||
- Add more postprocessing filters. Possible ideas, no guarantee I'll ever get around to them:
|
||||
- Pixelize, posterize (8-bit look)
|
||||
- Analog video glitches
|
||||
- Partition image into bands, move some left/right temporarily (for a few frames now that we can do that)
|
||||
- Another effect of bad VHS hsync: dynamic "bending" effect near top edge:
|
||||
- Distortion by horizontal movement
|
||||
- Topmost row of pixels moves the most, then a smoothly decaying offset profile as a function of height (decaying to zero at maybe 20% of image height, measured from the top)
|
||||
- The maximum offset flutters dynamically in a semi-regular, semi-unpredictable manner (use a superposition of three sine waves at different frequencies, as functions of time)
|
||||
- Digital data connection glitches
|
||||
- Apply to random rectangles; may need to persist for a few frames to animate and/or make them more noticeable
|
||||
- May need to protect important regions like the character's head (approximately, from the template); we're after "Hollywood glitchy", not actually glitchy
|
||||
- Types:
|
||||
- Constant-color rectangle
|
||||
- Missing data (zero out the alpha?)
|
||||
- Blur (leads to replacing by average color, with controllable sigma)
|
||||
- Zigzag deformation
|
||||
- Zigzag deformation (perhaps not needed now that we have `shift_distort`, which is similar, but with a rectangular shape, and applied to full lines of video)
|
||||
- Investigate if some particular emotions could use a small random per-frame oscillation applied to "iris_small",
|
||||
for that anime "intense emotion" effect (since THA3 doesn't have a morph specifically for the specular reflections in the eyes).
|
||||
|
||||
@@ -93,7 +88,14 @@ Not scheduled for now.
|
||||
- To save GPU resources, automatically pause animation when the web browser window with SillyTavern is not in focus. Resume when it regains focus.
|
||||
- Needs a new API endpoint for pause/resume. Note the current `/api/talkinghead/unload` is actually a pause function (the client pauses, and
|
||||
then just hides the live image), but there is currently no resume function (except `/api/talkinghead/load`, which requires sending an image file).
|
||||
|
||||
- Lip-sync talking animation to TTS output.
|
||||
- THA3 has morphs for A, I, U, E, O, and the "mouth delta" shape Δ.
|
||||
- This needs either:
|
||||
- Realtime data from client
|
||||
- Exists already! See `SillyTavern/public/scripts/extensions/tts/index.js`, function `playAudioData`. There's lip sync for VRM (VRoid).
|
||||
Still need to investigate how the VRM plugin extracts phonemes from the audio data.
|
||||
- Or if ST-extras generates the TTS output, then at least a start timestamp for the playback of a given TTS output audio file,
|
||||
and a possibility to stop animating if the user stops the audio.
|
||||
|
||||
### Far future
|
||||
|
||||
@@ -104,10 +106,4 @@ Definitely not scheduled. Ideas for future enhancements.
|
||||
- The algorithm should be cartoon-aware, some modern-day equivalent of waifu2x. A GAN such as 4x-AnimeSharp or Remacri would be nice, but too slow.
|
||||
- Maybe the scaler should run at the client side to avoid the need to stream 1024x1024 PNGs.
|
||||
- What JavaScript anime scalers are there, or which algorithms are simple enough for a small custom implementation?
|
||||
- Lip-sync talking animation to TTS output.
|
||||
- THA3 has morphs for A, I, U, E, O, and the "mouth delta" shape Δ.
|
||||
- This needs either:
|
||||
- Realtime data from client
|
||||
- Or if ST-extras generates the TTS output, then at least a start timestamp for the playback of a given TTS output audio file,
|
||||
and a possibility to stop animating if the user stops the audio.
|
||||
- Group chats / visual novel mode / several talkingheads running simultaneously.
|
||||
|
||||
@@ -124,6 +124,9 @@ class Postprocessor:
|
||||
self.vhs_glitch_last_frame_no = defaultdict(lambda: 0.0)
|
||||
self.vhs_glitch_last_image = defaultdict(lambda: None)
|
||||
self.vhs_glitch_last_mask = defaultdict(lambda: None)
|
||||
self.shift_distort_interval = defaultdict(lambda: 0.0)
|
||||
self.shift_distort_last_frame_no = defaultdict(lambda: 0.0)
|
||||
self.shift_distort_grid = defaultdict(lambda: None)
|
||||
|
||||
def render_into(self, image):
|
||||
"""Apply current postprocess chain, modifying `image` in-place."""
|
||||
@@ -492,6 +495,70 @@ class Postprocessor:
|
||||
warped = warped.squeeze(0) # [1, c, h, w] -> [c, h, w]
|
||||
image[:, :, :] = warped
|
||||
|
||||
def analog_distort(self, image: torch.tensor, *,
|
||||
speed: float = 8.0,
|
||||
strength: float = 0.1,
|
||||
ripple_amplitude: float = 0.05,
|
||||
ripple_density1: float = 4.0,
|
||||
ripple_density2: Optional[float] = 13.0,
|
||||
ripple_density3: Optional[float] = 27.0,
|
||||
edge: str = "top") -> None:
|
||||
"""[dynamic] Analog video signal distorted by a runaway hsync near the top or bottom edge.
|
||||
|
||||
A bad video cable connection can do this, e.g. when connecting a game console to a display
|
||||
with an analog YPbPr component cable 10m in length. In reality, when I ran into this phenomenon,
|
||||
the distortion only occurred for near-white images, but as glitch art, it looks better if it's
|
||||
always applied at full strength.
|
||||
|
||||
`speed`: At speed 1.0, a full cycle of the rippling effect completes every `image_height` frames.
|
||||
So effectively the cycle position updates by `speed * (1 / image_height)` at each frame.
|
||||
`strength`: Base strength for maximum distortion at the edge of the image.
|
||||
In units where the height and width of the image are both 2.0.
|
||||
`ripple_amplitude`: Variation on top of `strength`.
|
||||
`ripple_density1`: Like `density` in `analog_badhsync`, but in time. How many cycles the first
|
||||
component wave completes per one cycle of the ripple effect.
|
||||
`ripple_density2`: Like `ripple_density1`, but for the second component wave.
|
||||
Set to `None` or to 0.0 to disable the second component wave.
|
||||
`ripple_density3`: Like `ripple_density1`, but for the third component wave.
|
||||
Set to `None` or to 0.0 to disable the third component wave.
|
||||
`edge`: one of "top", "bottom". Near which edge of the image to apply the maximal distortion.
|
||||
The distortion then decays to zero, with a quadratic profile, in 1/8 of the image height.
|
||||
|
||||
Note that "frame" here refers to the normalized frame number, at a reference of 25 FPS.
|
||||
"""
|
||||
c, h, w = image.shape
|
||||
|
||||
# Animation
|
||||
# FPS correction happens automatically, because `frame_no` is normalized to CALIBRATION_FPS.
|
||||
cycle_pos = (self.frame_no / h) * speed
|
||||
cycle_pos = cycle_pos - float(int(cycle_pos)) # fractional part
|
||||
cycle_pos *= 2.0 # full cycle = 2 units
|
||||
|
||||
# Deformation
|
||||
# The spatial distort profile is a quadratic curve [0, 1], for 1/8 of the image height.
|
||||
meshy = self._meshy
|
||||
if edge == "top":
|
||||
spatial_distort_profile = (torch.clamp(meshy + 0.75, max=0.0) * 4.0)**2 # distort near y = -1
|
||||
else: # edge == "bottom":
|
||||
spatial_distort_profile = (torch.clamp(meshy - 0.75, min=0.0) * 4.0)**2 # distort near y = +1
|
||||
ripple_amplitude = ripple_amplitude
|
||||
ripple = math.sin(ripple_density1 * cycle_pos * math.pi)
|
||||
if ripple_density2:
|
||||
ripple += math.sin(ripple_density2 * cycle_pos * math.pi)
|
||||
if ripple_density3:
|
||||
ripple += math.sin(ripple_density3 * cycle_pos * math.pi)
|
||||
instantaneous_strength = (1.0 - ripple_amplitude) * strength + ripple_amplitude * ripple
|
||||
# The minus sign: read coordinates toward the left -> shift the image toward the right.
|
||||
meshx = self._meshx - instantaneous_strength * spatial_distort_profile
|
||||
|
||||
# Then just the usual incantation for applying a geometric distortion in Torch:
|
||||
grid = torch.stack((meshx, meshy), 2)
|
||||
grid = grid.unsqueeze(0) # batch of one
|
||||
image_batch = image.unsqueeze(0) # batch of one -> [1, c, h, w]
|
||||
warped = torch.nn.functional.grid_sample(image_batch, grid, mode="bilinear", padding_mode="border", align_corners=False)
|
||||
warped = warped.squeeze(0) # [1, c, h, w] -> [c, h, w]
|
||||
image[:, :, :] = warped
|
||||
|
||||
def _vhs_noise(self, image: torch.tensor, *,
|
||||
height: int) -> torch.tensor:
|
||||
"""Generate a horizontal band of noise that looks as if it came from a blank VHS tape.
|
||||
@@ -615,6 +682,59 @@ class Postprocessor:
|
||||
# fade = fade.unsqueeze(0) # [1, w]
|
||||
# image[3, -noise_pixels:, :] = fade
|
||||
|
||||
def shift_distort(self, image: torch.tensor, *,
|
||||
strength: float = 0.05,
|
||||
unboost: float = 4.0,
|
||||
max_glitches: int = 3,
|
||||
min_glitch_height: int = 20, max_glitch_height: int = 30,
|
||||
hold_min: int = 1, hold_max: int = 3,
|
||||
name: str = "shift_distort0") -> None:
|
||||
"""[dynamic] Glitchy digital video transport, with transient (per-frame) blocks of lines shifted left or right.
|
||||
|
||||
`strength`: Amount of the horizontal shift, in units where 2.0 is the width of the full image.
|
||||
Positive values shift toward the right.
|
||||
For shifting both left and right, use two copies of the filter in your chain,
|
||||
one with `strength > 0` and one with `strength < 0`.
|
||||
`unboost`: Use this to adjust the probability profile for the appearance of glitches.
|
||||
The higher `unboost` is, the less probable it is for glitches to appear at all,
|
||||
and there will be fewer of them (in the same video frame) when they do appear.
|
||||
`max_glitches`: Maximum number of glitches in the video frame.
|
||||
`min_glitch_height`, `max_glitch_height`: in pixels. The height is randomized separately for each glitch.
|
||||
`hold_min`, `hold_max`: in frames (at a reference of 25 FPS). Limits for the random time that the
|
||||
filter holds one glitch pattern before randomizing the next one.
|
||||
|
||||
`name`: Optional name for this filter instance in the chain. Used as cache key.
|
||||
If you have more than one `shift_distort` in the chain, they should have
|
||||
different names so that each one gets its own cache.
|
||||
"""
|
||||
# Re-randomize the glitch pattern whenever enough frames have elapsed after last randomization
|
||||
if self.shift_distort_grid[name] is None or (int(self.frame_no) - int(self.shift_distort_last_frame_no[name])) >= self.shift_distort_interval[name]:
|
||||
n_glitches = torch.rand(1, device="cpu")**unboost # unboost: increase probability of having none or few glitching lines
|
||||
n_glitches = int(max_glitches * n_glitches[0])
|
||||
meshy = self._meshy
|
||||
meshx = self._meshx.clone() # don't modify the original; also, make sure each element has a unique memory address
|
||||
if n_glitches:
|
||||
c, h, w = image.shape
|
||||
glitch_start_lines = torch.rand(n_glitches, device="cpu")
|
||||
glitch_start_lines = [int((h - (max_glitch_height - 1)) * x) for x in glitch_start_lines]
|
||||
for line in glitch_start_lines:
|
||||
glitch_height = torch.rand(1, device="cpu")
|
||||
glitch_height = int(min_glitch_height + (max_glitch_height - min_glitch_height) * glitch_height[0])
|
||||
meshx[line:(line + glitch_height), :] -= strength
|
||||
shift_distort_grid = torch.stack((meshx, meshy), 2)
|
||||
shift_distort_grid = shift_distort_grid.unsqueeze(0) # batch of one
|
||||
self.shift_distort_grid[name] = shift_distort_grid
|
||||
# Randomize time until next change of glitch pattern
|
||||
self.shift_distort_interval[name] = round(hold_min + float(torch.rand(1, device="cpu")[0]) * (hold_max - hold_min))
|
||||
self.shift_distort_last_frame_no[name] = self.frame_no
|
||||
else:
|
||||
shift_distort_grid = self.shift_distort_grid[name]
|
||||
|
||||
image_batch = image.unsqueeze(0) # batch of one -> [1, c, h, w]
|
||||
warped = torch.nn.functional.grid_sample(image_batch, shift_distort_grid, mode="bilinear", padding_mode="border", align_corners=False)
|
||||
warped = warped.squeeze(0) # [1, c, h, w] -> [c, h, w]
|
||||
image[:, :, :] = warped
|
||||
|
||||
# --------------------------------------------------------------------------------
|
||||
# CRT TV output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user