From e4f3d335dceb741bd016af15d1c8151362952bf9 Mon Sep 17 00:00:00 2001 From: bymyself Date: Fri, 23 Jan 2026 20:32:57 -0800 Subject: [PATCH] feat: Add VideoSlice node with lazy operations on VideoInput - Add VideoOp base class and SliceOp in _input/video_types.py - Add sliced() method to VideoInput that returns a copy with operation appended - Each subclass applies operations in get_components() and get_frame_count() - After materialization, VideoFromFile delegates to internal VideoFromComponents - Add VideoSlice node that uses video.sliced(start_frame, frame_count) - Add tests for SliceOp, sliced() behavior, and materialization --- comfy_api/latest/_input/__init__.py | 4 +- comfy_api/latest/_input/video_types.py | 43 +++++ comfy_api/latest/_input_impl/__init__.py | 3 +- comfy_api/latest/_input_impl/video_types.py | 52 ++++-- comfy_extras/nodes_video.py | 24 +++ tests-unit/comfy_api_test/video_slice_test.py | 150 ++++++++++++++++++ 6 files changed, 263 insertions(+), 13 deletions(-) create mode 100644 tests-unit/comfy_api_test/video_slice_test.py diff --git a/comfy_api/latest/_input/__init__.py b/comfy_api/latest/_input/__init__.py index 14f0e72f4..e51c97372 100644 --- a/comfy_api/latest/_input/__init__.py +++ b/comfy_api/latest/_input/__init__.py @@ -1,10 +1,12 @@ from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput -from .video_types import VideoInput +from .video_types import VideoInput, VideoOp, SliceOp __all__ = [ "ImageInput", "AudioInput", "VideoInput", + "VideoOp", + "SliceOp", "MaskInput", "LatentInput", ] diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py index e634a0311..3dacad178 100644 --- a/comfy_api/latest/_input/video_types.py +++ b/comfy_api/latest/_input/video_types.py @@ -1,11 +1,48 @@ from __future__ import annotations from abc import ABC, abstractmethod +from dataclasses import dataclass from fractions import Fraction from typing import Optional, Union, IO +import copy import io import av from .._util import VideoContainer, VideoCodec, VideoComponents + +class VideoOp(ABC): + """Base class for lazy video operations.""" + + @abstractmethod + def apply(self, components: VideoComponents) -> VideoComponents: + pass + + @abstractmethod + def compute_frame_count(self, input_frame_count: int) -> int: + pass + + +@dataclass(frozen=True) +class SliceOp(VideoOp): + """Extract a range of frames from the video.""" + start_frame: int + frame_count: int + + def apply(self, components: VideoComponents) -> VideoComponents: + total = components.images.shape[0] + start = max(0, min(self.start_frame, total)) + end = min(start + self.frame_count, total) + return VideoComponents( + images=components.images[start:end], + audio=components.audio, + frame_rate=components.frame_rate, + metadata=getattr(components, 'metadata', None), + ) + + def compute_frame_count(self, input_frame_count: int) -> int: + start = max(0, min(self.start_frame, input_frame_count)) + return min(self.frame_count, input_frame_count - start) + + class VideoInput(ABC): """ Abstract base class for video input types. @@ -21,6 +58,12 @@ class VideoInput(ABC): """ pass + def sliced(self, start_frame: int, frame_count: int) -> "VideoInput": + """Return a copy of this video with a slice operation appended.""" + new = copy.copy(self) + new._operations = getattr(self, '_operations', []) + [SliceOp(start_frame, frame_count)] + return new + @abstractmethod def save_to( self, diff --git a/comfy_api/latest/_input_impl/__init__.py b/comfy_api/latest/_input_impl/__init__.py index 02901b8b9..fe8e0bf35 100644 --- a/comfy_api/latest/_input_impl/__init__.py +++ b/comfy_api/latest/_input_impl/__init__.py @@ -1,7 +1,8 @@ from .video_types import VideoFromFile, VideoFromComponents +from .._input import SliceOp __all__ = [ - # Implementations "VideoFromFile", "VideoFromComponents", + "SliceOp", ] diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 1405d0b81..709d53e60 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -3,7 +3,7 @@ from av.container import InputContainer from av.subtitles.stream import SubtitleStream from fractions import Fraction from typing import Optional -from .._input import AudioInput, VideoInput +from .._input import AudioInput, VideoInput, VideoOp import av import io import json @@ -63,6 +63,8 @@ class VideoFromFile(VideoInput): containing the file contents. """ self.__file = file + self._operations: list[VideoOp] = [] + self.__materialized: Optional[VideoFromComponents] = None def get_stream_source(self) -> str | io.BytesIO: """ @@ -161,6 +163,10 @@ class VideoFromFile(VideoInput): if frame_count == 0: raise ValueError(f"Could not determine frame count for file '{self.__file}'") + + # Apply operations to get final frame count + for op in self._operations: + frame_count = op.compute_frame_count(frame_count) return frame_count def get_frame_rate(self) -> Fraction: @@ -239,10 +245,18 @@ class VideoFromFile(VideoInput): return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata) def get_components(self) -> VideoComponents: + if self.__materialized is not None: + return self.__materialized.get_components() + if isinstance(self.__file, io.BytesIO): self.__file.seek(0) # Reset the BytesIO object to the beginning with av.open(self.__file, mode='r') as container: - return self.get_components_internal(container) + components = self.get_components_internal(container) + for op in self._operations: + components = op.apply(components) + self.__materialized = VideoFromComponents(components) + self._operations = [] + return components raise ValueError(f"No video stream found in file '{self.__file}'") def save_to( @@ -317,14 +331,27 @@ class VideoFromComponents(VideoInput): def __init__(self, components: VideoComponents): self.__components = components + self._operations: list[VideoOp] = [] def get_components(self) -> VideoComponents: + if self._operations: + components = self.__components + for op in self._operations: + components = op.apply(components) + self.__components = components + self._operations = [] return VideoComponents( images=self.__components.images, audio=self.__components.audio, frame_rate=self.__components.frame_rate ) + def get_frame_count(self) -> int: + count = int(self.__components.images.shape[0]) + for op in self._operations: + count = op.compute_frame_count(count) + return count + def save_to( self, path: str, @@ -332,6 +359,9 @@ class VideoFromComponents(VideoInput): codec: VideoCodec = VideoCodec.AUTO, metadata: Optional[dict] = None ): + # Materialize ops before saving + components = self.get_components() + if format != VideoContainer.AUTO and format != VideoContainer.MP4: raise ValueError("Only MP4 format is supported for now") if codec != VideoCodec.AUTO and codec != VideoCodec.H264: @@ -345,22 +375,22 @@ class VideoFromComponents(VideoInput): for key, value in metadata.items(): output.metadata[key] = json.dumps(value) - frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000) + frame_rate = Fraction(round(components.frame_rate * 1000), 1000) # Create a video stream video_stream = output.add_stream('h264', rate=frame_rate) - video_stream.width = self.__components.images.shape[2] - video_stream.height = self.__components.images.shape[1] + video_stream.width = components.images.shape[2] + video_stream.height = components.images.shape[1] video_stream.pix_fmt = 'yuv420p' # Create an audio stream audio_sample_rate = 1 audio_stream: Optional[av.AudioStream] = None - if self.__components.audio: - audio_sample_rate = int(self.__components.audio['sample_rate']) + if components.audio: + audio_sample_rate = int(components.audio['sample_rate']) audio_stream = output.add_stream('aac', rate=audio_sample_rate) # Encode video - for i, frame in enumerate(self.__components.images): + for i, frame in enumerate(components.images): img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3) frame = av.VideoFrame.from_ndarray(img, format='rgb24') frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264 @@ -371,9 +401,9 @@ class VideoFromComponents(VideoInput): packet = video_stream.encode(None) output.mux(packet) - if audio_stream and self.__components.audio: - waveform = self.__components.audio['waveform'] - waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])] + if audio_stream and components.audio: + waveform = components.audio['waveform'] + waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * components.images.shape[0])] frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo') frame.sample_rate = audio_sample_rate frame.pts = 0 diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index c609e03da..dc3a77be7 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -159,6 +159,29 @@ class GetVideoComponents(io.ComfyNode): return io.NodeOutput(components.images, components.audio, float(components.frame_rate)) +class VideoSlice(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VideoSlice", + display_name="Video Slice", + category="image/video", + description="Extract a range of frames from a video.", + inputs=[ + io.Video.Input("video", tooltip="The video to slice."), + io.Int.Input("start_frame", default=0, min=0, tooltip="The frame index to start from (0-indexed)."), + io.Int.Input("frame_count", default=1, min=1, tooltip="Number of frames to extract."), + ], + outputs=[ + io.Video.Output(tooltip="The sliced video."), + ], + ) + + @classmethod + def execute(cls, video: Input.Video, start_frame: int, frame_count: int) -> io.NodeOutput: + return io.NodeOutput(video.sliced(start_frame, frame_count)) + + class LoadVideo(io.ComfyNode): @classmethod def define_schema(cls): @@ -206,6 +229,7 @@ class VideoExtension(ComfyExtension): SaveVideo, CreateVideo, GetVideoComponents, + VideoSlice, LoadVideo, ] diff --git a/tests-unit/comfy_api_test/video_slice_test.py b/tests-unit/comfy_api_test/video_slice_test.py new file mode 100644 index 000000000..b6d97ef3a --- /dev/null +++ b/tests-unit/comfy_api_test/video_slice_test.py @@ -0,0 +1,150 @@ +import pytest +import torch +import tempfile +import os +import av +from fractions import Fraction +from comfy_api.input_impl.video_types import ( + VideoFromFile, + VideoFromComponents, + SliceOp, +) +from comfy_api.util.video_types import VideoComponents + + +def create_test_video(width=4, height=4, frames=10, fps=30): + """Helper to create a temporary video file.""" + tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) + with av.open(tmp.name, mode="w") as container: + stream = container.add_stream("h264", rate=fps) + stream.width = width + stream.height = height + stream.pix_fmt = "yuv420p" + + for i in range(frames): + frame_data = torch.ones(height, width, 3, dtype=torch.uint8) * (i * 25) + frame = av.VideoFrame.from_ndarray(frame_data.numpy(), format="rgb24") + frame = frame.reformat(format="yuv420p") + packet = stream.encode(frame) + container.mux(packet) + + packet = stream.encode(None) + container.mux(packet) + + return tmp.name + + +@pytest.fixture +def video_file_10_frames(): + file_path = create_test_video(frames=10) + yield file_path + os.unlink(file_path) + + +@pytest.fixture +def video_components_10_frames(): + images = torch.rand(10, 4, 4, 3) + return VideoComponents(images=images, frame_rate=Fraction(30)) + + +class TestSliceOp: + def test_apply_slices_correctly(self, video_components_10_frames): + op = SliceOp(start_frame=2, frame_count=3) + result = op.apply(video_components_10_frames) + + assert result.images.shape[0] == 3 + assert torch.equal(result.images, video_components_10_frames.images[2:5]) + + def test_compute_frame_count(self): + op = SliceOp(start_frame=2, frame_count=5) + assert op.compute_frame_count(10) == 5 + + def test_compute_frame_count_clamps(self): + op = SliceOp(start_frame=8, frame_count=5) + assert op.compute_frame_count(10) == 2 + + +class TestVideoSliced: + def test_sliced_returns_new_instance(self, video_components_10_frames): + video = VideoFromComponents(video_components_10_frames) + sliced = video.sliced(2, 3) + + assert video is not sliced + assert len(video._operations) == 0 + assert len(sliced._operations) == 1 + + def test_get_components_applies_operations(self, video_components_10_frames): + video = VideoFromComponents(video_components_10_frames) + sliced = video.sliced(2, 3) + + components = sliced.get_components() + + assert components.images.shape[0] == 3 + assert torch.equal(components.images, video_components_10_frames.images[2:5]) + + def test_get_frame_count(self, video_components_10_frames): + video = VideoFromComponents(video_components_10_frames) + sliced = video.sliced(2, 3) + + assert sliced.get_frame_count() == 3 + + def test_get_duration(self, video_components_10_frames): + video = VideoFromComponents(video_components_10_frames) + sliced = video.sliced(0, 3) + + assert sliced.get_duration() == pytest.approx(0.1) + + def test_chained_slices_compose(self, video_components_10_frames): + video = VideoFromComponents(video_components_10_frames) + sliced = video.sliced(2, 6).sliced(1, 3) + + components = sliced.get_components() + + assert components.images.shape[0] == 3 + assert torch.equal(components.images, video_components_10_frames.images[3:6]) + + def test_operations_list_is_immutable(self, video_components_10_frames): + video = VideoFromComponents(video_components_10_frames) + sliced1 = video.sliced(0, 5) + sliced2 = sliced1.sliced(1, 2) + + assert len(video._operations) == 0 + assert len(sliced1._operations) == 1 + assert len(sliced2._operations) == 2 + + def test_from_file(self, video_file_10_frames): + video = VideoFromFile(video_file_10_frames) + sliced = video.sliced(2, 3) + + components = sliced.get_components() + + assert components.images.shape[0] == 3 + assert sliced.get_frame_count() == 3 + + def test_save_sliced_video(self, video_components_10_frames, tmp_path): + video = VideoFromComponents(video_components_10_frames) + sliced = video.sliced(2, 3) + + output_path = str(tmp_path / "sliced_output.mp4") + sliced.save_to(output_path) + + saved_video = VideoFromFile(output_path) + assert saved_video.get_frame_count() == 3 + + def test_materialization_clears_ops(self, video_components_10_frames): + video = VideoFromComponents(video_components_10_frames) + sliced = video.sliced(2, 3) + + assert len(sliced._operations) == 1 + sliced.get_components() + assert len(sliced._operations) == 0 + + def test_second_get_components_uses_cache(self, video_components_10_frames): + video = VideoFromComponents(video_components_10_frames) + sliced = video.sliced(2, 3) + + first = sliced.get_components() + second = sliced.get_components() + + assert first.images.shape == second.images.shape + assert torch.equal(first.images, second.images)