add FPS correction to VHS glitch filter

This commit is contained in:
Juha Jeronen
2024-01-18 14:10:29 +02:00
parent 5905b87b96
commit 8b13beeb3f

View File

@@ -46,6 +46,8 @@ T = TypeVar("T")
Atom = Union[str, bool, int, float]
MaybeContained = Union[T, List[T], Dict[str, T]]
VHS_GLITCH_BLANK = object() # nonce value meaning the dynamic VHS glitch effect already decided that no glitches should appear during the current frame
class Postprocessor:
"""
`chain`: Postprocessor filter chain configuration.
@@ -116,6 +118,8 @@ class Postprocessor:
# Caches for individual dynamic effects
self.alphanoise_last_image = None
self.vhs_glitch_last_image = None
self.vhs_glitch_last_mask = None
def render_into(self, image):
"""Apply current postprocess chain, modifying `image`."""
@@ -457,20 +461,36 @@ class Postprocessor:
`max_glitches`: Maximum number of glitches in the video frame.
`min_glitch_height`, `max_glitch_height`: in pixels. The height is randomized separately for each glitch.
"""
# TODO: FPS correction for `analog_vhsglitches` (need to store glitching line metadata and noise images)
c, h, w = image.shape
n_glitches = torch.rand(1, device="cpu")**unboost # higher probability of having none or few glitching lines
n_glitches = int(max_glitches * n_glitches[0])
if not n_glitches:
return
glitch_start_lines = torch.rand(n_glitches, device="cpu")
glitch_start_lines = [int((h - (max_glitch_height - 1)) * x) for x in glitch_start_lines]
for line in glitch_start_lines:
glitch_height = torch.rand(1, device="cpu")
glitch_height = int(min_glitch_height + (max_glitch_height - min_glitch_height) * glitch_height[0])
noise_image = self._vhs_noise(image, height=glitch_height)
# Re-randomize the glitch noise image whenever the normalized frame changes
# TODO: Add `hold_min`, `hold_max` parameters (similarly to how blink and sway work) to set how long a set of glitches persists.
# TODO: Especially useful if we add support for multiple copies of the same kind of effect, since then they could have different settings.
if self.vhs_glitch_last_image is None or int(self.frame_no) > int(self.last_frame_no):
n_glitches = torch.rand(1, device="cpu")**unboost # unboost: increase probability of having none or few glitching lines
n_glitches = int(max_glitches * n_glitches[0])
if not n_glitches:
vhs_glitch_image = VHS_GLITCH_BLANK # use a nonce value instead of None to distinguish between "uninitialized" and "no glitches during current frame"
vhs_glitch_mask = None
else:
c, h, w = image.shape
vhs_glitch_image = torch.zeros(1, h, w, dtype=image.dtype, device=self.device) # monochrome
vhs_glitch_mask = torch.zeros(1, h, w, dtype=image.dtype, device=self.device) # alpha only
glitch_start_lines = torch.rand(n_glitches, device="cpu")
glitch_start_lines = [int((h - (max_glitch_height - 1)) * x) for x in glitch_start_lines]
for line in glitch_start_lines:
glitch_height = torch.rand(1, device="cpu")
glitch_height = int(min_glitch_height + (max_glitch_height - min_glitch_height) * glitch_height[0])
vhs_glitch_image[0, line:(line + glitch_height), :] = self._vhs_noise(image, height=glitch_height) # [1, h, w]
vhs_glitch_mask[0, line:(line + glitch_height), :] = 1.0
self.vhs_glitch_last_image = vhs_glitch_image
self.vhs_glitch_last_mask = vhs_glitch_mask
else:
vhs_glitch_image = self.vhs_glitch_last_image
vhs_glitch_mask = self.vhs_glitch_last_mask
if vhs_glitch_image is not VHS_GLITCH_BLANK:
# Apply glitch to RGB only, so fully transparent parts stay transparent (important to make the effect less distracting).
image[:3, line:(line + glitch_height), :] = (1.0 - strength) * image[:3, line:(line + glitch_height), :] + strength * noise_image
strength_field = strength * vhs_glitch_mask # "field" as in physics, NOT as in CRT TV
image[:3, :, :] = (1.0 - strength_field) * image[:3, :, :] + strength_field * vhs_glitch_image
def analog_vhstracking(self, image: torch.tensor, *,
base_offset: float = 0.03,
@@ -593,7 +613,7 @@ class Postprocessor:
# - As the hue difference approaches zero, the pixel is fully passed through.
# - The 1.0 - ... together with the square makes a sharp spike at the reference hue.
desat_diff2 = (1.0 - torch.clamp(desat_hue_distance / bandpass_q, max=1.0))**2
strength_field = strength * (1.0 - desat_diff2) # [h, w]
strength_field = strength * (1.0 - desat_diff2) # [h, w]; "field" as in physics, NOT as in CRT TV
else:
strength_field = strength # just a scalar!