mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-27 03:29:51 +00:00
- Add VideoOp base class and SliceOp in _input/video_types.py - Add sliced() method to VideoInput that returns a copy with operation appended - Each subclass applies operations in get_components() and get_frame_count() - After materialization, VideoFromFile delegates to internal VideoFromComponents - Add VideoSlice node that uses video.sliced(start_frame, frame_count) - Add tests for SliceOp, sliced() behavior, and materialization
151 lines
5.0 KiB
Python
151 lines
5.0 KiB
Python
import pytest
|
|
import torch
|
|
import tempfile
|
|
import os
|
|
import av
|
|
from fractions import Fraction
|
|
from comfy_api.input_impl.video_types import (
|
|
VideoFromFile,
|
|
VideoFromComponents,
|
|
SliceOp,
|
|
)
|
|
from comfy_api.util.video_types import VideoComponents
|
|
|
|
|
|
def create_test_video(width=4, height=4, frames=10, fps=30):
|
|
"""Helper to create a temporary video file."""
|
|
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
|
with av.open(tmp.name, mode="w") as container:
|
|
stream = container.add_stream("h264", rate=fps)
|
|
stream.width = width
|
|
stream.height = height
|
|
stream.pix_fmt = "yuv420p"
|
|
|
|
for i in range(frames):
|
|
frame_data = torch.ones(height, width, 3, dtype=torch.uint8) * (i * 25)
|
|
frame = av.VideoFrame.from_ndarray(frame_data.numpy(), format="rgb24")
|
|
frame = frame.reformat(format="yuv420p")
|
|
packet = stream.encode(frame)
|
|
container.mux(packet)
|
|
|
|
packet = stream.encode(None)
|
|
container.mux(packet)
|
|
|
|
return tmp.name
|
|
|
|
|
|
@pytest.fixture
|
|
def video_file_10_frames():
|
|
file_path = create_test_video(frames=10)
|
|
yield file_path
|
|
os.unlink(file_path)
|
|
|
|
|
|
@pytest.fixture
|
|
def video_components_10_frames():
|
|
images = torch.rand(10, 4, 4, 3)
|
|
return VideoComponents(images=images, frame_rate=Fraction(30))
|
|
|
|
|
|
class TestSliceOp:
|
|
def test_apply_slices_correctly(self, video_components_10_frames):
|
|
op = SliceOp(start_frame=2, frame_count=3)
|
|
result = op.apply(video_components_10_frames)
|
|
|
|
assert result.images.shape[0] == 3
|
|
assert torch.equal(result.images, video_components_10_frames.images[2:5])
|
|
|
|
def test_compute_frame_count(self):
|
|
op = SliceOp(start_frame=2, frame_count=5)
|
|
assert op.compute_frame_count(10) == 5
|
|
|
|
def test_compute_frame_count_clamps(self):
|
|
op = SliceOp(start_frame=8, frame_count=5)
|
|
assert op.compute_frame_count(10) == 2
|
|
|
|
|
|
class TestVideoSliced:
|
|
def test_sliced_returns_new_instance(self, video_components_10_frames):
|
|
video = VideoFromComponents(video_components_10_frames)
|
|
sliced = video.sliced(2, 3)
|
|
|
|
assert video is not sliced
|
|
assert len(video._operations) == 0
|
|
assert len(sliced._operations) == 1
|
|
|
|
def test_get_components_applies_operations(self, video_components_10_frames):
|
|
video = VideoFromComponents(video_components_10_frames)
|
|
sliced = video.sliced(2, 3)
|
|
|
|
components = sliced.get_components()
|
|
|
|
assert components.images.shape[0] == 3
|
|
assert torch.equal(components.images, video_components_10_frames.images[2:5])
|
|
|
|
def test_get_frame_count(self, video_components_10_frames):
|
|
video = VideoFromComponents(video_components_10_frames)
|
|
sliced = video.sliced(2, 3)
|
|
|
|
assert sliced.get_frame_count() == 3
|
|
|
|
def test_get_duration(self, video_components_10_frames):
|
|
video = VideoFromComponents(video_components_10_frames)
|
|
sliced = video.sliced(0, 3)
|
|
|
|
assert sliced.get_duration() == pytest.approx(0.1)
|
|
|
|
def test_chained_slices_compose(self, video_components_10_frames):
|
|
video = VideoFromComponents(video_components_10_frames)
|
|
sliced = video.sliced(2, 6).sliced(1, 3)
|
|
|
|
components = sliced.get_components()
|
|
|
|
assert components.images.shape[0] == 3
|
|
assert torch.equal(components.images, video_components_10_frames.images[3:6])
|
|
|
|
def test_operations_list_is_immutable(self, video_components_10_frames):
|
|
video = VideoFromComponents(video_components_10_frames)
|
|
sliced1 = video.sliced(0, 5)
|
|
sliced2 = sliced1.sliced(1, 2)
|
|
|
|
assert len(video._operations) == 0
|
|
assert len(sliced1._operations) == 1
|
|
assert len(sliced2._operations) == 2
|
|
|
|
def test_from_file(self, video_file_10_frames):
|
|
video = VideoFromFile(video_file_10_frames)
|
|
sliced = video.sliced(2, 3)
|
|
|
|
components = sliced.get_components()
|
|
|
|
assert components.images.shape[0] == 3
|
|
assert sliced.get_frame_count() == 3
|
|
|
|
def test_save_sliced_video(self, video_components_10_frames, tmp_path):
|
|
video = VideoFromComponents(video_components_10_frames)
|
|
sliced = video.sliced(2, 3)
|
|
|
|
output_path = str(tmp_path / "sliced_output.mp4")
|
|
sliced.save_to(output_path)
|
|
|
|
saved_video = VideoFromFile(output_path)
|
|
assert saved_video.get_frame_count() == 3
|
|
|
|
def test_materialization_clears_ops(self, video_components_10_frames):
|
|
video = VideoFromComponents(video_components_10_frames)
|
|
sliced = video.sliced(2, 3)
|
|
|
|
assert len(sliced._operations) == 1
|
|
sliced.get_components()
|
|
assert len(sliced._operations) == 0
|
|
|
|
def test_second_get_components_uses_cache(self, video_components_10_frames):
|
|
video = VideoFromComponents(video_components_10_frames)
|
|
sliced = video.sliced(2, 3)
|
|
|
|
first = sliced.get_components()
|
|
second = sliced.get_components()
|
|
|
|
assert first.images.shape == second.images.shape
|
|
assert torch.equal(first.images, second.images)
|