"""Cross-framework comparison benchmark for diffusion serving. Launches servers (SGLang, vLLM-Omni, LightX2V) for each test case, sends a single request, measures end-to-end latency, and writes comparison-results.json. Usage: # Full run (requires GPU) python3 scripts/ci/utils/diffusion/run_comparison.py # Dry-run (config parsing + command preview only) python3 scripts/ci/utils/diffusion/run_comparison.py --dry-run # Run only specific case(s) python3 scripts/ci/utils/diffusion/run_comparison.py --case-ids flux1_dev_t2i_1024 # Run only specific framework(s) python3 scripts/ci/utils/diffusion/run_comparison.py --frameworks sglang """ import argparse import base64 import io import json import os import signal import subprocess import sys import tempfile import threading import time from datetime import datetime, timezone from pathlib import Path import requests # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- CONFIGS_PATH = Path(__file__).parent / "comparison_configs.json" INSTALL_SCRIPT = Path(__file__).parents[1] / "install_comparison_frameworks.sh" DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 30000 HEALTH_TIMEOUT = ( 2400 # seconds (40 min — FLUX.2-dev needs ~10 min download + torch.compile) ) REQUEST_TIMEOUT = 1200 # seconds GPU_CLEAR_WAIT = 15 # seconds between framework runs # Frameworks that need separate installation (conflict with sglang's deps) INSTALLABLE_FRAMEWORKS = {"vllm-omni", "lightx2v"} # Cached reference image (downloaded once) _cached_ref_image: bytes | None = None _cached_ref_image_path: str | None = None # --------------------------------------------------------------------------- # Server lifecycle — command builders # --------------------------------------------------------------------------- def _build_sglang_cmd(case: dict, fw_cfg: dict, port: int) -> list[str]: cmd = [ "sglang", "serve", "--model-path", case["model"], "--port", str(port), "--host", DEFAULT_HOST, ] if case["num_gpus"] > 1: cmd += ["--num-gpus", str(case["num_gpus"])] if fw_cfg.get("serve_args", "").strip(): cmd += fw_cfg["serve_args"].strip().split() return cmd def _build_vllm_cmd(case: dict, fw_cfg: dict, port: int) -> list[str]: cmd = [ "vllm", "serve", case["model"], "--omni", "--port", str(port), "--host", DEFAULT_HOST, ] if fw_cfg.get("serve_args", "").strip(): cmd += fw_cfg["serve_args"].strip().split() return cmd def _resolve_hf_model_path(model_id: str) -> str: """Resolve a HuggingFace model ID to a local cache path, or return as-is.""" if os.path.isdir(model_id): return model_id try: from huggingface_hub import snapshot_download path = snapshot_download(model_id) print(f" Resolved {model_id} -> {path}") return path except Exception: return model_id def _write_lightx2v_config(case: dict) -> str: """Write a minimal LightX2V config JSON and return its path.""" cfg = { "infer_steps": case.get("num_inference_steps", 50), "guidance_scale": case.get("guidance_scale", 4.0), "seed": case.get("seed", 42), } if "num_frames" in case: cfg["target_video_length"] = case["num_frames"] if "height" in case: cfg["height"] = case["height"] if "width" in case: cfg["width"] = case["width"] config_path = os.path.join( tempfile.gettempdir(), f"lightx2v_config_{case['id']}.json" ) with open(config_path, "w") as f: json.dump(cfg, f) return config_path def _build_lightx2v_cmd(case: dict, fw_cfg: dict, port: int) -> list[str]: """Build LightX2V server launch command. Single GPU: python -m lightx2v.server --model_path ... --model_cls ... --task ... --port ... Multi GPU: torchrun --nproc_per_node=N -m lightx2v.server ... LightX2V requires a local model path and a config JSON with infer params. """ model_cls = fw_cfg["model_cls"] task = fw_cfg["lightx2v_task"] num_gpus = case["num_gpus"] model_path = _resolve_hf_model_path(case["model"]) config_path = _write_lightx2v_config(case) server_args = [ "--model_path", model_path, "--model_cls", model_cls, "--task", task, "--config_json", config_path, "--host", DEFAULT_HOST, "--port", str(port), ] if fw_cfg.get("serve_args", "").strip(): server_args += fw_cfg["serve_args"].strip().split() if num_gpus > 1: cmd = [ "torchrun", f"--nproc_per_node={num_gpus}", "-m", "lightx2v.server", ] + server_args else: cmd = ["python3", "-m", "lightx2v.server"] + server_args return cmd def build_server_cmd(framework: str, case: dict, fw_cfg: dict, port: int) -> list[str]: builders = { "sglang": _build_sglang_cmd, "vllm-omni": _build_vllm_cmd, "lightx2v": _build_lightx2v_cmd, } builder = builders.get(framework) if builder is None: raise ValueError(f"Unknown framework: {framework}") return builder(case, fw_cfg, port) # --------------------------------------------------------------------------- # Server lifecycle — health check & cleanup # --------------------------------------------------------------------------- # Health check endpoints per framework HEALTH_ENDPOINTS = { "sglang": "/health", "vllm-omni": "/health", "lightx2v": "/v1/service/status", } def wait_for_health( base_url: str, framework: str = "sglang", timeout: int = HEALTH_TIMEOUT ) -> None: """Poll health endpoint until 200, then verify model is loaded.""" endpoint = HEALTH_ENDPOINTS.get(framework, "/health") health_url = f"{base_url}{endpoint}" print(f" Waiting for server at {health_url} ...") start = time.time() while True: try: resp = requests.get(health_url, timeout=2) if resp.status_code == 200: break except requests.exceptions.RequestException: pass if time.time() - start > timeout: raise TimeoutError( f"Server at {health_url} did not start within {timeout}s" ) time.sleep(2) # For SGLang, /health can return 200 before model routes are registered. # Poll /v1/models to confirm the model is fully loaded. if framework == "sglang": models_url = f"{base_url}/v1/models" while True: try: resp = requests.get(models_url, timeout=5) if resp.status_code == 200: break except requests.exceptions.RequestException: pass if time.time() - start > timeout: raise TimeoutError(f"Model at {models_url} not ready within {timeout}s") time.sleep(2) elapsed = time.time() - start print(f" Server ready in {elapsed:.1f}s") KILLALL_SCRIPT = Path(__file__).parents[3] / "killall_sglang.sh" def kill_server(proc: subprocess.Popen) -> None: """Kill server process tree and clean up GPU processes.""" if proc.poll() is not None: return try: os.killpg(os.getpgid(proc.pid), signal.SIGTERM) except (ProcessLookupError, PermissionError): pass try: proc.wait(timeout=30) except subprocess.TimeoutExpired: try: os.killpg(os.getpgid(proc.pid), signal.SIGKILL) except (ProcessLookupError, PermissionError): pass proc.wait(timeout=10) # Use killall_sglang.sh for thorough cleanup (esp. multi-GPU workers) if KILLALL_SCRIPT.exists(): subprocess.run( ["bash", str(KILLALL_SCRIPT)], timeout=30, capture_output=True, ) # --------------------------------------------------------------------------- # Reference image helpers # --------------------------------------------------------------------------- def _get_ref_image_bytes(config: dict) -> bytes: """Download and cache the shared test reference image.""" global _cached_ref_image if _cached_ref_image is not None: return _cached_ref_image url = config.get("test_image_url", "") if not url: raise RuntimeError("No test_image_url in config for image-conditioned case") print(f" Downloading reference image from {url} ...") resp = requests.get(url, timeout=60) resp.raise_for_status() _cached_ref_image = resp.content return _cached_ref_image def _get_ref_image_b64(config: dict) -> str: """Get reference image as base64 string.""" return base64.b64encode(_get_ref_image_bytes(config)).decode("utf-8") def _get_ref_image_path(config: dict) -> str: """Save reference image to a temp file and return path.""" global _cached_ref_image_path if _cached_ref_image_path and os.path.exists(_cached_ref_image_path): return _cached_ref_image_path data = _get_ref_image_bytes(config) fd, path = tempfile.mkstemp(suffix=".png") with os.fdopen(fd, "wb") as f: f.write(data) _cached_ref_image_path = path return path # --------------------------------------------------------------------------- # Request helpers — SGLang (OpenAI-compatible) # --------------------------------------------------------------------------- def _build_sglang_payload(case: dict) -> dict: """Build common SGLang request payload.""" payload = { "model": case["model"], "prompt": case["prompt"], "size": f"{case['width']}x{case['height']}", "n": 1, "response_format": "b64_json", } for key in ( "num_inference_steps", "guidance_scale", "seed", "num_frames", "fps", "negative_prompt", ): if key in case: payload[key] = case[key] return payload def _read_perf_dump(perf_dump_path: str, timeout: float = 10.0) -> float | None: """Read total_duration_ms from a perf dump JSON written by the server. The server writes the file asynchronously after the HTTP response, so we poll briefly. """ deadline = time.time() + timeout while time.time() < deadline: try: with open(perf_dump_path) as f: data = json.load(f) total_ms = data.get("total_duration_ms") if total_ms is not None: return total_ms / 1000.0 except (FileNotFoundError, json.JSONDecodeError): pass time.sleep(0.5) return None def send_image_request_sglang( base_url: str, case: dict, perf_dump_path: str | None = None ) -> float: """Send a single T2I request via SGLang's /v1/images/generations.""" payload = _build_sglang_payload(case) if perf_dump_path: payload["perf_dump_path"] = perf_dump_path start = time.time() resp = requests.post( f"{base_url}/v1/images/generations", json=payload, timeout=REQUEST_TIMEOUT, ) client_latency = time.time() - start resp.raise_for_status() data = resp.json() if "data" not in data or len(data["data"]) == 0: raise RuntimeError(f"Image request returned no data: {data}") if perf_dump_path: server_latency = _read_perf_dump(perf_dump_path) if server_latency is not None: print( f" Image generated in {server_latency:.2f}s (server-side), " f"client={client_latency:.2f}s" ) return server_latency print(f" Image generated in {client_latency:.2f}s") return client_latency def send_video_request_sglang( base_url: str, case: dict, perf_dump_path: str | None = None ) -> float: """Send a single T2V request via SGLang's /v1/videos (async).""" payload = _build_sglang_payload(case) if perf_dump_path: payload["perf_dump_path"] = perf_dump_path start = time.time() # Submit job resp = requests.post( f"{base_url}/v1/videos", json=payload, timeout=REQUEST_TIMEOUT, ) resp.raise_for_status() job = resp.json() job_id = job.get("id") if not job_id: raise RuntimeError(f"Video submit returned no job id: {job}") # Poll for completion poll_url = f"{base_url}/v1/videos/{job_id}" while True: time.sleep(1) poll_resp = requests.get(poll_url, timeout=30) poll_resp.raise_for_status() poll_data = poll_resp.json() status = poll_data.get("status") if status == "completed": break elif status == "failed": raise RuntimeError(f"Video generation failed: {poll_data}") if time.time() - start > REQUEST_TIMEOUT: raise TimeoutError(f"Video generation timed out after {REQUEST_TIMEOUT}s") client_latency = time.time() - start if perf_dump_path: server_latency = _read_perf_dump(perf_dump_path) if server_latency is not None: print( f" Video generated in {server_latency:.2f}s (server-side), " f"client={client_latency:.2f}s" ) return server_latency print(f" Video generated in {client_latency:.2f}s") return client_latency def send_image_conditioned_request_sglang( base_url: str, case: dict, config: dict, perf_dump_path: str | None = None ) -> float: """Send an image-conditioned request (edit/I2V/TI2V) via SGLang multipart API.""" task = case["task"] ref_bytes = _get_ref_image_bytes(config) # Build multipart form — field name depends on endpoint: # image edits use "image", video (I2V/TI2V) uses "input_reference" if task in ("image-to-video", "text-image-to-video"): file_field = "input_reference" else: file_field = "image" files = {file_field: ("ref.png", io.BytesIO(ref_bytes), "image/png")} data = { "model": case["model"], "prompt": case["prompt"], "size": f"{case['width']}x{case['height']}", "n": "1", "response_format": "b64_json", } for key in ( "num_inference_steps", "guidance_scale", "seed", "num_frames", "fps", "negative_prompt", ): if key in case: data[key] = str(case[key]) if perf_dump_path: data["perf_dump_path"] = perf_dump_path # Choose endpoint based on task if task in ("image-edit", "image-to-image"): endpoint = "/v1/images/edits" elif task in ("image-to-video", "text-image-to-video"): endpoint = "/v1/videos" else: endpoint = "/v1/images/generations" start = time.time() resp = requests.post( f"{base_url}{endpoint}", files=files, data=data, timeout=REQUEST_TIMEOUT, ) # For video endpoints, need to poll if task in ("image-to-video", "text-image-to-video"): resp.raise_for_status() job = resp.json() job_id = job.get("id") if not job_id: raise RuntimeError(f"Video submit returned no job id: {job}") poll_url = f"{base_url}/v1/videos/{job_id}" while True: time.sleep(1) poll_resp = requests.get(poll_url, timeout=30) poll_resp.raise_for_status() poll_data = poll_resp.json() status = poll_data.get("status") if status == "completed": break elif status == "failed": raise RuntimeError(f"Video generation failed: {poll_data}") if time.time() - start > REQUEST_TIMEOUT: raise TimeoutError(f"Timed out after {REQUEST_TIMEOUT}s") else: resp.raise_for_status() client_latency = time.time() - start if perf_dump_path: server_latency = _read_perf_dump(perf_dump_path) if server_latency is not None: print( f" Generated in {server_latency:.2f}s (server-side), " f"client={client_latency:.2f}s" ) return server_latency print(f" Generated in {client_latency:.2f}s (sglang, image-conditioned)") return client_latency # --------------------------------------------------------------------------- # Request helpers — vLLM-Omni # --------------------------------------------------------------------------- def send_request_vllm_omni(base_url: str, case: dict, config: dict) -> float: """Send request via vLLM-Omni's /v1/chat/completions endpoint.""" extra_body = { "height": case["height"], "width": case["width"], "num_inference_steps": case.get("num_inference_steps", 50), "guidance_scale": case.get("guidance_scale", 4.0), "seed": case.get("seed", 42), } if "num_frames" in case: extra_body["num_frames"] = case["num_frames"] if "fps" in case: extra_body["fps"] = case["fps"] if "negative_prompt" in case: extra_body["negative_prompt"] = case["negative_prompt"] # Build message content (text or text+image) content: list[dict] | str = case["prompt"] if case.get("reference_image"): ref_b64 = _get_ref_image_b64(config) content = [ { "type": "image_url", "image_url": {"url": f"data:image/png;base64,{ref_b64}"}, }, {"type": "text", "text": case["prompt"]}, ] payload = { "model": case["model"], "messages": [{"role": "user", "content": content}], "extra_body": extra_body, } start = time.time() resp = requests.post( f"{base_url}/v1/chat/completions", json=payload, timeout=REQUEST_TIMEOUT, ) latency = time.time() - start resp.raise_for_status() data = resp.json() choices = data.get("choices", []) if not choices: raise RuntimeError(f"vLLM-Omni request returned no choices: {data}") print(f" Generated in {latency:.2f}s (vllm-omni)") return latency # --------------------------------------------------------------------------- # Request helpers — LightX2V # --------------------------------------------------------------------------- def send_request_lightx2v(base_url: str, case: dict, config: dict) -> float: """Send request via LightX2V's async task API.""" task = case["task"] if task in ("text-to-image", "image-edit"): endpoint = "/v1/tasks/image" else: endpoint = "/v1/tasks/video" payload = { "prompt": case["prompt"], "seed": case.get("seed", 42), "infer_steps": case.get("num_inference_steps", 50), } # LightX2V uses target_video_length for frames, height/width directly if "num_frames" in case: payload["target_video_length"] = case["num_frames"] if "height" in case: payload["height"] = case["height"] if "width" in case: payload["width"] = case["width"] if "guidance_scale" in case: payload["guidance_scale"] = case["guidance_scale"] if "fps" in case: payload["fps"] = case["fps"] if "negative_prompt" in case: payload["negative_prompt"] = case["negative_prompt"] # Image-conditioned: LightX2V accepts image_path (URL or local path) if case.get("reference_image"): payload["image_path"] = config.get("test_image_url", "") start = time.time() # Submit task resp = requests.post( f"{base_url}{endpoint}", json=payload, timeout=REQUEST_TIMEOUT, ) resp.raise_for_status() task_data = resp.json() task_id = task_data.get("task_id") if not task_id: raise RuntimeError(f"LightX2V submit returned no task_id: {task_data}") # Poll for completion poll_url = f"{base_url}/v1/tasks/{task_id}/status" while True: time.sleep(1) poll_resp = requests.get(poll_url, timeout=30) poll_resp.raise_for_status() poll_data = poll_resp.json() status = poll_data.get("task_status", "").upper() if status == "COMPLETED": break elif status in ("FAILED", "CANCELLED"): raise RuntimeError(f"LightX2V task {status}: {poll_data}") if time.time() - start > REQUEST_TIMEOUT: raise TimeoutError(f"LightX2V task timed out after {REQUEST_TIMEOUT}s") latency = time.time() - start print(f" Generated in {latency:.2f}s (lightx2v)") return latency # --------------------------------------------------------------------------- # Unified request dispatcher # --------------------------------------------------------------------------- def send_request( base_url: str, case: dict, framework: str = "sglang", config: dict | None = None, perf_dump_path: str | None = None, ) -> float: config = config or {} if framework == "vllm-omni": return send_request_vllm_omni(base_url, case, config) elif framework == "lightx2v": return send_request_lightx2v(base_url, case, config) # SGLang — use OpenAI-compatible endpoints with optional perf log task = case["task"] if case.get("reference_image"): return send_image_conditioned_request_sglang( base_url, case, config, perf_dump_path ) elif task == "text-to-image": return send_image_request_sglang(base_url, case, perf_dump_path) elif task == "text-to-video": return send_video_request_sglang(base_url, case, perf_dump_path) else: raise ValueError(f"Unknown task type: {task}") # --------------------------------------------------------------------------- # Main orchestrator # --------------------------------------------------------------------------- def run_single( case: dict, framework: str, fw_cfg: dict, port: int, log_dir: Path, config: dict | None = None, ) -> dict: """Run a single (case, framework) combination. Returns result dict.""" result = { "case_id": case["id"], "framework": framework, "model": case["model"], "task": case["task"], "latency_s": None, "error": None, } cmd = build_server_cmd(framework, case, fw_cfg, port) print(f"\n Command: {' '.join(cmd)}") env = os.environ.copy() env.update(fw_cfg.get("extra_env", {})) # perf_dump_path for SGLang server-side timing (passed in request, zero overhead when None) perf_dump_path = None if framework == "sglang": perf_dump_path = os.path.join(str(log_dir), f"perf_{case['id']}_measured.json") log_file = log_dir / f"{case['id']}_{framework}.log" log_fh = open(log_file, "w", encoding="utf-8", buffering=1) log_thread = None proc = None try: proc = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env, preexec_fn=os.setsid, text=True, bufsize=1, ) # Tee server output to both log file and stdout (like test_server_utils) def _log_pipe(pipe, fh): try: for line in iter(pipe.readline, ""): sys.stdout.write(f" [server] {line}") sys.stdout.flush() fh.write(line) except ValueError: pass # pipe closed log_thread = threading.Thread(target=_log_pipe, args=(proc.stdout, log_fh)) log_thread.daemon = True log_thread.start() base_url = f"http://{DEFAULT_HOST}:{port}" wait_for_health(base_url, framework) # Warmup requests (not measured, no perf dump) # Use few steps to be fast — server's own warmup (warmup_steps=3) handles # torch.compile compilation; these external warmups just stabilize triton # kernel specializations across requests. WARMUP_STEPS = 3 warmup_case = {**case, "num_inference_steps": WARMUP_STEPS} for wi in range(1, 3): print(f" Sending warmup request ({wi}/2, {WARMUP_STEPS} steps)...") try: send_request(base_url, warmup_case, framework, config) except Exception as e: print(f" Warmup request {wi} failed (non-fatal): {e}") # Measured request — pass perf_dump_path for SGLang server-side timing if perf_dump_path and os.path.exists(perf_dump_path): os.remove(perf_dump_path) print(" Sending measured request...") latency = send_request( base_url, case, framework, config, perf_dump_path=perf_dump_path ) result["latency_s"] = round(latency, 3) except Exception as e: result["error"] = str(e) print(f" ERROR: {e}") finally: if proc: kill_server(proc) if log_thread: log_thread.join(timeout=5) log_fh.close() return result def _install_framework(fw_name: str, dry_run: bool = False) -> bool: """Install a comparison framework via the install script. Returns True on success.""" if fw_name not in INSTALLABLE_FRAMEWORKS: return True if not INSTALL_SCRIPT.exists(): print(f" WARNING: Install script not found at {INSTALL_SCRIPT}") return False if dry_run: print(f" [DRY-RUN] Would install: bash {INSTALL_SCRIPT} {fw_name}") return True print(f"\n{'='*60}") print(f"Installing framework: {fw_name}") print(f"{'='*60}") ret = subprocess.run( ["bash", str(INSTALL_SCRIPT), fw_name], timeout=600, ) if ret.returncode != 0: print(f" WARNING: {fw_name} installation failed (exit {ret.returncode})") return False return True def run_comparison( config: dict, case_ids: list[str] | None = None, frameworks: list[str] | None = None, port: int = DEFAULT_PORT, output: str = "comparison-results.json", dry_run: bool = False, ) -> dict: """Run all comparison cases, grouped by framework to minimize installs. Order: sglang first (already installed), then vllm-omni, then lightx2v. Each non-sglang framework is installed right before its cases run. """ timestamp = datetime.now(timezone.utc).isoformat() commit_sha = os.environ.get("GITHUB_SHA", "unknown") run_id = os.environ.get("GITHUB_RUN_ID", "local") log_dir = Path("comparison-logs") log_dir.mkdir(exist_ok=True) # Collect all (case, framework) pairs, grouped by framework fw_order = ["sglang", "vllm-omni", "lightx2v"] fw_cases: dict[str, list[tuple[dict, dict]]] = {fw: [] for fw in fw_order} for case in config["cases"]: if case_ids and case["id"] not in case_ids: continue for fw_name, fw_cfg in case["frameworks"].items(): if frameworks and fw_name not in frameworks: continue if fw_name not in fw_cases: fw_cases[fw_name] = [] fw_cases[fw_name].append((case, fw_cfg)) results = [] installed_fws: set[str] = set() for fw_name in fw_order: pairs = fw_cases.get(fw_name, []) if not pairs: continue # Install framework if needed (once per framework) if fw_name not in installed_fws and fw_name in INSTALLABLE_FRAMEWORKS: if not _install_framework(fw_name, dry_run): # Skip all cases for this framework for case, _ in pairs: results.append( { "case_id": case["id"], "framework": fw_name, "model": case["model"], "task": case["task"], "latency_s": None, "error": f"{fw_name} installation failed", } ) continue installed_fws.add(fw_name) for case, fw_cfg in pairs: print(f"\n{'='*60}") print(f"Case: {case['id']} | Model: {case['model']} | Framework: {fw_name}") print(f"{'='*60}") if dry_run: cmd = build_server_cmd(fw_name, case, fw_cfg, port) print(f" [DRY-RUN] Would run: {' '.join(cmd)}") results.append( { "case_id": case["id"], "framework": fw_name, "model": case["model"], "task": case["task"], "latency_s": None, "error": "dry-run", } ) continue result = run_single(case, fw_name, fw_cfg, port, log_dir, config) results.append(result) # Wait for GPU memory to clear print(f" Waiting {GPU_CLEAR_WAIT}s for GPU memory to clear...") time.sleep(GPU_CLEAR_WAIT) output_data = { "timestamp": timestamp, "commit_sha": commit_sha, "run_id": run_id, "results": results, } os.makedirs(os.path.dirname(output) or ".", exist_ok=True) with open(output, "w") as f: json.dump(output_data, f, indent=2) print(f"\nResults written to {output}") # Print summary table print(f"\n{'='*60}") print("SUMMARY") print(f"{'='*60}") for r in results: lat = f"{r['latency_s']:.2f}s" if r["latency_s"] else r.get("error", "N/A") print(f" {r['case_id']:30s} | {r['framework']:12s} | {lat}") return output_data # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser( description="Cross-framework diffusion serving comparison benchmark" ) parser.add_argument( "--config", default=str(CONFIGS_PATH), help="Path to comparison_configs.json", ) parser.add_argument( "--case-ids", nargs="+", default=None, help="Only run specific case IDs", ) parser.add_argument( "--frameworks", nargs="+", default=None, help="Only run specific frameworks (sglang, vllm-omni, lightx2v)", ) parser.add_argument( "--port", type=int, default=DEFAULT_PORT, help="Server port", ) parser.add_argument( "--output", default="comparison-results.json", help="Output JSON path", ) parser.add_argument( "--dry-run", action="store_true", help="Parse config and print commands without launching servers", ) args = parser.parse_args() with open(args.config) as f: config = json.load(f) print(f"Loaded {len(config['cases'])} comparison case(s) from {args.config}") output_data = run_comparison( config=config, case_ids=args.case_ids, frameworks=args.frameworks, port=args.port, output=args.output, dry_run=args.dry_run, ) # Exit with non-zero if any case had an error errors = [r for r in output_data.get("results", []) if r.get("error")] if errors and not args.dry_run: print(f"\n{len(errors)} case(s) had errors:") for e in errors: print(f" {e['case_id']} ({e['framework']}): {e['error']}") sys.exit(1) if __name__ == "__main__": main()