Files
composable_kernel/dispatcher/tests/test_autocorrect.py
Vidyasagar Ananthan 40290297cd [CK] [CK_Tile] Add GroupConv to Kernel Dispatcher (#5168)
## Motivation

This PR adds CK Tile group convolution (forward, backward-data,
backward-weight) support to the kernel dispatcher, matching and unifying
with the existing dispatcher GEMM infrastructure in architecture and
usability. The dispatcher provides a unified kernel dispatch system with
both C++ and Python frontends, and until now only supported GEMM
operations. This PR enables framework integrators to use the same
declarative kernel workflow for convolutions as they do for GEMM:
declare kernels, build a registry JIT, select kernels within the
registry at runtime, and dispatch to GPU. Future PRs will include
runtime kernel selection heuristics for autotuning of kernel parameters
based on (problem, hardware arch).

## Technical Details

Grouped convolution support has been added to the CK Tile Dispatcher
with generated_conv_backend.hpp enabling dispatcher.run(in, wei, out,
problem) for all 6 conv variants (fwd/bwdd/bwdw x 2D/3D), runtime
heuristic kernel selection, and GroupedConvKernelKey with full
ConvConfigBase fields. Python side adds parallel JIT via
registry.build(max_workers) and heuristic registry.select(). Includes 7
C++ and 6 Python examples covering all directions with CPU reference
validation, and shared infrastructure improvements (BaseRegistry CRTP,
structured exceptions). As a sanity check, JIT compile times for a
single kernel remains the same and for multiple kernels there is better
parallelism:
Kernels | 1 worker | 8 workers
1 | 7.7 s | 7.7 s
2 | 15.9 s | 8.2 s
4 | 33.4 s | 9.7 s
6 | 52.3 s | 10.2 s

## Test Plan

145 ephemeral unit tests have been added to test basic functionality.
All 30 examples/integration tests run end-to-end on gfx950 (MI350): 7
C++ conv, 7 C++ GEMM, 6 Python conv, 10 Python GEMM. CPU reference
validation for forward, backward-data, and backward-weight (2D) in both
C++ and Python examples pass.

## Test Result

30 examples pass. Peak performance: 132 TFLOPS (Batch-32 forward 56x56),
53 TFLOPS (pointwise 1x1). CPU reference accuracy: max_abs_diff < 0.002
for all directions (fp16 vs fp32 reference).

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: Yaswanth Raparti <113389104+yraparti@users.noreply.github.com>
2026-04-09 10:38:33 -07:00

626 lines
20 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Comprehensive Test Suite for Auto-Correction and Validation
Tests:
1. GEMM validation and wildcard expansion
2. Conv validation and wildcard expansion
3. Python KernelConfig auto-correction
4. Architecture-specific dtype support
5. Edge cases and error handling
Can be run as:
python3 tests/test_autocorrect.py # Run all tests
python3 tests/test_autocorrect.py -v # Verbose output
python3 tests/test_autocorrect.py TestGemmValidation # Run specific test class
ctest -R test_autocorrect # Via ctest
Exit codes:
0 = All tests passed
1 = Some tests failed
"""
import sys
import unittest
import random
from pathlib import Path
# Setup paths
SCRIPT_DIR = Path(__file__).parent.resolve()
DISPATCHER_DIR = SCRIPT_DIR.parent
sys.path.insert(0, str(DISPATCHER_DIR / "python"))
sys.path.insert(0, str(DISPATCHER_DIR / "codegen"))
sys.path.insert(0, str(DISPATCHER_DIR / "scripts"))
# Import modules under test
from compile_gemm_examples import ( # noqa: E402
validate_kernel_config,
expand_declaration_with_arch_filter,
is_wildcard_declaration,
)
from compile_grouped_conv_examples import ( # noqa: E402
validate_grouped_conv_kernel_config as validate_conv_kernel_config,
expand_grouped_conv_declaration_with_arch_filter as expand_conv_declaration_with_arch_filter,
is_grouped_conv_wildcard_declaration as is_conv_wildcard_declaration,
)
from ctypes_utils import auto_correct_kernel_config, KernelConfig # noqa: E402
# =============================================================================
# TEST DATA
# =============================================================================
VALID_ARCHS = ["gfx90a", "gfx942", "gfx950"]
VALID_DTYPES = ["fp16", "bf16"]
VALID_LAYOUTS = ["rcr", "rrr"]
VALID_PIPELINES = ["compv3", "compv4"]
VALID_SCHEDULERS = ["intrawave"]
# Known valid wave configs for gfx942
VALID_WAVE_CONFIGS_GFX942 = [[1, 4, 1], [2, 2, 1], [4, 1, 1]]
# Known valid warp tiles for fp16 on gfx942
VALID_WARP_TILES_FP16_GFX942 = [[16, 16, 16], [16, 16, 32], [32, 32, 8], [32, 32, 16]]
# =============================================================================
# GEMM VALIDATION TESTS
# =============================================================================
class TestGemmValidation(unittest.TestCase):
"""Test GEMM kernel validation."""
def test_valid_config(self):
"""Valid configuration should pass validation."""
config = {
"name": "test_valid",
"dtype_a": "fp16",
"dtype_b": "fp16",
"dtype_c": "fp16",
"layout": "rcr",
"tile_m": 128,
"tile_n": 128,
"tile_k": 32,
"wave_m": 2,
"wave_n": 2,
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"scheduler": "intrawave",
}
is_valid, error = validate_kernel_config(config, "gfx942")
self.assertTrue(is_valid, f"Expected valid, got error: {error}")
def test_invalid_wave_config(self):
"""Invalid wave config should fail validation."""
config = {
"name": "test_invalid_wave",
"dtype_a": "fp16",
"wave_m": 3, # Invalid
"wave_n": 3, # Invalid
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"scheduler": "intrawave",
}
is_valid, error = validate_kernel_config(config, "gfx942")
self.assertFalse(is_valid)
self.assertIn("wave", error.lower())
def test_invalid_scheduler(self):
"""Invalid scheduler should fail validation."""
config = {
"name": "test_invalid_scheduler",
"dtype_a": "fp16",
"wave_m": 2,
"wave_n": 2,
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"epilogue": "cshuffle",
"scheduler": "interwave", # Invalid with compv4+cshuffle
}
is_valid, error = validate_kernel_config(config, "gfx942")
self.assertFalse(is_valid)
self.assertIn("trait", error.lower())
def test_wildcard_skips_validation(self):
"""Wildcard declarations should skip validation."""
config = {
"name": "test_wildcard",
"dtype_a": "fp16",
"wave_m": -1, # Wildcard
"wave_n": -1, # Wildcard
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"scheduler": "intrawave",
}
self.assertTrue(is_wildcard_declaration(config))
is_valid, _ = validate_kernel_config(config, "gfx942")
self.assertTrue(is_valid)
def test_unsupported_arch(self):
"""Unsupported architecture should fail validation."""
config = {
"name": "test_bad_arch",
"dtype_a": "fp16",
"wave_m": 2,
"wave_n": 2,
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"scheduler": "intrawave",
}
is_valid, error = validate_kernel_config(config, "gfx_invalid")
self.assertFalse(is_valid)
self.assertIn("unsupported", error.lower())
class TestGemmExpansion(unittest.TestCase):
"""Test GEMM wildcard expansion."""
def test_wave_expansion(self):
"""Wave wildcard should expand to valid configs."""
config = {
"name": "test_wave_expand",
"dtype_a": "fp16",
"dtype_b": "fp16",
"dtype_c": "fp16",
"layout": "rcr",
"tile_m": 128,
"tile_n": 128,
"tile_k": 32,
"wave_m": -1, # Wildcard
"wave_n": -1, # Wildcard
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"scheduler": "intrawave",
}
expanded = expand_declaration_with_arch_filter(config, "gfx942")
self.assertGreater(len(expanded), 0, "Should expand to at least one config")
# All expanded configs should be valid
for exp in expanded:
is_valid, error = validate_kernel_config(exp, "gfx942")
self.assertTrue(is_valid, f"Expanded config invalid: {error}")
def test_full_wildcard_expansion(self):
"""Full wildcard should expand to multiple valid configs."""
config = {
"name": "test_full_wildcard",
"dtype_a": "fp16",
"dtype_b": "fp16",
"dtype_c": "fp16",
"layout": "rcr",
"tile_m": 128,
"tile_n": 128,
"tile_k": 32,
"wave_m": -1,
"wave_n": -1,
"wave_k": 1,
"warp_m": -1,
"warp_n": -1,
"warp_k": -1,
"pipeline": "*",
"scheduler": "*",
}
expanded = expand_declaration_with_arch_filter(config, "gfx942")
self.assertGreater(
len(expanded), 1, "Full wildcard should expand to multiple configs"
)
def test_explicit_config_not_expanded(self):
"""Explicit (non-wildcard) config should not expand."""
config = {
"name": "test_explicit",
"dtype_a": "fp16",
"dtype_b": "fp16",
"dtype_c": "fp16",
"layout": "rcr",
"tile_m": 128,
"tile_n": 128,
"tile_k": 32,
"wave_m": 2,
"wave_n": 2,
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"scheduler": "intrawave",
}
expanded = expand_declaration_with_arch_filter(config, "gfx942")
self.assertEqual(len(expanded), 1, "Explicit config should not expand")
# =============================================================================
# CONV VALIDATION TESTS
# =============================================================================
class TestConvValidation(unittest.TestCase):
"""Test Conv kernel validation."""
def test_valid_conv_config(self):
"""Valid conv configuration should pass validation."""
config = {
"name": "test_valid_conv",
"dtype": "fp16",
"layout": "nhwgc",
"conv_type": "forward",
"tile_k": 128,
"tile_c": 128,
"wave_m": 2,
"wave_n": 2,
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"scheduler": "intrawave",
}
is_valid, error = validate_conv_kernel_config(config, "gfx942")
self.assertTrue(is_valid, f"Expected valid, got error: {error}")
def test_invalid_conv_wave(self):
"""Invalid wave config should fail conv validation."""
config = {
"name": "test_invalid_conv_wave",
"dtype": "fp16",
"wave_m": 5, # Invalid
"wave_n": 5, # Invalid
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"scheduler": "intrawave",
}
is_valid, error = validate_conv_kernel_config(config, "gfx942")
self.assertFalse(is_valid)
self.assertIn("wave", error.lower())
def test_conv_wildcard_detection(self):
"""Should correctly detect conv wildcards."""
wildcard_config = {
"wave_m": -1,
"wave_n": 2,
"warp_m": 32,
"warp_n": 32,
"pipeline": "compv4",
"scheduler": "intrawave",
}
self.assertTrue(is_conv_wildcard_declaration(wildcard_config))
explicit_config = {
"wave_m": 2,
"wave_n": 2,
"warp_m": 32,
"warp_n": 32,
"pipeline": "compv4",
"scheduler": "intrawave",
}
self.assertFalse(is_conv_wildcard_declaration(explicit_config))
class TestConvExpansion(unittest.TestCase):
"""Test Conv wildcard expansion."""
def test_conv_wave_expansion(self):
"""Conv wave wildcard should expand to valid configs."""
config = {
"name": "test_conv_wave_expand",
"dtype": "fp16",
"layout": "nhwgc",
"conv_type": "forward",
"tile_k": 128,
"tile_c": 128,
"wave_m": -1,
"wave_n": -1,
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"pipeline": "compv4",
"scheduler": "intrawave",
}
expanded = expand_conv_declaration_with_arch_filter(config, "gfx942")
self.assertGreater(len(expanded), 0, "Should expand to at least one config")
# =============================================================================
# PYTHON AUTO-CORRECTION TESTS
# =============================================================================
class TestPythonAutoCorrect(unittest.TestCase):
"""Test Python KernelConfig auto-correction."""
def test_autocorrect_invalid_wave(self):
"""Auto-correction should fix invalid wave config."""
config = KernelConfig()
config.dtype_a = "fp16"
config.dtype_b = "fp16"
config.dtype_c = "fp16"
config.dtype_acc = "fp32"
config.layout_a = "row"
config.layout_b = "col"
config.layout_c = "row"
config.tile_m = 128
config.tile_n = 128
config.tile_k = 32
config.wave_m = 1 # May be invalid
config.wave_n = 1 # May be invalid
config.wave_k = 1
config.warp_m = 32
config.warp_n = 32
config.warp_k = 16
config.pipeline = "compv4"
config.scheduler = "intrawave"
config.gfx_arch = "gfx942"
corrected, was_modified, corrections = auto_correct_kernel_config(
config, verbose=False
)
# Should either be valid or corrected
self.assertIsNotNone(corrected)
if was_modified:
self.assertGreater(len(corrections), 0)
def test_autocorrect_returns_three_values(self):
"""Auto-correction should return (config, was_modified, corrections)."""
config = KernelConfig()
config.dtype_a = "fp16"
config.dtype_b = "fp16"
config.dtype_c = "fp16"
config.dtype_acc = "fp32"
config.layout_a = "row"
config.layout_b = "col"
config.layout_c = "row"
config.tile_m = 128
config.tile_n = 128
config.tile_k = 32
config.wave_m = 2
config.wave_n = 2
config.wave_k = 1
config.warp_m = 32
config.warp_n = 32
config.warp_k = 16
config.pipeline = "compv4"
config.scheduler = "intrawave"
config.gfx_arch = "gfx942"
result = auto_correct_kernel_config(config, verbose=False)
self.assertEqual(len(result), 3, "Should return 3 values")
corrected, was_modified, corrections = result
self.assertIsInstance(was_modified, bool)
self.assertIsInstance(corrections, list)
# =============================================================================
# STRESS TESTS
# =============================================================================
class TestStressRandom(unittest.TestCase):
"""Stress test with random configurations."""
def test_random_gemm_configs(self):
"""Random GEMM configs should either validate or expand successfully."""
random.seed(42) # Reproducible
dtypes = ["fp16", "bf16"]
layouts = ["rcr", "rrr"]
tiles = [(64, 64, 32), (128, 128, 32), (256, 256, 64)]
waves = [(1, 1, 1), (2, 2, 1), (1, 4, 1), (3, 3, 1)] # Some invalid
warps = [(16, 16, 16), (32, 32, 16), (48, 48, 24)] # Some invalid
pipelines = ["compv3", "compv4", "invalid"]
schedulers = ["intrawave", "interwave"]
success_count = 0
total_count = 30
for _ in range(total_count):
config = {
"name": "random_test",
"dtype_a": random.choice(dtypes),
"dtype_b": random.choice(dtypes),
"dtype_c": random.choice(dtypes),
"layout": random.choice(layouts),
"tile_m": random.choice(tiles)[0],
"tile_n": random.choice(tiles)[1],
"tile_k": random.choice(tiles)[2],
"wave_m": random.choice(waves)[0],
"wave_n": random.choice(waves)[1],
"wave_k": random.choice(waves)[2],
"warp_m": random.choice(warps)[0],
"warp_n": random.choice(warps)[1],
"warp_k": random.choice(warps)[2],
"pipeline": random.choice(pipelines),
"scheduler": random.choice(schedulers),
}
is_valid, _ = validate_kernel_config(config, "gfx942")
if is_valid:
success_count += 1
else:
# Try wildcard expansion
wildcard = config.copy()
wildcard["wave_m"] = -1
wildcard["wave_n"] = -1
wildcard["warp_m"] = -1
wildcard["warp_n"] = -1
wildcard["pipeline"] = "*"
wildcard["scheduler"] = "*"
expanded = expand_declaration_with_arch_filter(wildcard, "gfx942")
if expanded:
success_count += 1
# At least 50% should be handleable
self.assertGreater(
success_count / total_count,
0.5,
f"Only {success_count}/{total_count} configs were handleable",
)
def test_random_conv_configs(self):
"""Random Conv configs should either validate or expand successfully."""
random.seed(42)
dtypes = ["fp16", "bf16"]
tiles = [(64, 64), (128, 128), (256, 256)]
waves = [(2, 2, 1), (1, 4, 1), (3, 3, 1)]
warps = [(16, 16, 16), (32, 32, 16)]
success_count = 0
total_count = 20
for _ in range(total_count):
config = {
"name": "random_conv_test",
"dtype": random.choice(dtypes),
"layout": "nhwgc",
"conv_type": "forward",
"tile_k": random.choice(tiles)[0],
"tile_c": random.choice(tiles)[1],
"wave_m": random.choice(waves)[0],
"wave_n": random.choice(waves)[1],
"wave_k": random.choice(waves)[2],
"warp_m": random.choice(warps)[0],
"warp_n": random.choice(warps)[1],
"warp_k": random.choice(warps)[2],
"pipeline": "compv4",
"scheduler": "intrawave",
}
is_valid, _ = validate_conv_kernel_config(config, "gfx942")
if is_valid:
success_count += 1
else:
# Try wildcard expansion
wildcard = config.copy()
wildcard["wave_m"] = -1
wildcard["wave_n"] = -1
wildcard["warp_m"] = -1
wildcard["warp_n"] = -1
expanded = expand_conv_declaration_with_arch_filter(wildcard, "gfx942")
if expanded:
success_count += 1
self.assertGreater(
success_count / total_count,
0.5,
f"Only {success_count}/{total_count} conv configs were handleable",
)
# =============================================================================
# ARCHITECTURE TESTS
# =============================================================================
class TestArchitectureSupport(unittest.TestCase):
"""Test architecture-specific support."""
def test_gfx942_fp16_support(self):
"""gfx942 should support fp16."""
config = {
"dtype_a": "fp16",
"wave_m": -1,
"wave_n": -1,
"warp_m": -1,
"warp_n": -1,
"pipeline": "*",
"scheduler": "*",
}
expanded = expand_declaration_with_arch_filter(config, "gfx942")
self.assertGreater(len(expanded), 0, "gfx942 should support fp16")
def test_gfx942_bf16_support(self):
"""gfx942 should support bf16."""
config = {
"dtype_a": "bf16",
"wave_m": -1,
"wave_n": -1,
"warp_m": -1,
"warp_n": -1,
"pipeline": "*",
"scheduler": "*",
}
expanded = expand_declaration_with_arch_filter(config, "gfx942")
self.assertGreater(len(expanded), 0, "gfx942 should support bf16")
def test_gfx90a_support(self):
"""gfx90a should support fp16."""
config = {
"dtype_a": "fp16",
"wave_m": -1,
"wave_n": -1,
"warp_m": -1,
"warp_n": -1,
"pipeline": "*",
"scheduler": "*",
}
expanded = expand_declaration_with_arch_filter(config, "gfx90a")
self.assertGreater(len(expanded), 0, "gfx90a should support fp16")
# =============================================================================
# MAIN
# =============================================================================
def main():
"""Run tests."""
# Parse args for verbosity
verbosity = 2 if "-v" in sys.argv or "--verbose" in sys.argv else 1
# Create test suite
loader = unittest.TestLoader()
suite = unittest.TestSuite()
# Add all test classes
suite.addTests(loader.loadTestsFromTestCase(TestGemmValidation))
suite.addTests(loader.loadTestsFromTestCase(TestGemmExpansion))
suite.addTests(loader.loadTestsFromTestCase(TestConvValidation))
suite.addTests(loader.loadTestsFromTestCase(TestConvExpansion))
suite.addTests(loader.loadTestsFromTestCase(TestPythonAutoCorrect))
suite.addTests(loader.loadTestsFromTestCase(TestStressRandom))
suite.addTests(loader.loadTestsFromTestCase(TestArchitectureSupport))
# Run tests
runner = unittest.TextTestRunner(verbosity=verbosity)
result = runner.run(suite)
# Return exit code
return 0 if result.wasSuccessful() else 1
if __name__ == "__main__":
sys.exit(main())