Files
sglang/scripts/ci/utils/diffusion/diffusion_case_parser.py

373 lines
12 KiB
Python
Executable File

#!/usr/bin/env python3
"""
AST-based parser for diffusion test cases.
This module parses the diffusion case source and run_suite.py using AST to
extract test case information without requiring sglang dependencies. The case
source file is discovered from ONE_GPU_CASES/TWO_GPU_CASES imports in
run_suite.py so CI keeps a single source of truth.
Usage:
# From sibling scripts in this directory:
from diffusion_case_parser import collect_diffusion_suites, resolve_case_config_path
case_config_path = resolve_case_config_path(repo_root, run_suite_path)
suites = collect_diffusion_suites(case_config_path, run_suite_path, baseline_path)
"""
import ast
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional
# Mapping from list variable names to suite names
CASE_LIST_TO_SUITE = {
"ONE_GPU_CASES_A": "1-gpu",
"ONE_GPU_CASES_B": "1-gpu",
"ONE_GPU_CASES_C": "1-gpu-b200",
"TWO_GPU_CASES_A": "2-gpu",
"TWO_GPU_CASES_B": "2-gpu",
}
# Default estimated time for cases without baseline (5 minutes)
DEFAULT_EST_TIME_SECONDS = 300.0
# Fixed overhead for server startup when estimated_full_test_time_s is not set
STARTUP_OVERHEAD_SECONDS = 120.0
# Paths relative to repository root
BASELINE_REL_PATH = "python/sglang/multimodal_gen/test/server/perf_baselines.json"
RUN_SUITE_REL_PATH = "python/sglang/multimodal_gen/test/run_suite.py"
@dataclass
class DiffusionCaseInfo:
"""Information about a single diffusion test case."""
case_id: str # e.g., "qwen_image_t2i"
suite: str # "1-gpu" or "2-gpu"
est_time: float # estimated time in seconds
@dataclass
class DiffusionSuiteInfo:
"""Complete information for a test suite."""
suite: str # "1-gpu" or "2-gpu"
cases: List[DiffusionCaseInfo] # parametrized test cases
standalone_files: List[str] # standalone test files
standalone_est_times: Dict[str, float] # standalone file -> estimated seconds
missing_standalone_estimates: List[
str
] # standalone files without configured estimate
class DiffusionTestCaseVisitor(ast.NodeVisitor):
"""
AST visitor to extract DiffusionTestCase definitions from the case config.
Parses assignments like:
ONE_GPU_CASES_A: list[DiffusionTestCase] = [
DiffusionTestCase("case_id", ...),
...
]
"""
def __init__(self):
self.cases: Dict[str, List[str]] = {} # list_name -> [case_id, ...]
def visit_Assign(self, node: ast.Assign):
self._process_assignment(node.targets, node.value)
self.generic_visit(node)
def visit_AnnAssign(self, node: ast.AnnAssign):
if node.target and node.value:
self._process_assignment([node.target], node.value)
self.generic_visit(node)
def _process_assignment(self, targets: List[ast.AST], value: ast.AST):
"""Process an assignment to extract case IDs if it's a known list."""
for target in targets:
if isinstance(target, ast.Name) and target.id in CASE_LIST_TO_SUITE:
list_name = target.id
case_ids = self._extract_case_ids_from_list(value)
if case_ids is not None:
self.cases[list_name] = case_ids
def _extract_case_ids_from_list(self, node: ast.AST) -> Optional[List[str]]:
"""Extract case IDs from a literal list of DiffusionTestCase calls."""
if not isinstance(node, ast.List):
return None
case_ids = []
for elt in node.elts:
case_id = self._extract_case_id_from_call(elt)
if case_id:
case_ids.append(case_id)
return case_ids
def _extract_case_id_from_call(self, node: ast.AST) -> Optional[str]:
"""Extract case_id from DiffusionTestCase(...) call."""
if not isinstance(node, ast.Call):
return None
# Check if it's a DiffusionTestCase call
if isinstance(node.func, ast.Name) and node.func.id == "DiffusionTestCase":
# First positional argument is the case_id
if node.args and isinstance(node.args[0], ast.Constant):
return node.args[0].value
return None
def resolve_case_config_path(repo_root: Path, run_suite_path: Path) -> Path:
"""
Resolve the diffusion case config path from run_suite imports.
run_suite.py must import BOTH ONE_GPU_CASES and TWO_GPU_CASES from the same
module. That imported module is treated as the single source of truth.
"""
with open(run_suite_path, "r", encoding="utf-8") as f:
content = f.read()
tree = ast.parse(content, filename=str(run_suite_path))
one_gpu_module: Optional[str] = None
two_gpu_module: Optional[str] = None
for node in ast.walk(tree):
if not isinstance(node, ast.ImportFrom) or not node.module:
continue
imported_names = {alias.name for alias in node.names}
if "ONE_GPU_CASES" in imported_names:
one_gpu_module = node.module
if "TWO_GPU_CASES" in imported_names:
two_gpu_module = node.module
if one_gpu_module is None or two_gpu_module is None:
raise RuntimeError(
"run_suite.py must import BOTH ONE_GPU_CASES and TWO_GPU_CASES."
)
if one_gpu_module != two_gpu_module:
raise RuntimeError(
"run_suite.py imports ONE_GPU_CASES and TWO_GPU_CASES from different "
f"modules: {one_gpu_module} vs {two_gpu_module}"
)
rel_path = Path(*one_gpu_module.split(".")).with_suffix(".py")
candidates = [repo_root / rel_path, repo_root / "python" / rel_path]
case_config_path = next((path for path in candidates if path.exists()), None)
if case_config_path is None:
raise FileNotFoundError(
"Resolved case config from run_suite does not exist. Checked: "
+ ", ".join(str(path) for path in candidates)
)
return case_config_path
class RunSuiteVisitor(ast.NodeVisitor):
"""
AST visitor to extract standalone metadata from run_suite.py.
Parses:
STANDALONE_FILES = {
"1-gpu": ["test_lora_format_adapter.py"],
"2-gpu": [],
}
"""
def __init__(self):
self.standalone_files: Dict[str, List[str]] = {}
self.standalone_est_times: Dict[str, Dict[str, float]] = {}
def visit_Assign(self, node: ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id == "STANDALONE_FILES":
self.standalone_files = self._extract_file_dict(node.value)
if (
isinstance(target, ast.Name)
and target.id == "STANDALONE_FILE_EST_TIMES"
):
self.standalone_est_times = self._extract_est_time_dict(node.value)
self.generic_visit(node)
def _extract_file_dict(self, node: ast.AST) -> Dict[str, List[str]]:
"""Extract dictionary of suite -> file list."""
result = {}
if isinstance(node, ast.Dict):
for key, value in zip(node.keys, node.values):
if isinstance(key, ast.Constant) and isinstance(value, ast.List):
suite = key.value
files = [
elt.value for elt in value.elts if isinstance(elt, ast.Constant)
]
result[suite] = files
return result
def _extract_est_time_dict(self, node: ast.AST) -> Dict[str, Dict[str, float]]:
"""Extract dictionary of suite -> standalone file -> estimated seconds."""
result = {}
if not isinstance(node, ast.Dict):
return result
for key, value in zip(node.keys, node.values):
if not isinstance(key, ast.Constant) or not isinstance(value, ast.Dict):
continue
suite = key.value
suite_est_times = {}
for inner_key, inner_value in zip(value.keys, value.values):
if not (
isinstance(inner_key, ast.Constant)
and isinstance(inner_value, ast.Constant)
):
continue
suite_est_times[inner_key.value] = float(inner_value.value)
result[suite] = suite_est_times
return result
def load_baselines(baseline_path: Path) -> Dict[str, float]:
"""
Load performance baselines from JSON file.
Returns:
Dictionary mapping case_id to estimated time in seconds.
"""
if not baseline_path.exists():
return {}
with open(baseline_path, "r", encoding="utf-8") as f:
data = json.load(f)
baselines = {}
scenarios = data.get("scenarios", {})
for case_id, scenario in scenarios.items():
if scenario.get("estimated_full_test_time_s") is not None:
baselines[case_id] = scenario["estimated_full_test_time_s"]
else:
expected_e2e_ms = scenario.get("expected_e2e_ms", 0)
baselines[case_id] = expected_e2e_ms / 1000.0 + STARTUP_OVERHEAD_SECONDS
return baselines
def get_case_est_time(case_id: str, baselines: Dict[str, float]) -> float:
"""Get estimated time for a case, with fallback to default."""
return baselines.get(case_id, DEFAULT_EST_TIME_SECONDS)
def parse_testcase_configs(config_path: Path) -> Dict[str, List[str]]:
"""
Parse a diffusion case config file to extract case IDs.
Returns:
Dictionary mapping list name to case IDs.
e.g., {"ONE_GPU_CASES_A": ["qwen_image_t2i", ...], ...}
"""
with open(config_path, "r", encoding="utf-8") as f:
content = f.read()
tree = ast.parse(content, filename=str(config_path))
visitor = DiffusionTestCaseVisitor()
visitor.visit(tree)
return visitor.cases
def parse_run_suite_standalone_data(
run_suite_path: Path,
) -> tuple[Dict[str, List[str]], Dict[str, Dict[str, float]]]:
"""
Parse run_suite.py to extract standalone file metadata.
Returns:
Tuple of:
- suite -> standalone file list
- suite -> standalone file -> estimated seconds
"""
with open(run_suite_path, "r", encoding="utf-8") as f:
content = f.read()
tree = ast.parse(content, filename=str(run_suite_path))
visitor = RunSuiteVisitor()
visitor.visit(tree)
return visitor.standalone_files, visitor.standalone_est_times
def validate_standalone_est_times(
standalone_files: Dict[str, List[str]],
standalone_est_times: Dict[str, Dict[str, float]],
) -> Dict[str, List[str]]:
missing_by_suite = {}
for suite, files in standalone_files.items():
suite_est_times = standalone_est_times.get(suite, {})
missing = [
standalone_file
for standalone_file in files
if standalone_file not in suite_est_times
]
if missing:
missing_by_suite[suite] = missing
return missing_by_suite
def collect_diffusion_suites(
case_config_path: Path,
run_suite_path: Path,
baseline_path: Path,
) -> Dict[str, DiffusionSuiteInfo]:
"""
Collect all diffusion test suite information using AST parsing.
Args:
case_config_path: Path to case config (resolved from run_suite.py)
run_suite_path: Path to run_suite.py
baseline_path: Path to perf_baselines.json
Returns:
Dictionary mapping suite name to DiffusionSuiteInfo.
"""
# Parse case IDs from the single source case config.
case_lists = parse_testcase_configs(case_config_path)
# Parse standalone files from run_suite.py
standalone_files, standalone_est_times = parse_run_suite_standalone_data(
run_suite_path
)
missing_standalone_estimates = validate_standalone_est_times(
standalone_files, standalone_est_times
)
# Load baselines for time estimation
baselines = load_baselines(baseline_path)
# Build suite info
suites = {}
for list_name, suite in CASE_LIST_TO_SUITE.items():
case_ids = case_lists.get(list_name, [])
cases = [
DiffusionCaseInfo(
case_id=cid,
suite=suite,
est_time=get_case_est_time(cid, baselines),
)
for cid in case_ids
]
if suite not in suites:
suites[suite] = DiffusionSuiteInfo(
suite=suite,
cases=[],
standalone_files=standalone_files.get(suite, []),
standalone_est_times=dict(standalone_est_times.get(suite, {})),
missing_standalone_estimates=list(
missing_standalone_estimates.get(suite, [])
),
)
suites[suite].cases.extend(cases)
return suites