Files
composable_kernel/dispatcher/scripts/stress_test_autocorrect.py
Vidyasagar Ananthan 9e049a32a1 Adding dispatcher architecture (#3300)
* WIP POC of dispatcher

* Dispatcher python workflow setup.

* Dispatcher cleanup and updates.

Further dispatcher cleanup and updates.

Build fixes

Improvements and python to CK example

Improvements to readme

* Fixes to python paths

* Cleaning up code

* Improving dispatcher support for different arch

Fixing typos

* Fix formatting errors

* Cleaning up examples

* Improving codegeneration

* Improving and fixing C++ examples

* Adding conv functionality (fwd,bwd,bwdw) and examples.

* Fixes based on feedback.

* Further fixes based on feedback.

* Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug.

* Another round of improvements  based on feedback.

* Trimming out unnecessary code.

* Fixing the multi-D implementation.

* Using gpu verification for gemms and fixing convolutions tflops calculation.

* Fix counter usage issue and arch filtering per ops.

* Adding changelog and other fixes.

* Improve examples and resolve critical bugs.

* Reduce build time for python examples.

* Fixing minor bug.

* Fix compilation error.

* Improve installation instructions for dispatcher.

* Add docker based  installation instructions for dispatcher.

* Fixing arch-based filtering to match tile engine.

* Remove dead code and fix arch filtering.

* Minor bugfix.

* Updates after rebase.

* Trimming code.

* Fix copyright headers.

* Consolidate examples, cut down code.

* Minor fixes.

* Improving python examples.

* Update readmes.

* Remove conv functionality.

* Cleanup following conv removable.
2026-01-22 09:34:33 -08:00

541 lines
16 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Stress Test for Auto-Correction and Codegen
This script tests the robustness of:
1. GEMM auto-correction (Python)
2. Conv auto-correction (Python)
3. C++ kernel declaration validation and wildcard expansion
4. Architecture filtering
Usage:
python3 scripts/stress_test_autocorrect.py [--arch gfx942] [--samples 50] [--verbose]
"""
import argparse
import random
import sys
from pathlib import Path
# Add paths for imports
dispatcher_root = Path(__file__).parent.parent
sys.path.insert(0, str(dispatcher_root / "python"))
sys.path.insert(0, str(dispatcher_root / "codegen"))
sys.path.insert(0, str(dispatcher_root / "scripts"))
from ctypes_utils import auto_correct_kernel_config, KernelConfig # noqa: E402
# Import validation/expansion functions from compile scripts
from compile_gemm_examples import ( # noqa: E402
validate_kernel_config,
expand_declaration_with_arch_filter,
)
from compile_conv_examples import ( # noqa: E402
validate_conv_kernel_config,
expand_conv_declaration_with_arch_filter,
)
# =============================================================================
# TEST PARAMETERS
# =============================================================================
# Valid dtypes
DTYPES = ["fp16", "bf16", "fp32", "fp8", "bf8", "int8"]
# Valid layouts
LAYOUTS = ["rcr", "rrr", "crr", "ccr"]
# Tile sizes (some valid, some invalid)
TILE_SIZES = [
(32, 32, 16),
(64, 64, 32),
(128, 128, 32),
(256, 256, 64),
(128, 256, 32),
(256, 128, 32),
# Invalid sizes to test auto-correction
(100, 100, 50),
(17, 17, 17),
(512, 512, 128),
]
# Wave configs (some valid, some invalid)
WAVE_CONFIGS = [
(1, 1, 1),
(1, 2, 1),
(2, 1, 1),
(2, 2, 1),
(1, 4, 1),
(4, 1, 1),
(2, 4, 1),
(4, 2, 1),
# Invalid configs to test auto-correction
(3, 3, 1),
(5, 5, 1),
(1, 1, 2),
]
# Warp tile sizes (some valid, some invalid)
WARP_TILES = [
(16, 16, 16),
(16, 16, 32),
(32, 32, 8),
(32, 32, 16),
# Invalid tiles to test auto-correction
(48, 48, 24),
(64, 64, 32),
]
# Pipelines and schedulers
PIPELINES = ["compv3", "compv4", "flatmma", "invalid_pipeline"]
SCHEDULERS = ["intrawave", "interwave", "invalid_scheduler"]
# Architectures
ARCHS = ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1200", "gfx1201"]
# =============================================================================
# TEST FUNCTIONS
# =============================================================================
def generate_random_gemm_config():
"""Generate a random GEMM configuration (may be invalid)."""
dtype = random.choice(DTYPES)
layout = random.choice(LAYOUTS)
tile = random.choice(TILE_SIZES)
wave = random.choice(WAVE_CONFIGS)
warp = random.choice(WARP_TILES)
pipeline = random.choice(PIPELINES)
scheduler = random.choice(SCHEDULERS)
arch = random.choice(ARCHS)
return {
"name": f"test_{dtype}_{layout}_{tile[0]}x{tile[1]}x{tile[2]}",
"dtype_a": dtype,
"dtype_b": dtype,
"dtype_c": dtype,
"dtype_acc": "fp32",
"layout": layout,
"tile_m": tile[0],
"tile_n": tile[1],
"tile_k": tile[2],
"wave_m": wave[0],
"wave_n": wave[1],
"wave_k": wave[2],
"warp_m": warp[0],
"warp_n": warp[1],
"warp_k": warp[2],
"pipeline": pipeline,
"scheduler": scheduler,
"arch": arch,
}
def generate_random_conv_config():
"""Generate a random Conv configuration (may be invalid)."""
dtype = random.choice(["fp16", "bf16"])
tile_k = random.choice([64, 128, 256])
tile_c = random.choice([64, 128, 256])
wave = random.choice(WAVE_CONFIGS)
warp = random.choice(WARP_TILES)
pipeline = random.choice(["compv3", "compv4"])
scheduler = random.choice(["intrawave"])
arch = random.choice(ARCHS)
return {
"name": f"test_conv_{dtype}_{tile_k}x{tile_c}",
"dtype": dtype,
"layout": "nhwgc",
"conv_type": "forward",
"tile_k": tile_k,
"tile_c": tile_c,
"wave_m": wave[0],
"wave_n": wave[1],
"wave_k": wave[2],
"warp_m": warp[0],
"warp_n": warp[1],
"warp_k": warp[2],
"pipeline": pipeline,
"scheduler": scheduler,
"arch": arch,
}
def test_gemm_validation(config, verbose=False):
"""Test GEMM validation and auto-correction."""
arch = config.get("arch", "gfx942")
is_valid, error_msg = validate_kernel_config(config, arch)
result = {
"config": config,
"is_valid": is_valid,
"error_msg": error_msg,
"expanded": [],
"auto_corrected": None,
}
if not is_valid:
# Try wildcard expansion
wildcard_config = config.copy()
wildcard_config["wave_m"] = -1
wildcard_config["wave_n"] = -1
wildcard_config["warp_m"] = -1
wildcard_config["warp_n"] = -1
wildcard_config["pipeline"] = "*"
wildcard_config["scheduler"] = "*"
expanded = expand_declaration_with_arch_filter(wildcard_config, arch)
result["expanded"] = expanded
if verbose:
print(f"\n Config: {config['name']}")
print(f" Valid: {is_valid}")
if not is_valid:
print(f" Error: {error_msg[:80]}...")
print(f" Expanded to: {len(result['expanded'])} configurations")
return result
def test_python_autocorrect(verbose=False):
"""Test Python auto-correction for GEMM KernelConfig."""
print("\n" + "=" * 70)
print(" PYTHON AUTO-CORRECTION TEST (GEMM KernelConfig)")
print("=" * 70)
test_cases = [
# Valid config
{
"name": "valid_fp16",
"dtype_a": "fp16",
"dtype_b": "fp16",
"dtype_c": "fp16",
"dtype_acc": "fp32",
"layout": "rcr",
"tile_m": 128,
"tile_n": 128,
"tile_k": 32,
"wave_m": 2,
"wave_n": 2,
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"scheduler": "intrawave",
"gfx_arch": "gfx942",
},
# Invalid wave config
{
"name": "invalid_wave",
"dtype_a": "fp16",
"dtype_b": "fp16",
"dtype_c": "fp16",
"dtype_acc": "fp32",
"layout": "rcr",
"tile_m": 128,
"tile_n": 128,
"tile_k": 32,
"wave_m": 1,
"wave_n": 1,
"wave_k": 1, # Invalid for gfx942
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"scheduler": "intrawave",
"gfx_arch": "gfx942",
},
# Invalid scheduler
{
"name": "invalid_scheduler",
"dtype_a": "fp16",
"dtype_b": "fp16",
"dtype_c": "fp16",
"dtype_acc": "fp32",
"layout": "rcr",
"tile_m": 128,
"tile_n": 128,
"tile_k": 32,
"wave_m": 2,
"wave_n": 2,
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"scheduler": "interwave", # May not be valid for all archs
"gfx_arch": "gfx942",
},
]
results = {"passed": 0, "failed": 0, "details": []}
for tc in test_cases:
try:
config = KernelConfig()
config.dtype_a = tc["dtype_a"]
config.dtype_b = tc["dtype_b"]
config.dtype_c = tc["dtype_c"]
config.dtype_acc = tc["dtype_acc"]
config.tile_m = tc["tile_m"]
config.tile_n = tc["tile_n"]
config.tile_k = tc["tile_k"]
config.wave_m = tc["wave_m"]
config.wave_n = tc["wave_n"]
config.wave_k = tc["wave_k"]
config.warp_m = tc["warp_m"]
config.warp_n = tc["warp_n"]
config.warp_k = tc["warp_k"]
config.pipeline = tc["pipeline"]
config.scheduler = tc["scheduler"]
config.gfx_arch = tc["gfx_arch"]
corrected, was_modified, corrections = auto_correct_kernel_config(
config, verbose=verbose
)
results["passed"] += 1
results["details"].append(
{
"name": tc["name"],
"status": "PASS",
"was_modified": was_modified,
"corrections": corrections,
}
)
if verbose:
print(f"\n {tc['name']}: PASS")
if was_modified:
print(f" Modified: {len(corrections)} correction(s)")
for c in corrections:
print(f"{c}")
except Exception as e:
results["failed"] += 1
results["details"].append(
{"name": tc["name"], "status": "FAIL", "error": str(e)}
)
if verbose:
print(f"\n {tc['name']}: FAIL - {e}")
print(f"\n Summary: {results['passed']} passed, {results['failed']} failed")
return results
def run_stress_test(arch, num_samples, verbose):
"""Run the full stress test."""
print("\n" + "=" * 70)
print(" DISPATCHER AUTO-CORRECTION & CODEGEN STRESS TEST")
print("=" * 70)
print(f" Target Architecture: {arch}")
print(f" Number of Samples: {num_samples}")
print("=" * 70)
# Test 1: GEMM Validation
print("\n" + "-" * 70)
print(" TEST 1: GEMM Validation & Wildcard Expansion")
print("-" * 70)
gemm_results = {"valid": 0, "invalid": 0, "expanded": 0, "expansion_failed": 0}
for i in range(num_samples):
config = generate_random_gemm_config()
config["arch"] = arch # Override with target arch
result = test_gemm_validation(config, verbose)
if result["is_valid"]:
gemm_results["valid"] += 1
else:
gemm_results["invalid"] += 1
if result["expanded"]:
gemm_results["expanded"] += 1
else:
gemm_results["expansion_failed"] += 1
print("\n GEMM Results:")
print(f" Valid configs: {gemm_results['valid']}")
print(f" Invalid configs: {gemm_results['invalid']}")
print(f" Successfully expanded: {gemm_results['expanded']}")
print(f" Expansion failed: {gemm_results['expansion_failed']}")
# Test 2: Conv Validation
print("\n" + "-" * 70)
print(" TEST 2: Conv Validation & Wildcard Expansion")
print("-" * 70)
conv_results = {"valid": 0, "invalid": 0, "expanded": 0, "expansion_failed": 0}
for i in range(num_samples):
config = generate_random_conv_config()
config["arch"] = arch # Override with target arch
is_valid, error_msg = validate_conv_kernel_config(config, arch)
if is_valid:
conv_results["valid"] += 1
else:
conv_results["invalid"] += 1
# Try wildcard expansion
wildcard_config = config.copy()
wildcard_config["wave_m"] = -1
wildcard_config["wave_n"] = -1
wildcard_config["warp_m"] = -1
wildcard_config["warp_n"] = -1
expanded = expand_conv_declaration_with_arch_filter(wildcard_config, arch)
if expanded:
conv_results["expanded"] += 1
else:
conv_results["expansion_failed"] += 1
print("\n Conv Results:")
print(f" Valid configs: {conv_results['valid']}")
print(f" Invalid configs: {conv_results['invalid']}")
print(f" Successfully expanded: {conv_results['expanded']}")
print(f" Expansion failed: {conv_results['expansion_failed']}")
# Test 3: Python Auto-Correction
print("\n" + "-" * 70)
print(" TEST 3: Python Auto-Correction (KernelConfig)")
print("-" * 70)
py_results = test_python_autocorrect(verbose)
# Test 4: Architecture-specific tests
print("\n" + "-" * 70)
print(" TEST 4: Architecture-Specific Validation")
print("-" * 70)
arch_test_configs = [
# fp16 should work on all archs
{"dtype": "fp16", "expected_archs": ARCHS},
# bf16 works on all archs that have bf16_bf16_fp32 in warp_tile_combos
{
"dtype": "bf16",
"expected_archs": [
"gfx908",
"gfx90a",
"gfx942",
"gfx950",
"gfx1100",
"gfx1200",
"gfx1201",
],
},
# fp8 works on archs that have fp8_fp8_fp32 in warp_tile_combos
{
"dtype": "fp8",
"expected_archs": ["gfx90a", "gfx942", "gfx950", "gfx1200", "gfx1201"],
},
]
for test in arch_test_configs:
dtype = test["dtype"]
print(f"\n Testing {dtype}:")
for test_arch in ARCHS:
config = {
"name": f"arch_test_{dtype}_{test_arch}",
"dtype_a": dtype,
"dtype_b": dtype,
"dtype_c": dtype,
"dtype_acc": "fp32",
"layout": "rcr",
"tile_m": 128,
"tile_n": 128,
"tile_k": 32,
"wave_m": -1, # Wildcard
"wave_n": -1,
"wave_k": 1,
"warp_m": -1,
"warp_n": -1,
"warp_k": -1,
"pipeline": "*",
"scheduler": "*",
"arch": test_arch,
}
expanded = expand_declaration_with_arch_filter(config, test_arch)
status = "" if expanded else ""
expected = test_arch in test["expected_archs"]
match = "OK" if (bool(expanded) == expected) else "MISMATCH"
if verbose or match == "MISMATCH":
print(f" {test_arch}: {status} ({len(expanded)} configs) [{match}]")
# Summary
print("\n" + "=" * 70)
print(" STRESS TEST SUMMARY")
print("=" * 70)
print(
f" GEMM: {gemm_results['valid'] + gemm_results['expanded']}/{num_samples} handled"
)
print(
f" Conv: {conv_results['valid'] + conv_results['expanded']}/{num_samples} handled"
)
print(
f" Python Auto-Correct: {py_results['passed']}/{py_results['passed'] + py_results['failed']} passed"
)
total_success = (
gemm_results["valid"]
+ gemm_results["expanded"]
+ conv_results["valid"]
+ conv_results["expanded"]
+ py_results["passed"]
)
total_tests = num_samples * 2 + py_results["passed"] + py_results["failed"]
print(f"\n Overall: {total_success}/{total_tests} tests handled successfully")
print("=" * 70)
return (
gemm_results["expansion_failed"] == 0 and conv_results["expansion_failed"] == 0
)
def main():
parser = argparse.ArgumentParser(
description="Stress test auto-correction and codegen"
)
parser.add_argument(
"--arch",
default="gfx942",
choices=ARCHS,
help="Target GPU architecture (default: gfx942)",
)
parser.add_argument(
"--samples",
type=int,
default=50,
help="Number of random samples to test (default: 50)",
)
parser.add_argument(
"--verbose", "-v", action="store_true", help="Show detailed output"
)
parser.add_argument(
"--seed", type=int, default=None, help="Random seed for reproducibility"
)
args = parser.parse_args()
if args.seed is not None:
random.seed(args.seed)
success = run_stress_test(args.arch, args.samples, args.verbose)
return 0 if success else 1
if __name__ == "__main__":
sys.exit(main())