Files
ComfyUI/tests-unit/comfy_api_test/video_slice_test.py
bymyself e4f3d335dc feat: Add VideoSlice node with lazy operations on VideoInput
- Add VideoOp base class and SliceOp in _input/video_types.py
- Add sliced() method to VideoInput that returns a copy with operation appended
- Each subclass applies operations in get_components() and get_frame_count()
- After materialization, VideoFromFile delegates to internal VideoFromComponents
- Add VideoSlice node that uses video.sliced(start_frame, frame_count)
- Add tests for SliceOp, sliced() behavior, and materialization
2026-01-23 20:52:15 -08:00

151 lines
5.0 KiB
Python

import pytest
import torch
import tempfile
import os
import av
from fractions import Fraction
from comfy_api.input_impl.video_types import (
VideoFromFile,
VideoFromComponents,
SliceOp,
)
from comfy_api.util.video_types import VideoComponents
def create_test_video(width=4, height=4, frames=10, fps=30):
"""Helper to create a temporary video file."""
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
with av.open(tmp.name, mode="w") as container:
stream = container.add_stream("h264", rate=fps)
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
for i in range(frames):
frame_data = torch.ones(height, width, 3, dtype=torch.uint8) * (i * 25)
frame = av.VideoFrame.from_ndarray(frame_data.numpy(), format="rgb24")
frame = frame.reformat(format="yuv420p")
packet = stream.encode(frame)
container.mux(packet)
packet = stream.encode(None)
container.mux(packet)
return tmp.name
@pytest.fixture
def video_file_10_frames():
file_path = create_test_video(frames=10)
yield file_path
os.unlink(file_path)
@pytest.fixture
def video_components_10_frames():
images = torch.rand(10, 4, 4, 3)
return VideoComponents(images=images, frame_rate=Fraction(30))
class TestSliceOp:
def test_apply_slices_correctly(self, video_components_10_frames):
op = SliceOp(start_frame=2, frame_count=3)
result = op.apply(video_components_10_frames)
assert result.images.shape[0] == 3
assert torch.equal(result.images, video_components_10_frames.images[2:5])
def test_compute_frame_count(self):
op = SliceOp(start_frame=2, frame_count=5)
assert op.compute_frame_count(10) == 5
def test_compute_frame_count_clamps(self):
op = SliceOp(start_frame=8, frame_count=5)
assert op.compute_frame_count(10) == 2
class TestVideoSliced:
def test_sliced_returns_new_instance(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
assert video is not sliced
assert len(video._operations) == 0
assert len(sliced._operations) == 1
def test_get_components_applies_operations(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
components = sliced.get_components()
assert components.images.shape[0] == 3
assert torch.equal(components.images, video_components_10_frames.images[2:5])
def test_get_frame_count(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
assert sliced.get_frame_count() == 3
def test_get_duration(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(0, 3)
assert sliced.get_duration() == pytest.approx(0.1)
def test_chained_slices_compose(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 6).sliced(1, 3)
components = sliced.get_components()
assert components.images.shape[0] == 3
assert torch.equal(components.images, video_components_10_frames.images[3:6])
def test_operations_list_is_immutable(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced1 = video.sliced(0, 5)
sliced2 = sliced1.sliced(1, 2)
assert len(video._operations) == 0
assert len(sliced1._operations) == 1
assert len(sliced2._operations) == 2
def test_from_file(self, video_file_10_frames):
video = VideoFromFile(video_file_10_frames)
sliced = video.sliced(2, 3)
components = sliced.get_components()
assert components.images.shape[0] == 3
assert sliced.get_frame_count() == 3
def test_save_sliced_video(self, video_components_10_frames, tmp_path):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
output_path = str(tmp_path / "sliced_output.mp4")
sliced.save_to(output_path)
saved_video = VideoFromFile(output_path)
assert saved_video.get_frame_count() == 3
def test_materialization_clears_ops(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
assert len(sliced._operations) == 1
sliced.get_components()
assert len(sliced._operations) == 0
def test_second_get_components_uses_cache(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
first = sliced.get_components()
second = sliced.get_components()
assert first.images.shape == second.images.shape
assert torch.equal(first.images, second.images)