diff --git a/talkinghead/tha3/app/postprocessor.py b/talkinghead/tha3/app/postprocessor.py index a00991e..be8b0da 100644 --- a/talkinghead/tha3/app/postprocessor.py +++ b/talkinghead/tha3/app/postprocessor.py @@ -46,6 +46,8 @@ T = TypeVar("T") Atom = Union[str, bool, int, float] MaybeContained = Union[T, List[T], Dict[str, T]] +VHS_GLITCH_BLANK = object() # nonce value meaning the dynamic VHS glitch effect already decided that no glitches should appear during the current frame + class Postprocessor: """ `chain`: Postprocessor filter chain configuration. @@ -116,6 +118,8 @@ class Postprocessor: # Caches for individual dynamic effects self.alphanoise_last_image = None + self.vhs_glitch_last_image = None + self.vhs_glitch_last_mask = None def render_into(self, image): """Apply current postprocess chain, modifying `image`.""" @@ -457,20 +461,36 @@ class Postprocessor: `max_glitches`: Maximum number of glitches in the video frame. `min_glitch_height`, `max_glitch_height`: in pixels. The height is randomized separately for each glitch. """ - # TODO: FPS correction for `analog_vhsglitches` (need to store glitching line metadata and noise images) - c, h, w = image.shape - n_glitches = torch.rand(1, device="cpu")**unboost # higher probability of having none or few glitching lines - n_glitches = int(max_glitches * n_glitches[0]) - if not n_glitches: - return - glitch_start_lines = torch.rand(n_glitches, device="cpu") - glitch_start_lines = [int((h - (max_glitch_height - 1)) * x) for x in glitch_start_lines] - for line in glitch_start_lines: - glitch_height = torch.rand(1, device="cpu") - glitch_height = int(min_glitch_height + (max_glitch_height - min_glitch_height) * glitch_height[0]) - noise_image = self._vhs_noise(image, height=glitch_height) + # Re-randomize the glitch noise image whenever the normalized frame changes + # TODO: Add `hold_min`, `hold_max` parameters (similarly to how blink and sway work) to set how long a set of glitches persists. + # TODO: Especially useful if we add support for multiple copies of the same kind of effect, since then they could have different settings. + if self.vhs_glitch_last_image is None or int(self.frame_no) > int(self.last_frame_no): + n_glitches = torch.rand(1, device="cpu")**unboost # unboost: increase probability of having none or few glitching lines + n_glitches = int(max_glitches * n_glitches[0]) + if not n_glitches: + vhs_glitch_image = VHS_GLITCH_BLANK # use a nonce value instead of None to distinguish between "uninitialized" and "no glitches during current frame" + vhs_glitch_mask = None + else: + c, h, w = image.shape + vhs_glitch_image = torch.zeros(1, h, w, dtype=image.dtype, device=self.device) # monochrome + vhs_glitch_mask = torch.zeros(1, h, w, dtype=image.dtype, device=self.device) # alpha only + glitch_start_lines = torch.rand(n_glitches, device="cpu") + glitch_start_lines = [int((h - (max_glitch_height - 1)) * x) for x in glitch_start_lines] + for line in glitch_start_lines: + glitch_height = torch.rand(1, device="cpu") + glitch_height = int(min_glitch_height + (max_glitch_height - min_glitch_height) * glitch_height[0]) + vhs_glitch_image[0, line:(line + glitch_height), :] = self._vhs_noise(image, height=glitch_height) # [1, h, w] + vhs_glitch_mask[0, line:(line + glitch_height), :] = 1.0 + self.vhs_glitch_last_image = vhs_glitch_image + self.vhs_glitch_last_mask = vhs_glitch_mask + else: + vhs_glitch_image = self.vhs_glitch_last_image + vhs_glitch_mask = self.vhs_glitch_last_mask + + if vhs_glitch_image is not VHS_GLITCH_BLANK: # Apply glitch to RGB only, so fully transparent parts stay transparent (important to make the effect less distracting). - image[:3, line:(line + glitch_height), :] = (1.0 - strength) * image[:3, line:(line + glitch_height), :] + strength * noise_image + strength_field = strength * vhs_glitch_mask # "field" as in physics, NOT as in CRT TV + image[:3, :, :] = (1.0 - strength_field) * image[:3, :, :] + strength_field * vhs_glitch_image def analog_vhstracking(self, image: torch.tensor, *, base_offset: float = 0.03, @@ -593,7 +613,7 @@ class Postprocessor: # - As the hue difference approaches zero, the pixel is fully passed through. # - The 1.0 - ... together with the square makes a sharp spike at the reference hue. desat_diff2 = (1.0 - torch.clamp(desat_hue_distance / bandpass_q, max=1.0))**2 - strength_field = strength * (1.0 - desat_diff2) # [h, w] + strength_field = strength * (1.0 - desat_diff2) # [h, w]; "field" as in physics, NOT as in CRT TV else: strength_field = strength # just a scalar!