mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
[CK] [CK_Tile] Add GroupConv to Kernel Dispatcher ## Motivation This PR adds CK Tile group convolution (forward, backward-data, backward-weight) support to the kernel dispatcher, matching and unifying with the existing dispatcher GEMM infrastructure in architecture and usability. The dispatcher provides a unified kernel dispatch system with both C++ and Python frontends, and until now only supported GEMM operations. This PR enables framework integrators to use the same declarative kernel workflow for convolutions as they do for GEMM: declare kernels, build a registry JIT, select kernels within the registry at runtime, and dispatch to GPU. Future PRs will include runtime kernel selection heuristics for autotuning of kernel parameters based on (problem, hardware arch). ## Technical Details Grouped convolution support has been added to the CK Tile Dispatcher with generated_conv_backend.hpp enabling dispatcher.run(in, wei, out, problem) for all 6 conv variants (fwd/bwdd/bwdw x 2D/3D), runtime heuristic kernel selection, and GroupedConvKernelKey with full ConvConfigBase fields. Python side adds parallel JIT via registry.build(max_workers) and heuristic registry.select(). Includes 7 C++ and 6 Python examples covering all directions with CPU reference validation, and shared infrastructure improvements (BaseRegistry CRTP, structured exceptions). As a sanity check, JIT compile times for a single kernel remains the same and for multiple kernels there is better parallelism: Kernels | 1 worker | 8 workers 1 | 7.7 s | 7.7 s 2 | 15.9 s | 8.2 s 4 | 33.4 s | 9.7 s 6 | 52.3 s | 10.2 s ## Test Plan 145 ephemeral unit tests have been added to test basic functionality. All 30 examples/integration tests run end-to-end on gfx950 (MI350): 7 C++ conv, 7 C++ GEMM, 6 Python conv, 10 Python GEMM. CPU reference validation for forward, backward-data, and backward-weight (2D) in both C++ and Python examples pass. ## Test Result 30 examples pass. Peak performance: 132 TFLOPS (Batch-32 forward 56x56), 53 TFLOPS (pointwise 1x1). CPU reference accuracy: max_abs_diff < 0.002 for all directions (fp16 vs fp32 reference). ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
245 lines
8.9 KiB
Python
245 lines
8.9 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Tests for codegen/codegen_common.py -- shared infrastructure for GEMM and grouped conv codegen.
|
|
|
|
Phase 1a TDD: these tests are written BEFORE the implementation exists.
|
|
Run: python3 -m pytest tests/test_codegen_common.py -v
|
|
"""
|
|
|
|
import sys
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
SCRIPT_DIR = Path(__file__).parent.resolve()
|
|
DISPATCHER_DIR = SCRIPT_DIR.parent
|
|
sys.path.insert(0, str(DISPATCHER_DIR / "codegen"))
|
|
|
|
from codegen_common import ( # noqa: E402
|
|
TileConfig,
|
|
TraitConfigBase,
|
|
CommonTypeMappings,
|
|
generate_cpp_compilation_unit,
|
|
parallel_generate,
|
|
valid_wave_configs,
|
|
valid_warp_configs,
|
|
valid_trait_configs,
|
|
needs_wave_expansion,
|
|
needs_warp_expansion,
|
|
needs_pipeline_expansion,
|
|
)
|
|
|
|
|
|
class TestTileConfig(unittest.TestCase):
|
|
"""TileConfig dataclass tests."""
|
|
|
|
def test_valid_config(self):
|
|
tc = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)
|
|
self.assertTrue(tc.is_valid())
|
|
|
|
def test_zero_tile_invalid(self):
|
|
tc = TileConfig(0, 128, 32, 2, 2, 1, 32, 32, 16)
|
|
self.assertFalse(tc.is_valid())
|
|
|
|
def test_non_divisible_invalid(self):
|
|
tc = TileConfig(127, 128, 32, 2, 2, 1, 32, 32, 16)
|
|
self.assertFalse(tc.is_valid())
|
|
|
|
def test_all_fields_accessible(self):
|
|
tc = TileConfig(256, 128, 64, 4, 1, 1, 32, 32, 16)
|
|
self.assertEqual(tc.tile_m, 256)
|
|
self.assertEqual(tc.tile_n, 128)
|
|
self.assertEqual(tc.tile_k, 64)
|
|
self.assertEqual(tc.warp_m, 4)
|
|
self.assertEqual(tc.warp_n, 1)
|
|
self.assertEqual(tc.warp_k, 1)
|
|
self.assertEqual(tc.warp_tile_m, 32)
|
|
self.assertEqual(tc.warp_tile_n, 32)
|
|
self.assertEqual(tc.warp_tile_k, 16)
|
|
|
|
def test_small_valid_config(self):
|
|
tc = TileConfig(16, 16, 16, 1, 1, 1, 16, 16, 16)
|
|
self.assertTrue(tc.is_valid())
|
|
|
|
|
|
class TestTraitConfigBase(unittest.TestCase):
|
|
"""TraitConfigBase dataclass tests."""
|
|
|
|
def test_valid_intrawave(self):
|
|
tc = TraitConfigBase("compv3", "cshuffle", "intrawave", False, False, False)
|
|
self.assertTrue(tc.is_valid())
|
|
|
|
def test_invalid_interwave_compv3(self):
|
|
tc = TraitConfigBase("compv3", "cshuffle", "interwave", False, False, False)
|
|
self.assertFalse(tc.is_valid())
|
|
|
|
def test_invalid_interwave_compv4(self):
|
|
tc = TraitConfigBase("compv4", "cshuffle", "interwave", False, False, False)
|
|
self.assertFalse(tc.is_valid())
|
|
|
|
def test_valid_mem_interwave(self):
|
|
tc = TraitConfigBase("mem", "cshuffle", "interwave", False, False, False)
|
|
self.assertTrue(tc.is_valid())
|
|
|
|
def test_valid_mem_intrawave(self):
|
|
tc = TraitConfigBase("mem", "cshuffle", "intrawave", False, False, False)
|
|
self.assertTrue(tc.is_valid())
|
|
|
|
def test_padding_fields(self):
|
|
tc = TraitConfigBase("compv3", "cshuffle", "intrawave", True, True, True)
|
|
self.assertTrue(tc.pad_m)
|
|
self.assertTrue(tc.pad_n)
|
|
self.assertTrue(tc.pad_k)
|
|
|
|
|
|
class TestCommonTypeMappings(unittest.TestCase):
|
|
"""CommonTypeMappings tests."""
|
|
|
|
def test_dtype_to_ck(self):
|
|
self.assertEqual(CommonTypeMappings.DTYPE_TO_CK["fp16"], "fp16_t")
|
|
self.assertEqual(CommonTypeMappings.DTYPE_TO_CK["bf16"], "bf16_t")
|
|
self.assertEqual(CommonTypeMappings.DTYPE_TO_CK["fp32"], "float")
|
|
self.assertEqual(CommonTypeMappings.DTYPE_TO_CK["fp8"], "fp8_t")
|
|
|
|
def test_pipeline_to_ck(self):
|
|
self.assertEqual(
|
|
CommonTypeMappings.PIPELINE_TO_CK["mem"], "GemmPipelineAgBgCrMem"
|
|
)
|
|
self.assertIn("compv3", CommonTypeMappings.PIPELINE_TO_CK)
|
|
self.assertIn("compv4", CommonTypeMappings.PIPELINE_TO_CK)
|
|
|
|
def test_pipeline_to_base(self):
|
|
self.assertIn("mem", CommonTypeMappings.PIPELINE_TO_BASE)
|
|
self.assertIn("compv3", CommonTypeMappings.PIPELINE_TO_BASE)
|
|
self.assertIn("compv4", CommonTypeMappings.PIPELINE_TO_BASE)
|
|
|
|
def test_scheduler_to_ck(self):
|
|
self.assertIn("intrawave", CommonTypeMappings.SCHEDULER_TO_CK)
|
|
self.assertIn("interwave", CommonTypeMappings.SCHEDULER_TO_CK)
|
|
|
|
def test_epilogue_to_dispatcher(self):
|
|
self.assertIn("cshuffle", CommonTypeMappings.EPILOGUE_TO_DISPATCHER)
|
|
self.assertIn("default", CommonTypeMappings.EPILOGUE_TO_DISPATCHER)
|
|
|
|
def test_layout_to_ck(self):
|
|
self.assertIn("r", CommonTypeMappings.LAYOUT_TO_CK)
|
|
self.assertIn("c", CommonTypeMappings.LAYOUT_TO_CK)
|
|
|
|
def test_get_output_dtype(self):
|
|
self.assertEqual(CommonTypeMappings.get_output_dtype("fp8"), "fp16")
|
|
self.assertEqual(CommonTypeMappings.get_output_dtype("bf8"), "fp16")
|
|
self.assertEqual(CommonTypeMappings.get_output_dtype("fp16"), "fp16")
|
|
self.assertEqual(CommonTypeMappings.get_output_dtype("fp32"), "fp32")
|
|
|
|
|
|
class TestGenerateCppCompilationUnit(unittest.TestCase):
|
|
"""Tests for generate_cpp_compilation_unit."""
|
|
|
|
def test_includes_kernel_header(self):
|
|
result = generate_cpp_compilation_unit("my_kernel")
|
|
self.assertIn('#include "my_kernel.hpp"', result)
|
|
|
|
def test_contains_pragma_once_or_guard(self):
|
|
result = generate_cpp_compilation_unit("test_kernel")
|
|
self.assertIn("test_kernel", result)
|
|
|
|
def test_different_names_different_output(self):
|
|
a = generate_cpp_compilation_unit("kernel_a")
|
|
b = generate_cpp_compilation_unit("kernel_b")
|
|
self.assertNotEqual(a, b)
|
|
|
|
|
|
class TestParallelGenerate(unittest.TestCase):
|
|
"""Tests for parallel_generate helper."""
|
|
|
|
def _dummy_generate(self, item):
|
|
return f"generated_{item}"
|
|
|
|
def test_parallel_returns_all(self):
|
|
items = ["a", "b", "c", "d"]
|
|
results = parallel_generate(self._dummy_generate, items, parallel=True)
|
|
self.assertEqual(len(results), 4)
|
|
for item in items:
|
|
self.assertIn(f"generated_{item}", results)
|
|
|
|
def test_sequential_returns_all(self):
|
|
items = ["x", "y", "z"]
|
|
results = parallel_generate(self._dummy_generate, items, parallel=False)
|
|
self.assertEqual(len(results), 3)
|
|
for item in items:
|
|
self.assertIn(f"generated_{item}", results)
|
|
|
|
def test_empty_items(self):
|
|
results = parallel_generate(self._dummy_generate, [], parallel=True)
|
|
self.assertEqual(len(results), 0)
|
|
|
|
def test_logs_per_kernel_progress(self):
|
|
items = ["k1", "k2"]
|
|
with self.assertLogs(level="INFO") as cm:
|
|
parallel_generate(self._dummy_generate, items, parallel=False)
|
|
log_output = "\n".join(cm.output)
|
|
self.assertIn("k1", log_output)
|
|
self.assertIn("k2", log_output)
|
|
|
|
|
|
class TestArchAwareExpansion(unittest.TestCase):
|
|
"""Tests for arch-aware expansion helpers (best-of-conv)."""
|
|
|
|
def test_valid_wave_configs_gfx942(self):
|
|
configs = valid_wave_configs("gfx942")
|
|
self.assertIsInstance(configs, list)
|
|
self.assertIn([2, 2, 1], configs)
|
|
self.assertIn([1, 4, 1], configs)
|
|
|
|
def test_valid_wave_configs_unknown_arch(self):
|
|
configs = valid_wave_configs("gfx_unknown")
|
|
self.assertIsInstance(configs, list)
|
|
self.assertGreater(len(configs), 0)
|
|
|
|
def test_valid_warp_configs_gfx942_fp16(self):
|
|
configs = valid_warp_configs("gfx942", "fp16")
|
|
self.assertIsInstance(configs, list)
|
|
self.assertIn([32, 32, 16], configs)
|
|
|
|
def test_valid_warp_configs_unknown_arch(self):
|
|
configs = valid_warp_configs("gfx_unknown", "fp16")
|
|
self.assertIsInstance(configs, list)
|
|
self.assertGreater(len(configs), 0)
|
|
|
|
def test_valid_trait_configs_excludes_interwave_compute(self):
|
|
configs = valid_trait_configs()
|
|
self.assertIsInstance(configs, list)
|
|
self.assertNotIn(("compv3", "cshuffle", "interwave"), configs)
|
|
self.assertNotIn(("compv4", "cshuffle", "interwave"), configs)
|
|
|
|
def test_valid_trait_configs_includes_mem_interwave(self):
|
|
configs = valid_trait_configs()
|
|
has_mem_interwave = any(p == "mem" and s == "interwave" for p, s in configs)
|
|
self.assertTrue(has_mem_interwave)
|
|
|
|
def test_needs_wave_expansion_wildcard(self):
|
|
self.assertTrue(needs_wave_expansion({"wave_m": -1, "wave_n": 2}))
|
|
self.assertTrue(needs_wave_expansion({"wave_m": 2, "wave_n": -1}))
|
|
|
|
def test_needs_wave_expansion_explicit(self):
|
|
self.assertFalse(needs_wave_expansion({"wave_m": 2, "wave_n": 2}))
|
|
|
|
def test_needs_warp_expansion_wildcard(self):
|
|
self.assertTrue(needs_warp_expansion({"warp_m": -1, "warp_n": 32}))
|
|
|
|
def test_needs_warp_expansion_explicit(self):
|
|
self.assertFalse(needs_warp_expansion({"warp_m": 32, "warp_n": 32}))
|
|
|
|
def test_needs_pipeline_expansion_wildcard(self):
|
|
self.assertTrue(needs_pipeline_expansion({"pipeline": "*"}))
|
|
|
|
def test_needs_pipeline_expansion_explicit(self):
|
|
self.assertFalse(needs_pipeline_expansion({"pipeline": "compv4"}))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|