mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-26 19:19:53 +00:00
feat(preview): add per-queue live preview method override (#11261)
- Add set_preview_method() to override live preview method per queue item - Read extra_data.preview_method from /prompt request - Support values: taesd, latent2rgb, none, auto, default - "default" or unset uses server's CLI --preview-method setting - Add 44 tests (37 unit + 7 E2E)
This commit is contained in:
358
tests/execution/test_preview_method.py
Normal file
358
tests/execution/test_preview_method.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""
|
||||
E2E tests for Queue-specific Preview Method Override feature.
|
||||
|
||||
Tests actual execution with different preview_method values.
|
||||
Requires a running ComfyUI server with models.
|
||||
|
||||
Usage:
|
||||
COMFYUI_SERVER=http://localhost:8988 pytest test_preview_method_e2e.py -v -m preview_method
|
||||
|
||||
Note:
|
||||
These tests execute actual image generation and wait for completion.
|
||||
Tests verify preview image transmission based on preview_method setting.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import pytest
|
||||
import uuid
|
||||
import time
|
||||
import random
|
||||
import websocket
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# Server configuration
|
||||
SERVER_URL = os.environ.get("COMFYUI_SERVER", "http://localhost:8988")
|
||||
SERVER_HOST = SERVER_URL.replace("http://", "").replace("https://", "")
|
||||
|
||||
# Use existing inference graph fixture
|
||||
GRAPH_FILE = Path(__file__).parent.parent / "inference" / "graphs" / "default_graph_sdxl1_0.json"
|
||||
|
||||
|
||||
def is_server_running() -> bool:
|
||||
"""Check if ComfyUI server is running."""
|
||||
try:
|
||||
request = urllib.request.Request(f"{SERVER_URL}/system_stats")
|
||||
with urllib.request.urlopen(request, timeout=2.0):
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def prepare_graph_for_test(graph: dict, steps: int = 5) -> dict:
|
||||
"""Prepare graph for testing: randomize seeds and reduce steps."""
|
||||
adapted = json.loads(json.dumps(graph)) # Deep copy
|
||||
for node_id, node in adapted.items():
|
||||
inputs = node.get("inputs", {})
|
||||
# Handle both "seed" and "noise_seed" (used by KSamplerAdvanced)
|
||||
if "seed" in inputs:
|
||||
inputs["seed"] = random.randint(0, 2**32 - 1)
|
||||
if "noise_seed" in inputs:
|
||||
inputs["noise_seed"] = random.randint(0, 2**32 - 1)
|
||||
# Reduce steps for faster testing (default 20 -> 5)
|
||||
if "steps" in inputs:
|
||||
inputs["steps"] = steps
|
||||
return adapted
|
||||
|
||||
|
||||
# Alias for backward compatibility
|
||||
randomize_seed = prepare_graph_for_test
|
||||
|
||||
|
||||
class PreviewMethodClient:
|
||||
"""Client for testing preview_method with WebSocket execution tracking."""
|
||||
|
||||
def __init__(self, server_address: str):
|
||||
self.server_address = server_address
|
||||
self.client_id = str(uuid.uuid4())
|
||||
self.ws = None
|
||||
|
||||
def connect(self):
|
||||
"""Connect to WebSocket."""
|
||||
self.ws = websocket.WebSocket()
|
||||
self.ws.settimeout(120) # 2 minute timeout for sampling
|
||||
self.ws.connect(f"ws://{self.server_address}/ws?clientId={self.client_id}")
|
||||
|
||||
def close(self):
|
||||
"""Close WebSocket connection."""
|
||||
if self.ws:
|
||||
self.ws.close()
|
||||
|
||||
def queue_prompt(self, prompt: dict, extra_data: dict = None) -> dict:
|
||||
"""Queue a prompt and return response with prompt_id."""
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"client_id": self.client_id,
|
||||
"extra_data": extra_data or {}
|
||||
}
|
||||
req = urllib.request.Request(
|
||||
f"http://{self.server_address}/prompt",
|
||||
data=json.dumps(data).encode("utf-8"),
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
return json.loads(urllib.request.urlopen(req).read())
|
||||
|
||||
def wait_for_execution(self, prompt_id: str, timeout: float = 120.0) -> dict:
|
||||
"""
|
||||
Wait for execution to complete via WebSocket.
|
||||
|
||||
Returns:
|
||||
dict with keys: completed, error, preview_count, execution_time
|
||||
"""
|
||||
result = {
|
||||
"completed": False,
|
||||
"error": None,
|
||||
"preview_count": 0,
|
||||
"execution_time": 0.0
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
self.ws.settimeout(timeout)
|
||||
|
||||
try:
|
||||
while True:
|
||||
out = self.ws.recv()
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
if isinstance(out, str):
|
||||
message = json.loads(out)
|
||||
msg_type = message.get("type")
|
||||
data = message.get("data", {})
|
||||
|
||||
if data.get("prompt_id") != prompt_id:
|
||||
continue
|
||||
|
||||
if msg_type == "executing":
|
||||
if data.get("node") is None:
|
||||
# Execution complete
|
||||
result["completed"] = True
|
||||
result["execution_time"] = elapsed
|
||||
break
|
||||
|
||||
elif msg_type == "execution_error":
|
||||
result["error"] = data
|
||||
result["execution_time"] = elapsed
|
||||
break
|
||||
|
||||
elif msg_type == "progress":
|
||||
# Progress update during sampling
|
||||
pass
|
||||
|
||||
elif isinstance(out, bytes):
|
||||
# Binary data = preview image
|
||||
result["preview_count"] += 1
|
||||
|
||||
except websocket.WebSocketTimeoutException:
|
||||
result["error"] = "Timeout waiting for execution"
|
||||
result["execution_time"] = time.time() - start_time
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def load_graph() -> dict:
|
||||
"""Load the SDXL graph fixture with randomized seed."""
|
||||
with open(GRAPH_FILE) as f:
|
||||
graph = json.load(f)
|
||||
return randomize_seed(graph) # Avoid caching
|
||||
|
||||
|
||||
# Skip all tests if server is not running
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(
|
||||
not is_server_running(),
|
||||
reason=f"ComfyUI server not running at {SERVER_URL}"
|
||||
),
|
||||
pytest.mark.preview_method,
|
||||
pytest.mark.execution,
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create and connect a test client."""
|
||||
c = PreviewMethodClient(SERVER_HOST)
|
||||
c.connect()
|
||||
yield c
|
||||
c.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph():
|
||||
"""Load the test graph."""
|
||||
return load_graph()
|
||||
|
||||
|
||||
class TestPreviewMethodExecution:
|
||||
"""Test actual execution with different preview methods."""
|
||||
|
||||
def test_execution_with_latent2rgb(self, client, graph):
|
||||
"""
|
||||
Execute with preview_method=latent2rgb.
|
||||
Should complete and potentially receive preview images.
|
||||
"""
|
||||
extra_data = {"preview_method": "latent2rgb"}
|
||||
|
||||
response = client.queue_prompt(graph, extra_data)
|
||||
assert "prompt_id" in response
|
||||
|
||||
result = client.wait_for_execution(response["prompt_id"])
|
||||
|
||||
# Should complete (may error if model missing, but that's separate)
|
||||
assert result["completed"] or result["error"] is not None
|
||||
# Execution should take some time (sampling)
|
||||
if result["completed"]:
|
||||
assert result["execution_time"] > 0.5, "Execution too fast - likely didn't run"
|
||||
# latent2rgb should produce previews
|
||||
print(f"latent2rgb: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
||||
|
||||
def test_execution_with_taesd(self, client, graph):
|
||||
"""
|
||||
Execute with preview_method=taesd.
|
||||
TAESD provides higher quality previews.
|
||||
"""
|
||||
extra_data = {"preview_method": "taesd"}
|
||||
|
||||
response = client.queue_prompt(graph, extra_data)
|
||||
assert "prompt_id" in response
|
||||
|
||||
result = client.wait_for_execution(response["prompt_id"])
|
||||
|
||||
assert result["completed"] or result["error"] is not None
|
||||
if result["completed"]:
|
||||
assert result["execution_time"] > 0.5
|
||||
# taesd should also produce previews
|
||||
print(f"taesd: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
||||
|
||||
def test_execution_with_none_preview(self, client, graph):
|
||||
"""
|
||||
Execute with preview_method=none.
|
||||
No preview images should be generated.
|
||||
"""
|
||||
extra_data = {"preview_method": "none"}
|
||||
|
||||
response = client.queue_prompt(graph, extra_data)
|
||||
assert "prompt_id" in response
|
||||
|
||||
result = client.wait_for_execution(response["prompt_id"])
|
||||
|
||||
assert result["completed"] or result["error"] is not None
|
||||
if result["completed"]:
|
||||
# With "none", should receive no preview images
|
||||
assert result["preview_count"] == 0, \
|
||||
f"Expected no previews with 'none', got {result['preview_count']}"
|
||||
print(f"none: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
||||
|
||||
def test_execution_with_default(self, client, graph):
|
||||
"""
|
||||
Execute with preview_method=default.
|
||||
Should use server's CLI default setting.
|
||||
"""
|
||||
extra_data = {"preview_method": "default"}
|
||||
|
||||
response = client.queue_prompt(graph, extra_data)
|
||||
assert "prompt_id" in response
|
||||
|
||||
result = client.wait_for_execution(response["prompt_id"])
|
||||
|
||||
assert result["completed"] or result["error"] is not None
|
||||
if result["completed"]:
|
||||
print(f"default: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
||||
|
||||
def test_execution_without_preview_method(self, client, graph):
|
||||
"""
|
||||
Execute without preview_method in extra_data.
|
||||
Should use server's default preview method.
|
||||
"""
|
||||
extra_data = {} # No preview_method
|
||||
|
||||
response = client.queue_prompt(graph, extra_data)
|
||||
assert "prompt_id" in response
|
||||
|
||||
result = client.wait_for_execution(response["prompt_id"])
|
||||
|
||||
assert result["completed"] or result["error"] is not None
|
||||
if result["completed"]:
|
||||
print(f"(no override): {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
||||
|
||||
|
||||
class TestPreviewMethodComparison:
|
||||
"""Compare preview behavior between different methods."""
|
||||
|
||||
def test_none_vs_latent2rgb_preview_count(self, client, graph):
|
||||
"""
|
||||
Compare preview counts: 'none' should have 0, others should have >0.
|
||||
This is the key verification that preview_method actually works.
|
||||
"""
|
||||
results = {}
|
||||
|
||||
# Run with none (randomize seed to avoid caching)
|
||||
graph_none = randomize_seed(graph)
|
||||
extra_data_none = {"preview_method": "none"}
|
||||
response = client.queue_prompt(graph_none, extra_data_none)
|
||||
results["none"] = client.wait_for_execution(response["prompt_id"])
|
||||
|
||||
# Run with latent2rgb (randomize seed again)
|
||||
graph_rgb = randomize_seed(graph)
|
||||
extra_data_rgb = {"preview_method": "latent2rgb"}
|
||||
response = client.queue_prompt(graph_rgb, extra_data_rgb)
|
||||
results["latent2rgb"] = client.wait_for_execution(response["prompt_id"])
|
||||
|
||||
# Verify both completed
|
||||
assert results["none"]["completed"], f"'none' execution failed: {results['none']['error']}"
|
||||
assert results["latent2rgb"]["completed"], f"'latent2rgb' execution failed: {results['latent2rgb']['error']}"
|
||||
|
||||
# Key assertion: 'none' should have 0 previews
|
||||
assert results["none"]["preview_count"] == 0, \
|
||||
f"'none' should have 0 previews, got {results['none']['preview_count']}"
|
||||
|
||||
# 'latent2rgb' should have at least 1 preview (depends on steps)
|
||||
assert results["latent2rgb"]["preview_count"] > 0, \
|
||||
f"'latent2rgb' should have >0 previews, got {results['latent2rgb']['preview_count']}"
|
||||
|
||||
print("\nPreview count comparison:") # noqa: T201
|
||||
print(f" none: {results['none']['preview_count']} previews") # noqa: T201
|
||||
print(f" latent2rgb: {results['latent2rgb']['preview_count']} previews") # noqa: T201
|
||||
|
||||
|
||||
class TestPreviewMethodSequential:
|
||||
"""Test sequential execution with different preview methods."""
|
||||
|
||||
def test_sequential_different_methods(self, client, graph):
|
||||
"""
|
||||
Execute multiple prompts sequentially with different preview methods.
|
||||
Each should complete independently with correct preview behavior.
|
||||
"""
|
||||
methods = ["latent2rgb", "none", "default"]
|
||||
results = []
|
||||
|
||||
for method in methods:
|
||||
# Randomize seed for each execution to avoid caching
|
||||
graph_run = randomize_seed(graph)
|
||||
extra_data = {"preview_method": method}
|
||||
response = client.queue_prompt(graph_run, extra_data)
|
||||
|
||||
result = client.wait_for_execution(response["prompt_id"])
|
||||
results.append({
|
||||
"method": method,
|
||||
"completed": result["completed"],
|
||||
"preview_count": result["preview_count"],
|
||||
"execution_time": result["execution_time"],
|
||||
"error": result["error"]
|
||||
})
|
||||
|
||||
# All should complete or have clear errors
|
||||
for r in results:
|
||||
assert r["completed"] or r["error"] is not None, \
|
||||
f"Method {r['method']} neither completed nor errored"
|
||||
|
||||
# "none" should have zero previews if completed
|
||||
none_result = next(r for r in results if r["method"] == "none")
|
||||
if none_result["completed"]:
|
||||
assert none_result["preview_count"] == 0, \
|
||||
f"'none' should have 0 previews, got {none_result['preview_count']}"
|
||||
|
||||
print("\nSequential execution results:") # noqa: T201
|
||||
for r in results:
|
||||
status = "✓" if r["completed"] else f"✗ ({r['error']})"
|
||||
print(f" {r['method']}: {status}, {r['preview_count']} previews, {r['execution_time']:.2f}s") # noqa: T201
|
||||
Reference in New Issue
Block a user