Files
composable_kernel/dispatcher/tests/test_autocorrect.py
Vidyasagar Ananthan 9e049a32a1 Adding dispatcher architecture (#3300)
* WIP POC of dispatcher

* Dispatcher python workflow setup.

* Dispatcher cleanup and updates.

Further dispatcher cleanup and updates.

Build fixes

Improvements and python to CK example

Improvements to readme

* Fixes to python paths

* Cleaning up code

* Improving dispatcher support for different arch

Fixing typos

* Fix formatting errors

* Cleaning up examples

* Improving codegeneration

* Improving and fixing C++ examples

* Adding conv functionality (fwd,bwd,bwdw) and examples.

* Fixes based on feedback.

* Further fixes based on feedback.

* Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug.

* Another round of improvements  based on feedback.

* Trimming out unnecessary code.

* Fixing the multi-D implementation.

* Using gpu verification for gemms and fixing convolutions tflops calculation.

* Fix counter usage issue and arch filtering per ops.

* Adding changelog and other fixes.

* Improve examples and resolve critical bugs.

* Reduce build time for python examples.

* Fixing minor bug.

* Fix compilation error.

* Improve installation instructions for dispatcher.

* Add docker based  installation instructions for dispatcher.

* Fixing arch-based filtering to match tile engine.

* Remove dead code and fix arch filtering.

* Minor bugfix.

* Updates after rebase.

* Trimming code.

* Fix copyright headers.

* Consolidate examples, cut down code.

* Minor fixes.

* Improving python examples.

* Update readmes.

* Remove conv functionality.

* Cleanup following conv removable.
2026-01-22 09:34:33 -08:00

626 lines
20 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Comprehensive Test Suite for Auto-Correction and Validation
Tests:
1. GEMM validation and wildcard expansion
2. Conv validation and wildcard expansion
3. Python KernelConfig auto-correction
4. Architecture-specific dtype support
5. Edge cases and error handling
Can be run as:
python3 tests/test_autocorrect.py # Run all tests
python3 tests/test_autocorrect.py -v # Verbose output
python3 tests/test_autocorrect.py TestGemmValidation # Run specific test class
ctest -R test_autocorrect # Via ctest
Exit codes:
0 = All tests passed
1 = Some tests failed
"""
import sys
import unittest
import random
from pathlib import Path
# Setup paths
SCRIPT_DIR = Path(__file__).parent.resolve()
DISPATCHER_DIR = SCRIPT_DIR.parent
sys.path.insert(0, str(DISPATCHER_DIR / "python"))
sys.path.insert(0, str(DISPATCHER_DIR / "codegen"))
sys.path.insert(0, str(DISPATCHER_DIR / "scripts"))
# Import modules under test
from compile_gemm_examples import ( # noqa: E402
validate_kernel_config,
expand_declaration_with_arch_filter,
is_wildcard_declaration,
)
from compile_conv_examples import ( # noqa: E402
validate_conv_kernel_config,
expand_conv_declaration_with_arch_filter,
is_conv_wildcard_declaration,
)
from ctypes_utils import auto_correct_kernel_config, KernelConfig # noqa: E402
# =============================================================================
# TEST DATA
# =============================================================================
VALID_ARCHS = ["gfx90a", "gfx942", "gfx950"]
VALID_DTYPES = ["fp16", "bf16"]
VALID_LAYOUTS = ["rcr", "rrr"]
VALID_PIPELINES = ["compv3", "compv4"]
VALID_SCHEDULERS = ["intrawave"]
# Known valid wave configs for gfx942
VALID_WAVE_CONFIGS_GFX942 = [[1, 4, 1], [2, 2, 1], [4, 1, 1]]
# Known valid warp tiles for fp16 on gfx942
VALID_WARP_TILES_FP16_GFX942 = [[16, 16, 16], [16, 16, 32], [32, 32, 8], [32, 32, 16]]
# =============================================================================
# GEMM VALIDATION TESTS
# =============================================================================
class TestGemmValidation(unittest.TestCase):
"""Test GEMM kernel validation."""
def test_valid_config(self):
"""Valid configuration should pass validation."""
config = {
"name": "test_valid",
"dtype_a": "fp16",
"dtype_b": "fp16",
"dtype_c": "fp16",
"layout": "rcr",
"tile_m": 128,
"tile_n": 128,
"tile_k": 32,
"wave_m": 2,
"wave_n": 2,
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"scheduler": "intrawave",
}
is_valid, error = validate_kernel_config(config, "gfx942")
self.assertTrue(is_valid, f"Expected valid, got error: {error}")
def test_invalid_wave_config(self):
"""Invalid wave config should fail validation."""
config = {
"name": "test_invalid_wave",
"dtype_a": "fp16",
"wave_m": 3, # Invalid
"wave_n": 3, # Invalid
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"scheduler": "intrawave",
}
is_valid, error = validate_kernel_config(config, "gfx942")
self.assertFalse(is_valid)
self.assertIn("wave", error.lower())
def test_invalid_scheduler(self):
"""Invalid scheduler should fail validation."""
config = {
"name": "test_invalid_scheduler",
"dtype_a": "fp16",
"wave_m": 2,
"wave_n": 2,
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"epilogue": "cshuffle",
"scheduler": "interwave", # Invalid with compv4+cshuffle
}
is_valid, error = validate_kernel_config(config, "gfx942")
self.assertFalse(is_valid)
self.assertIn("trait", error.lower())
def test_wildcard_skips_validation(self):
"""Wildcard declarations should skip validation."""
config = {
"name": "test_wildcard",
"dtype_a": "fp16",
"wave_m": -1, # Wildcard
"wave_n": -1, # Wildcard
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"scheduler": "intrawave",
}
self.assertTrue(is_wildcard_declaration(config))
is_valid, _ = validate_kernel_config(config, "gfx942")
self.assertTrue(is_valid)
def test_unsupported_arch(self):
"""Unsupported architecture should fail validation."""
config = {
"name": "test_bad_arch",
"dtype_a": "fp16",
"wave_m": 2,
"wave_n": 2,
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"scheduler": "intrawave",
}
is_valid, error = validate_kernel_config(config, "gfx_invalid")
self.assertFalse(is_valid)
self.assertIn("unsupported", error.lower())
class TestGemmExpansion(unittest.TestCase):
"""Test GEMM wildcard expansion."""
def test_wave_expansion(self):
"""Wave wildcard should expand to valid configs."""
config = {
"name": "test_wave_expand",
"dtype_a": "fp16",
"dtype_b": "fp16",
"dtype_c": "fp16",
"layout": "rcr",
"tile_m": 128,
"tile_n": 128,
"tile_k": 32,
"wave_m": -1, # Wildcard
"wave_n": -1, # Wildcard
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"scheduler": "intrawave",
}
expanded = expand_declaration_with_arch_filter(config, "gfx942")
self.assertGreater(len(expanded), 0, "Should expand to at least one config")
# All expanded configs should be valid
for exp in expanded:
is_valid, error = validate_kernel_config(exp, "gfx942")
self.assertTrue(is_valid, f"Expanded config invalid: {error}")
def test_full_wildcard_expansion(self):
"""Full wildcard should expand to multiple valid configs."""
config = {
"name": "test_full_wildcard",
"dtype_a": "fp16",
"dtype_b": "fp16",
"dtype_c": "fp16",
"layout": "rcr",
"tile_m": 128,
"tile_n": 128,
"tile_k": 32,
"wave_m": -1,
"wave_n": -1,
"wave_k": 1,
"warp_m": -1,
"warp_n": -1,
"warp_k": -1,
"pipeline": "*",
"scheduler": "*",
}
expanded = expand_declaration_with_arch_filter(config, "gfx942")
self.assertGreater(
len(expanded), 1, "Full wildcard should expand to multiple configs"
)
def test_explicit_config_not_expanded(self):
"""Explicit (non-wildcard) config should not expand."""
config = {
"name": "test_explicit",
"dtype_a": "fp16",
"dtype_b": "fp16",
"dtype_c": "fp16",
"layout": "rcr",
"tile_m": 128,
"tile_n": 128,
"tile_k": 32,
"wave_m": 2,
"wave_n": 2,
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"scheduler": "intrawave",
}
expanded = expand_declaration_with_arch_filter(config, "gfx942")
self.assertEqual(len(expanded), 1, "Explicit config should not expand")
# =============================================================================
# CONV VALIDATION TESTS
# =============================================================================
class TestConvValidation(unittest.TestCase):
"""Test Conv kernel validation."""
def test_valid_conv_config(self):
"""Valid conv configuration should pass validation."""
config = {
"name": "test_valid_conv",
"dtype": "fp16",
"layout": "nhwgc",
"conv_type": "forward",
"tile_k": 128,
"tile_c": 128,
"wave_m": 2,
"wave_n": 2,
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"scheduler": "intrawave",
}
is_valid, error = validate_conv_kernel_config(config, "gfx942")
self.assertTrue(is_valid, f"Expected valid, got error: {error}")
def test_invalid_conv_wave(self):
"""Invalid wave config should fail conv validation."""
config = {
"name": "test_invalid_conv_wave",
"dtype": "fp16",
"wave_m": 5, # Invalid
"wave_n": 5, # Invalid
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"scheduler": "intrawave",
}
is_valid, error = validate_conv_kernel_config(config, "gfx942")
self.assertFalse(is_valid)
self.assertIn("wave", error.lower())
def test_conv_wildcard_detection(self):
"""Should correctly detect conv wildcards."""
wildcard_config = {
"wave_m": -1,
"wave_n": 2,
"warp_m": 32,
"warp_n": 32,
"pipeline": "compv4",
"scheduler": "intrawave",
}
self.assertTrue(is_conv_wildcard_declaration(wildcard_config))
explicit_config = {
"wave_m": 2,
"wave_n": 2,
"warp_m": 32,
"warp_n": 32,
"pipeline": "compv4",
"scheduler": "intrawave",
}
self.assertFalse(is_conv_wildcard_declaration(explicit_config))
class TestConvExpansion(unittest.TestCase):
"""Test Conv wildcard expansion."""
def test_conv_wave_expansion(self):
"""Conv wave wildcard should expand to valid configs."""
config = {
"name": "test_conv_wave_expand",
"dtype": "fp16",
"layout": "nhwgc",
"conv_type": "forward",
"tile_k": 128,
"tile_c": 128,
"wave_m": -1,
"wave_n": -1,
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"scheduler": "intrawave",
}
expanded = expand_conv_declaration_with_arch_filter(config, "gfx942")
self.assertGreater(len(expanded), 0, "Should expand to at least one config")
# =============================================================================
# PYTHON AUTO-CORRECTION TESTS
# =============================================================================
class TestPythonAutoCorrect(unittest.TestCase):
"""Test Python KernelConfig auto-correction."""
def test_autocorrect_invalid_wave(self):
"""Auto-correction should fix invalid wave config."""
config = KernelConfig()
config.dtype_a = "fp16"
config.dtype_b = "fp16"
config.dtype_c = "fp16"
config.dtype_acc = "fp32"
config.layout_a = "row"
config.layout_b = "col"
config.layout_c = "row"
config.tile_m = 128
config.tile_n = 128
config.tile_k = 32
config.wave_m = 1 # May be invalid
config.wave_n = 1 # May be invalid
config.wave_k = 1
config.warp_m = 32
config.warp_n = 32
config.warp_k = 16
config.pipeline = "compv4"
config.scheduler = "intrawave"
config.gfx_arch = "gfx942"
corrected, was_modified, corrections = auto_correct_kernel_config(
config, verbose=False
)
# Should either be valid or corrected
self.assertIsNotNone(corrected)
if was_modified:
self.assertGreater(len(corrections), 0)
def test_autocorrect_returns_three_values(self):
"""Auto-correction should return (config, was_modified, corrections)."""
config = KernelConfig()
config.dtype_a = "fp16"
config.dtype_b = "fp16"
config.dtype_c = "fp16"
config.dtype_acc = "fp32"
config.layout_a = "row"
config.layout_b = "col"
config.layout_c = "row"
config.tile_m = 128
config.tile_n = 128
config.tile_k = 32
config.wave_m = 2
config.wave_n = 2
config.wave_k = 1
config.warp_m = 32
config.warp_n = 32
config.warp_k = 16
config.pipeline = "compv4"
config.scheduler = "intrawave"
config.gfx_arch = "gfx942"
result = auto_correct_kernel_config(config, verbose=False)
self.assertEqual(len(result), 3, "Should return 3 values")
corrected, was_modified, corrections = result
self.assertIsInstance(was_modified, bool)
self.assertIsInstance(corrections, list)
# =============================================================================
# STRESS TESTS
# =============================================================================
class TestStressRandom(unittest.TestCase):
"""Stress test with random configurations."""
def test_random_gemm_configs(self):
"""Random GEMM configs should either validate or expand successfully."""
random.seed(42) # Reproducible
dtypes = ["fp16", "bf16"]
layouts = ["rcr", "rrr"]
tiles = [(64, 64, 32), (128, 128, 32), (256, 256, 64)]
waves = [(1, 1, 1), (2, 2, 1), (1, 4, 1), (3, 3, 1)] # Some invalid
warps = [(16, 16, 16), (32, 32, 16), (48, 48, 24)] # Some invalid
pipelines = ["compv3", "compv4", "invalid"]
schedulers = ["intrawave", "interwave"]
success_count = 0
total_count = 30
for _ in range(total_count):
config = {
"name": "random_test",
"dtype_a": random.choice(dtypes),
"dtype_b": random.choice(dtypes),
"dtype_c": random.choice(dtypes),
"layout": random.choice(layouts),
"tile_m": random.choice(tiles)[0],
"tile_n": random.choice(tiles)[1],
"tile_k": random.choice(tiles)[2],
"wave_m": random.choice(waves)[0],
"wave_n": random.choice(waves)[1],
"wave_k": random.choice(waves)[2],
"warp_m": random.choice(warps)[0],
"warp_n": random.choice(warps)[1],
"warp_k": random.choice(warps)[2],
"pipeline": random.choice(pipelines),
"scheduler": random.choice(schedulers),
}
is_valid, _ = validate_kernel_config(config, "gfx942")
if is_valid:
success_count += 1
else:
# Try wildcard expansion
wildcard = config.copy()
wildcard["wave_m"] = -1
wildcard["wave_n"] = -1
wildcard["warp_m"] = -1
wildcard["warp_n"] = -1
wildcard["pipeline"] = "*"
wildcard["scheduler"] = "*"
expanded = expand_declaration_with_arch_filter(wildcard, "gfx942")
if expanded:
success_count += 1
# At least 50% should be handleable
self.assertGreater(
success_count / total_count,
0.5,
f"Only {success_count}/{total_count} configs were handleable",
)
def test_random_conv_configs(self):
"""Random Conv configs should either validate or expand successfully."""
random.seed(42)
dtypes = ["fp16", "bf16"]
tiles = [(64, 64), (128, 128), (256, 256)]
waves = [(2, 2, 1), (1, 4, 1), (3, 3, 1)]
warps = [(16, 16, 16), (32, 32, 16)]
success_count = 0
total_count = 20
for _ in range(total_count):
config = {
"name": "random_conv_test",
"dtype": random.choice(dtypes),
"layout": "nhwgc",
"conv_type": "forward",
"tile_k": random.choice(tiles)[0],
"tile_c": random.choice(tiles)[1],
"wave_m": random.choice(waves)[0],
"wave_n": random.choice(waves)[1],
"wave_k": random.choice(waves)[2],
"warp_m": random.choice(warps)[0],
"warp_n": random.choice(warps)[1],
"warp_k": random.choice(warps)[2],
"pipeline": "compv4",
"scheduler": "intrawave",
}
is_valid, _ = validate_conv_kernel_config(config, "gfx942")
if is_valid:
success_count += 1
else:
# Try wildcard expansion
wildcard = config.copy()
wildcard["wave_m"] = -1
wildcard["wave_n"] = -1
wildcard["warp_m"] = -1
wildcard["warp_n"] = -1
expanded = expand_conv_declaration_with_arch_filter(wildcard, "gfx942")
if expanded:
success_count += 1
self.assertGreater(
success_count / total_count,
0.5,
f"Only {success_count}/{total_count} conv configs were handleable",
)
# =============================================================================
# ARCHITECTURE TESTS
# =============================================================================
class TestArchitectureSupport(unittest.TestCase):
"""Test architecture-specific support."""
def test_gfx942_fp16_support(self):
"""gfx942 should support fp16."""
config = {
"dtype_a": "fp16",
"wave_m": -1,
"wave_n": -1,
"warp_m": -1,
"warp_n": -1,
"pipeline": "*",
"scheduler": "*",
}
expanded = expand_declaration_with_arch_filter(config, "gfx942")
self.assertGreater(len(expanded), 0, "gfx942 should support fp16")
def test_gfx942_bf16_support(self):
"""gfx942 should support bf16."""
config = {
"dtype_a": "bf16",
"wave_m": -1,
"wave_n": -1,
"warp_m": -1,
"warp_n": -1,
"pipeline": "*",
"scheduler": "*",
}
expanded = expand_declaration_with_arch_filter(config, "gfx942")
self.assertGreater(len(expanded), 0, "gfx942 should support bf16")
def test_gfx90a_support(self):
"""gfx90a should support fp16."""
config = {
"dtype_a": "fp16",
"wave_m": -1,
"wave_n": -1,
"warp_m": -1,
"warp_n": -1,
"pipeline": "*",
"scheduler": "*",
}
expanded = expand_declaration_with_arch_filter(config, "gfx90a")
self.assertGreater(len(expanded), 0, "gfx90a should support fp16")
# =============================================================================
# MAIN
# =============================================================================
def main():
"""Run tests."""
# Parse args for verbosity
verbosity = 2 if "-v" in sys.argv or "--verbose" in sys.argv else 1
# Create test suite
loader = unittest.TestLoader()
suite = unittest.TestSuite()
# Add all test classes
suite.addTests(loader.loadTestsFromTestCase(TestGemmValidation))
suite.addTests(loader.loadTestsFromTestCase(TestGemmExpansion))
suite.addTests(loader.loadTestsFromTestCase(TestConvValidation))
suite.addTests(loader.loadTestsFromTestCase(TestConvExpansion))
suite.addTests(loader.loadTestsFromTestCase(TestPythonAutoCorrect))
suite.addTests(loader.loadTestsFromTestCase(TestStressRandom))
suite.addTests(loader.loadTestsFromTestCase(TestArchitectureSupport))
# Run tests
runner = unittest.TextTestRunner(verbosity=verbosity)
result = runner.run(suite)
# Return exit code
return 0 if result.wasSuccessful() else 1
if __name__ == "__main__":
sys.exit(main())