mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-12 08:40:02 +00:00
feat: Add VideoSlice node with lazy operations on VideoInput
- Add VideoOp base class and SliceOp in _input/video_types.py - Add sliced() method to VideoInput that returns a copy with operation appended - Each subclass applies operations in get_components() and get_frame_count() - After materialization, VideoFromFile delegates to internal VideoFromComponents - Add VideoSlice node that uses video.sliced(start_frame, frame_count) - Add tests for SliceOp, sliced() behavior, and materialization
This commit is contained in:
150
tests-unit/comfy_api_test/video_slice_test.py
Normal file
150
tests-unit/comfy_api_test/video_slice_test.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import pytest
|
||||
import torch
|
||||
import tempfile
|
||||
import os
|
||||
import av
|
||||
from fractions import Fraction
|
||||
from comfy_api.input_impl.video_types import (
|
||||
VideoFromFile,
|
||||
VideoFromComponents,
|
||||
SliceOp,
|
||||
)
|
||||
from comfy_api.util.video_types import VideoComponents
|
||||
|
||||
|
||||
def create_test_video(width=4, height=4, frames=10, fps=30):
|
||||
"""Helper to create a temporary video file."""
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
||||
with av.open(tmp.name, mode="w") as container:
|
||||
stream = container.add_stream("h264", rate=fps)
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
for i in range(frames):
|
||||
frame_data = torch.ones(height, width, 3, dtype=torch.uint8) * (i * 25)
|
||||
frame = av.VideoFrame.from_ndarray(frame_data.numpy(), format="rgb24")
|
||||
frame = frame.reformat(format="yuv420p")
|
||||
packet = stream.encode(frame)
|
||||
container.mux(packet)
|
||||
|
||||
packet = stream.encode(None)
|
||||
container.mux(packet)
|
||||
|
||||
return tmp.name
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def video_file_10_frames():
|
||||
file_path = create_test_video(frames=10)
|
||||
yield file_path
|
||||
os.unlink(file_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def video_components_10_frames():
|
||||
images = torch.rand(10, 4, 4, 3)
|
||||
return VideoComponents(images=images, frame_rate=Fraction(30))
|
||||
|
||||
|
||||
class TestSliceOp:
|
||||
def test_apply_slices_correctly(self, video_components_10_frames):
|
||||
op = SliceOp(start_frame=2, frame_count=3)
|
||||
result = op.apply(video_components_10_frames)
|
||||
|
||||
assert result.images.shape[0] == 3
|
||||
assert torch.equal(result.images, video_components_10_frames.images[2:5])
|
||||
|
||||
def test_compute_frame_count(self):
|
||||
op = SliceOp(start_frame=2, frame_count=5)
|
||||
assert op.compute_frame_count(10) == 5
|
||||
|
||||
def test_compute_frame_count_clamps(self):
|
||||
op = SliceOp(start_frame=8, frame_count=5)
|
||||
assert op.compute_frame_count(10) == 2
|
||||
|
||||
|
||||
class TestVideoSliced:
|
||||
def test_sliced_returns_new_instance(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
assert video is not sliced
|
||||
assert len(video._operations) == 0
|
||||
assert len(sliced._operations) == 1
|
||||
|
||||
def test_get_components_applies_operations(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
components = sliced.get_components()
|
||||
|
||||
assert components.images.shape[0] == 3
|
||||
assert torch.equal(components.images, video_components_10_frames.images[2:5])
|
||||
|
||||
def test_get_frame_count(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
assert sliced.get_frame_count() == 3
|
||||
|
||||
def test_get_duration(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(0, 3)
|
||||
|
||||
assert sliced.get_duration() == pytest.approx(0.1)
|
||||
|
||||
def test_chained_slices_compose(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 6).sliced(1, 3)
|
||||
|
||||
components = sliced.get_components()
|
||||
|
||||
assert components.images.shape[0] == 3
|
||||
assert torch.equal(components.images, video_components_10_frames.images[3:6])
|
||||
|
||||
def test_operations_list_is_immutable(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced1 = video.sliced(0, 5)
|
||||
sliced2 = sliced1.sliced(1, 2)
|
||||
|
||||
assert len(video._operations) == 0
|
||||
assert len(sliced1._operations) == 1
|
||||
assert len(sliced2._operations) == 2
|
||||
|
||||
def test_from_file(self, video_file_10_frames):
|
||||
video = VideoFromFile(video_file_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
components = sliced.get_components()
|
||||
|
||||
assert components.images.shape[0] == 3
|
||||
assert sliced.get_frame_count() == 3
|
||||
|
||||
def test_save_sliced_video(self, video_components_10_frames, tmp_path):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
output_path = str(tmp_path / "sliced_output.mp4")
|
||||
sliced.save_to(output_path)
|
||||
|
||||
saved_video = VideoFromFile(output_path)
|
||||
assert saved_video.get_frame_count() == 3
|
||||
|
||||
def test_materialization_clears_ops(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
assert len(sliced._operations) == 1
|
||||
sliced.get_components()
|
||||
assert len(sliced._operations) == 0
|
||||
|
||||
def test_second_get_components_uses_cache(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
first = sliced.get_components()
|
||||
second = sliced.get_components()
|
||||
|
||||
assert first.images.shape == second.images.shape
|
||||
assert torch.equal(first.images, second.images)
|
||||
Reference in New Issue
Block a user