mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-28 10:21:20 +00:00
add framerate correction to most postproc filters
This commit is contained in:
@@ -3,12 +3,21 @@
|
||||
These effects work in linear intensity space, before gamma correction.
|
||||
"""
|
||||
|
||||
__all__ = ["Postprocessor"]
|
||||
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from tha3.app.util import RunningAverage
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# # Default configuration for the postprocessor.
|
||||
# # This documents the correct ordering of the filters.
|
||||
# # Feel free to improvise, but make sure to understand why your filter chain makes sense.
|
||||
@@ -67,22 +76,51 @@ class Postprocessor:
|
||||
taking effect immediately. It is recommended to update the chain atomically, by::
|
||||
|
||||
my_postprocessor.chain = my_new_chain
|
||||
|
||||
In filter descriptions:
|
||||
[static] := depends only on input image, no explicit time dependence.
|
||||
[dynamic] := beside input image, also depends on time. In other words,
|
||||
produces animation even if the input image stays the same.
|
||||
"""
|
||||
|
||||
def __init__(self, device: torch.device, chain: Optional[List[Tuple[str, Dict[str, MaybeContained[Atom]]]]] = None):
|
||||
# We intentionally keep very little state in this class, for a more FP/REST approach with less bugs.
|
||||
# There's just the device info, a frame counter, and the current filter chain config (which is read at every frame).
|
||||
# The filters themselves are stateless; but note that they overwrite the image being processed.
|
||||
# Filters for static effects are stateless.
|
||||
#
|
||||
# We deviate from FP in that:
|
||||
# - The filters MUTATE, i.e. they overwrite the image being processed.
|
||||
# This is to allow optimizing their implementations for memory usage and speed.
|
||||
# - The filter for a dynamic effect may store state, if needed for performing FPS correction.
|
||||
self.device = device
|
||||
self.frame_no = 0
|
||||
if chain is None:
|
||||
chain = default_chain
|
||||
self.chain = chain
|
||||
|
||||
# Meshgrid cache for geometric position of each pixel
|
||||
self._yy = None
|
||||
self._xx = None
|
||||
self._meshy = None
|
||||
self._meshx = None
|
||||
self._prev_h = None
|
||||
self._prev_w = None
|
||||
|
||||
# FPS correction
|
||||
self.CALIBRATION_FPS = 25 # design FPS for dynamic effects (for automatic FPS correction)
|
||||
self.stream_start_timestamp = time.time_ns() # for updating frame counter reliably (no accumulation)
|
||||
self.frame_no = -1 # float, frame counter for *normalized* frame number *at CALIBRATION_FPS*
|
||||
self.last_frame_no = -1
|
||||
|
||||
# Performance measurement
|
||||
self.render_duration_statistics = RunningAverage()
|
||||
self.last_report_time = None
|
||||
|
||||
# Caches for individual dynamic effects
|
||||
self.alphanoise_last_image = None
|
||||
|
||||
def render_into(self, image):
|
||||
"""Apply current postprocess chain, modifying `image`."""
|
||||
time_render_start = time.time_ns()
|
||||
|
||||
c, h, w = image.shape
|
||||
if h != self._prev_h or w != self._prev_w:
|
||||
# Compute base meshgrid for the geometric position of each pixel.
|
||||
@@ -101,11 +139,77 @@ class Postprocessor:
|
||||
self._prev_h = h
|
||||
self._prev_w = w
|
||||
|
||||
for filter_name, settings in self.chain:
|
||||
# Update the frame counter.
|
||||
#
|
||||
# We consider the frame number to be a float, so that dynamic filters can decide what
|
||||
# to do at fractional frame positions. For continuously animated effects (e.g. banding)
|
||||
# it makes sense to interpolate continuously, whereas other effects (e.g. scanlines)
|
||||
# can make their decisions based on the integer part.
|
||||
#
|
||||
# As always with floats, we must be careful. Note that we operate in a mindset of robust
|
||||
# engineering. Since doing the Right Thing here does not cost significantly more engineering
|
||||
# effort than doing the intuitive but Wrong Thing, it is preferable to go for the proper solution,
|
||||
# regardless of whether it would take a centuries-long session to actually trigger a failure
|
||||
# in the less robust approach.
|
||||
#
|
||||
# So, floating point accuracy considerations? First, we note that accumulation invites
|
||||
# disaster in two ways:
|
||||
#
|
||||
# - Accumulating the result accumulates also representation error and roundoff error.
|
||||
# - When accumulating small positive numbers to a sum total, the update eventually
|
||||
# becomes too small to add, causing the counter to get stuck. (For floats, `x + ϵ = x`
|
||||
# for sufficiently small ϵ dependent on the magnitude of `x`.)
|
||||
#
|
||||
# Fortunately, frame number is a linear function of time, and time diffs can be measured
|
||||
# precisely. Thus, we can freshly compute the current frame number at each frame, completely
|
||||
# bypassing the need for accumulation:
|
||||
#
|
||||
seconds_since_stream_start = (time_render_start - self.stream_start_timestamp) / 10**9
|
||||
self.last_frame_no = self.frame_no
|
||||
self.frame_no = self.CALIBRATION_FPS * seconds_since_stream_start # float!
|
||||
|
||||
# That leaves just the questions of how accurate the calculation is, and for how long.
|
||||
# As to the first question:
|
||||
#
|
||||
# - Timestamps are an integer number of nanoseconds, so they are exact.
|
||||
# - Dividing by 10**9, we move the decimal point. But floats are base-2, so 0.1
|
||||
# is not representable in IEEE-754. So there will be some small representation error,
|
||||
# which for float64 likely appears in the ~15th significant digit.
|
||||
# - Basic arithmetic, such as multiplication, is guaranteed by IEEE-754
|
||||
# to be accurate to the ULP.
|
||||
#
|
||||
# Thus, as the result, we obtain the closest number that is representable in IEEE-754,
|
||||
# and the strategy works for the whole range of float64.
|
||||
#
|
||||
# As for the second question, floats are logarithmically spaced. So if this is left running
|
||||
# "for long enough" during the same session, accuracy will eventually suffer. Instead of the
|
||||
# counter getting stuck, however, this will manifest as the frame number updating by more
|
||||
# than `1.0` each time it updates (i.e. whenever the elapsed number of frames reaches the
|
||||
# next representable float).
|
||||
#
|
||||
# This could be fixed by resetting `stream_start_timestamp` once the frame number
|
||||
# becomes too large. But in practice, how long does it take for this issue to occur?
|
||||
# The ULP becomes 1.0 at ~5e15. To reach frame number 5e15, at the reference 25 FPS,
|
||||
# the time required is 2e14 seconds, i.e. 2.31e9 days, or 6.34 million years.
|
||||
# While I can almost imagine the eventual bug report, I think it's safe to ignore this.
|
||||
|
||||
# Apply the current filter chain.
|
||||
chain = self.chain # read just once; other threads might reassign it while we're rendering
|
||||
for filter_name, settings in chain:
|
||||
apply_filter = getattr(self, filter_name)
|
||||
apply_filter(image, **settings)
|
||||
|
||||
self.frame_no += 1
|
||||
time_now = time.time_ns()
|
||||
render_elapsed_sec = (time_now - time_render_start) / 10**9
|
||||
self.render_duration_statistics.add_datapoint(render_elapsed_sec)
|
||||
|
||||
# Log the FPS counter in 5-second intervals.
|
||||
if (self.last_report_time is None or time_now - self.last_report_time > 5e9):
|
||||
avg_render_sec = self.render_duration_statistics.average()
|
||||
msec = round(1000 * avg_render_sec, 1)
|
||||
fps = round(1 / avg_render_sec, 1) if avg_render_sec > 0.0 else 0.0
|
||||
logger.info(f"postproc: {msec:.1f}ms [{fps} FPS available]")
|
||||
self.last_report_time = time_now
|
||||
|
||||
# --------------------------------------------------------------------------------
|
||||
# Physical input signal
|
||||
@@ -113,7 +217,7 @@ class Postprocessor:
|
||||
def bloom(self, image: torch.tensor, *,
|
||||
luma_threshold: float = 0.8,
|
||||
hdr_exposure: float = 0.7) -> None:
|
||||
"""Bloom effect (fake HDR). Popular in early 2000s anime.
|
||||
"""[static] Bloom effect (fake HDR). Popular in early 2000s anime.
|
||||
|
||||
Makes bright parts of the image bleed light into their surroundings, enhancing perceived contrast.
|
||||
Only makes sense when the talkinghead is rendered on a dark-ish background.
|
||||
@@ -156,7 +260,7 @@ class Postprocessor:
|
||||
def chromatic_aberration(self, image: torch.tensor, *,
|
||||
transverse_sigma: float = 0.5,
|
||||
axial_scale: float = 0.005) -> None:
|
||||
"""Simulate the two types of chromatic aberration in a camera lens.
|
||||
"""[static] Simulate the two types of chromatic aberration in a camera lens.
|
||||
|
||||
Like everything else here, this is of course made of smoke and mirrors. We simulate the axial effect
|
||||
(index of refraction varying w.r.t. wavelength) by geometrically scaling the RGB channels individually,
|
||||
@@ -206,7 +310,7 @@ class Postprocessor:
|
||||
|
||||
def vignetting(self, image: torch.tensor, *,
|
||||
strength: float = 0.42) -> None:
|
||||
"""Simulate vignetting (less light hitting the corners of a film frame or CCD sensor).
|
||||
"""[static] Simulate vignetting (less light hitting the corners of a film frame or CCD sensor).
|
||||
|
||||
The profile used here is [cos(strength * d * pi)]**2, where `d` is the distance
|
||||
from the center, scaled such that `d = 1.0` is reached at the corners.
|
||||
@@ -222,7 +326,7 @@ class Postprocessor:
|
||||
|
||||
def translucency(self, image: torch.tensor, *,
|
||||
alpha: float = 0.9) -> None:
|
||||
"""A simple translucency filter for a hologram look.
|
||||
"""[static] A simple translucency filter for a hologram look.
|
||||
|
||||
Multiplicatively adjusts the alpha channel.
|
||||
"""
|
||||
@@ -234,7 +338,7 @@ class Postprocessor:
|
||||
def alphanoise(self, image: torch.tensor, *,
|
||||
magnitude: float = 0.1,
|
||||
sigma: float = 0.0) -> None:
|
||||
"""Dynamic noise to alpha channel. A cheap alternative to luma noise.
|
||||
"""[dynamic] Dynamic noise to alpha channel. A cheap alternative to luma noise.
|
||||
|
||||
`magnitude`: How much noise to apply. 0 is off, 1 is as much noise as possible.
|
||||
|
||||
@@ -249,12 +353,17 @@ class Postprocessor:
|
||||
Scifi hologram: magnitude=0.1, sigma=0.0
|
||||
Analog VHS tape: magnitude=0.2, sigma=2.0
|
||||
"""
|
||||
c, h, w = image.shape
|
||||
noise_image = torch.rand(h, w, device=self.device, dtype=image.dtype)
|
||||
if sigma > 0.0:
|
||||
noise_image = noise_image.unsqueeze(0) # [h, w] -> [c, h, w] (where c=1)
|
||||
noise_image = torchvision.transforms.GaussianBlur((5, 5), sigma=sigma)(noise_image)
|
||||
noise_image = noise_image.squeeze(0) # -> [h, w]
|
||||
# Re-randomize the noise image whenever the normalized frame changes
|
||||
if self.alphanoise_last_image is None or int(self.frame_no) > int(self.last_frame_no):
|
||||
c, h, w = image.shape
|
||||
noise_image = torch.rand(h, w, device=self.device, dtype=image.dtype)
|
||||
if sigma > 0.0:
|
||||
noise_image = noise_image.unsqueeze(0) # [h, w] -> [c, h, w] (where c=1)
|
||||
noise_image = torchvision.transforms.GaussianBlur((5, 5), sigma=sigma)(noise_image)
|
||||
noise_image = noise_image.squeeze(0) # -> [h, w]
|
||||
self.alphanoise_last_image = noise_image
|
||||
else:
|
||||
noise_image = self.alphanoise_last_image
|
||||
base_magnitude = 1.0 - magnitude
|
||||
image[3, :, :].mul_(base_magnitude + magnitude * noise_image)
|
||||
|
||||
@@ -264,7 +373,7 @@ class Postprocessor:
|
||||
def analog_lowres(self, image: torch.tensor, *,
|
||||
kernel_size: int = 5,
|
||||
sigma: float = 0.75) -> None:
|
||||
"""Low-resolution analog video signal, simulated by blurring.
|
||||
"""[static] Low-resolution analog video signal, simulated by blurring.
|
||||
|
||||
`kernel_size`: size of the Gaussian blur kernel, in pixels.
|
||||
`sigma`: standard deviation of the Gaussian blur kernel, in pixels.
|
||||
@@ -282,7 +391,7 @@ class Postprocessor:
|
||||
amplitude1: float = 0.001, density1: float = 4.0,
|
||||
amplitude2: Optional[float] = 0.001, density2: Optional[float] = 13.0,
|
||||
amplitude3: Optional[float] = 0.001, density3: Optional[float] = 27.0) -> None:
|
||||
"""Analog video signal with fluctuating hsync.
|
||||
"""[dynamic] Analog video signal with fluctuating hsync.
|
||||
|
||||
We superpose three waves with different densities (1 / cycle length)
|
||||
to make the pattern look more irregular.
|
||||
@@ -295,6 +404,7 @@ class Postprocessor:
|
||||
c, h, w = image.shape
|
||||
|
||||
# Animation
|
||||
# FPS correction happens automatically, because `frame_no` is normalized to CALIBRATION_FPS.
|
||||
cycle_pos = (self.frame_no / h) * speed
|
||||
cycle_pos = cycle_pos - float(int(cycle_pos)) # fractional part
|
||||
cycle_pos = 1.0 - cycle_pos # -> motion from top toward bottom
|
||||
@@ -335,7 +445,7 @@ class Postprocessor:
|
||||
unboost: float = 4.0,
|
||||
max_glitches: int = 3,
|
||||
min_glitch_height: int = 3, max_glitch_height: int = 6) -> None:
|
||||
"""Damaged 1980s VHS video tape, with transient (per-frame) glitching lines.
|
||||
"""[dynamic] Damaged 1980s VHS video tape, with transient (per-frame) glitching lines.
|
||||
|
||||
This leaves the alpha channel alone, so the effect only affects parts that already show something.
|
||||
This is an artistic interpretation that makes the effect less distracting when used with RGBA data.
|
||||
@@ -347,6 +457,7 @@ class Postprocessor:
|
||||
`max_glitches`: Maximum number of glitches in the video frame.
|
||||
`min_glitch_height`, `max_glitch_height`: in pixels. The height is randomized separately for each glitch.
|
||||
"""
|
||||
# TODO: FPS correction for `analog_vhsglitches` (need to store glitching line metadata and noise images)
|
||||
c, h, w = image.shape
|
||||
n_glitches = torch.rand(1, device="cpu")**unboost # higher probability of having none or few glitching lines
|
||||
n_glitches = int(max_glitches * n_glitches[0])
|
||||
@@ -365,7 +476,7 @@ class Postprocessor:
|
||||
base_offset: float = 0.03,
|
||||
max_dynamic_offset: float = 0.01,
|
||||
speed: float = 2.5) -> None:
|
||||
"""1980s VHS tape with bad tracking.
|
||||
"""[dynamic] 1980s VHS tape with bad tracking.
|
||||
|
||||
Image floats up and down, and a band of black and white noise appears at the bottom.
|
||||
|
||||
@@ -374,6 +485,7 @@ class Postprocessor:
|
||||
c, h, w = image.shape
|
||||
|
||||
# Animation
|
||||
# FPS correction happens automatically, because `frame_no` is normalized to CALIBRATION_FPS.
|
||||
cycle_pos = (self.frame_no / h) * speed
|
||||
cycle_pos = cycle_pos - float(int(cycle_pos)) # fractional part
|
||||
cycle_pos *= 2.0 # full cycle = 2 units
|
||||
@@ -424,7 +536,7 @@ class Postprocessor:
|
||||
strength: float = 1.0,
|
||||
tint_rgb: List[float] = [1.0, 1.0, 1.0],
|
||||
bandpass_reference_rgb: List[float] = [1.0, 0.0, 0.0], bandpass_q: float = 0.0) -> None:
|
||||
"""Desaturation with bells and whistles.
|
||||
"""[static] Desaturation with bells and whistles.
|
||||
|
||||
Does not touch the alpha channel.
|
||||
|
||||
@@ -498,7 +610,7 @@ class Postprocessor:
|
||||
strength: float = 0.4,
|
||||
density: float = 2.0,
|
||||
speed: float = 16.0) -> None:
|
||||
"""Bad analog video signal, with traveling brighter and darker bands.
|
||||
"""[dynamic] Bad analog video signal, with traveling brighter and darker bands.
|
||||
|
||||
This simulates a CRT display as it looks when filmed on video without syncing.
|
||||
|
||||
@@ -510,6 +622,7 @@ class Postprocessor:
|
||||
yy = torch.linspace(0, math.pi, h, dtype=image.dtype, device=self.device)
|
||||
|
||||
# Animation
|
||||
# FPS correction happens automatically, because `frame_no` is normalized to CALIBRATION_FPS.
|
||||
cycle_pos = (self.frame_no / h) * speed
|
||||
cycle_pos = cycle_pos - float(int(cycle_pos)) # fractional part
|
||||
cycle_pos = 1.0 - cycle_pos # -> motion from top toward bottom
|
||||
@@ -523,14 +636,14 @@ class Postprocessor:
|
||||
def scanlines(self, image: torch.tensor, *,
|
||||
field: int = 0,
|
||||
dynamic: bool = True) -> None:
|
||||
"""CRT TV like scanlines.
|
||||
"""[dynamic] CRT TV like scanlines.
|
||||
|
||||
`field`: Which CRT field is dimmed at the first frame. 0 = top, 1 = bottom.
|
||||
`dynamic`: If `True`, the dimmed field will alternate each frame (top, bottom, top, bottom, ...)
|
||||
for a more authentic CRT look (like Phosphor deinterlacer in VLC).
|
||||
"""
|
||||
if dynamic:
|
||||
start = (field + self.frame_no) % 2
|
||||
start = (field + int(self.frame_no)) % 2
|
||||
else:
|
||||
start = field
|
||||
# We should ideally modify just the Y channel in YUV space, but modifying the alpha instead looks alright.
|
||||
|
||||
Reference in New Issue
Block a user