Files
composable_kernel/dispatcher/tests/test_grouped_conv_utils.py
Vidyasagar Ananthan 920acd2c12 [rocm-libraries] ROCm/rocm-libraries#5168 (commit 8b5afcb)
[CK] [CK_Tile] Add GroupConv to Kernel Dispatcher

## Motivation

This PR adds CK Tile group convolution (forward, backward-data,
backward-weight) support to the kernel dispatcher, matching and unifying
with the existing dispatcher GEMM infrastructure in architecture and
usability. The dispatcher provides a unified kernel dispatch system with
both C++ and Python frontends, and until now only supported GEMM
operations. This PR enables framework integrators to use the same
declarative kernel workflow for convolutions as they do for GEMM:
declare kernels, build a registry JIT, select kernels within the
registry at runtime, and dispatch to GPU. Future PRs will include
runtime kernel selection heuristics for autotuning of kernel parameters
based on (problem, hardware arch).

## Technical Details

Grouped convolution support has been added to the CK Tile Dispatcher
with generated_conv_backend.hpp enabling dispatcher.run(in, wei, out,
problem) for all 6 conv variants (fwd/bwdd/bwdw x 2D/3D), runtime
heuristic kernel selection, and GroupedConvKernelKey with full
ConvConfigBase fields. Python side adds parallel JIT via
registry.build(max_workers) and heuristic registry.select(). Includes 7
C++ and 6 Python examples covering all directions with CPU reference
validation, and shared infrastructure improvements (BaseRegistry CRTP,
structured exceptions). As a sanity check, JIT compile times for a
single kernel remains the same and for multiple kernels there is better
parallelism:
Kernels | 1 worker | 8 workers
1 | 7.7 s | 7.7 s
2 | 15.9 s | 8.2 s
4 | 33.4 s | 9.7 s
6 | 52.3 s | 10.2 s

## Test Plan

145 ephemeral unit tests have been added to test basic functionality.
All 30 examples/integration tests run end-to-end on gfx950 (MI350): 7
C++ conv, 7 C++ GEMM, 6 Python conv, 10 Python GEMM. CPU reference
validation for forward, backward-data, and backward-weight (2D) in both
C++ and Python examples pass.

## Test Result

30 examples pass. Peak performance: 132 TFLOPS (Batch-32 forward 56x56),
53 TFLOPS (pointwise 1x1). CPU reference accuracy: max_abs_diff < 0.002
for all directions (fp16 vs fp32 reference).

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-09 17:39:35 +00:00

350 lines
13 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
TDD tests for python/grouped_conv_utils.py -- grouped convolution Python utilities.
Phase 1 TDD: tests written BEFORE implementation exists.
Run: python3 -m pytest tests/test_grouped_conv_utils.py -v
"""
import sys
import unittest
from pathlib import Path
SCRIPT_DIR = Path(__file__).parent.resolve()
DISPATCHER_DIR = SCRIPT_DIR.parent
sys.path.insert(0, str(DISPATCHER_DIR / "python"))
sys.path.insert(0, str(DISPATCHER_DIR / "codegen"))
from dispatcher_common import ValidationResultBase # noqa: E402
from grouped_conv_utils import ( # noqa: E402
GroupedConvValidationResult,
validate_grouped_conv_config,
auto_correct_grouped_conv_config,
get_grouped_conv_default_config,
GroupedConvDataType,
format_grouped_conv_summary,
)
# =============================================================================
# VALID CONFIG FIXTURES
# =============================================================================
def make_valid_grouped_conv_config():
"""Return a valid grouped conv config dict for gfx942."""
return {
"tile_config": {
"tile_k": 128,
"tile_c": 128,
"wave_m": 2,
"wave_n": 2,
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
},
"trait_config": {
"pipeline": "compv4",
"epilogue": "cshuffle",
"scheduler": "intrawave",
},
"variant": "2d_fwd",
"ndim_spatial": 2,
"arch": "gfx942",
"layout": "nhwgc",
"dtype": "fp16",
}
# =============================================================================
# TestGroupedConvValidationResult
# =============================================================================
class TestGroupedConvValidationResult(unittest.TestCase):
"""Tests for GroupedConvValidationResult dataclass."""
def test_inherits_from_validation_result_base(self):
"""GroupedConvValidationResult should inherit from ValidationResultBase."""
self.assertTrue(
issubclass(GroupedConvValidationResult, ValidationResultBase),
"GroupedConvValidationResult must inherit from ValidationResultBase",
)
def test_valid_result_has_is_valid(self):
"""Valid result has is_valid=True."""
vr = GroupedConvValidationResult(is_valid=True)
self.assertTrue(vr.is_valid)
def test_invalid_result_has_is_valid_false(self):
"""Invalid result has is_valid=False."""
vr = GroupedConvValidationResult(is_valid=False, errors=["bad config"])
self.assertFalse(vr.is_valid)
def test_has_errors_list(self):
"""Result has errors list."""
vr = GroupedConvValidationResult(
is_valid=False,
errors=["invalid wave", "invalid trait"],
)
self.assertEqual(len(vr.errors), 2)
self.assertIn("invalid wave", vr.errors)
self.assertIn("invalid trait", vr.errors)
def test_has_warnings_list(self):
"""Result has warnings list."""
vr = GroupedConvValidationResult(
is_valid=True,
warnings=["deprecated option"],
)
self.assertEqual(len(vr.warnings), 1)
self.assertIn("deprecated option", vr.warnings)
def test_has_suggested_fixes_dict(self):
"""Result has suggested_fixes dict."""
vr = GroupedConvValidationResult(
is_valid=False,
suggested_fixes={"wave_m": 2, "wave_n": 2},
)
self.assertIn("wave_m", vr.suggested_fixes)
self.assertEqual(vr.suggested_fixes["wave_m"], 2)
self.assertIn("wave_n", vr.suggested_fixes)
self.assertEqual(vr.suggested_fixes["wave_n"], 2)
def test_default_empty_errors_warnings_fixes(self):
"""Default result has empty errors, warnings, suggested_fixes."""
vr = GroupedConvValidationResult(is_valid=True)
self.assertEqual(vr.errors, [])
self.assertEqual(vr.warnings, [])
self.assertEqual(vr.suggested_fixes, {})
# =============================================================================
# TestValidateGroupedConvConfig
# =============================================================================
class TestValidateGroupedConvConfig(unittest.TestCase):
"""Tests for validate_grouped_conv_config."""
def test_valid_config_passes(self):
"""Valid config should pass validation."""
config = make_valid_grouped_conv_config()
result = validate_grouped_conv_config(config)
self.assertTrue(result.is_valid, f"Expected valid, got errors: {result.errors}")
self.assertEqual(result.errors, [])
def test_invalid_wave_config_fails(self):
"""Invalid wave config should fail validation."""
config = make_valid_grouped_conv_config()
config["tile_config"]["wave_m"] = 3
config["tile_config"]["wave_n"] = 3
result = validate_grouped_conv_config(config)
self.assertFalse(result.is_valid)
self.assertGreater(len(result.errors), 0)
error_str = " ".join(result.errors).lower()
self.assertIn("wave", error_str)
def test_invalid_trait_fails(self):
"""Invalid trait combination should fail validation."""
config = make_valid_grouped_conv_config()
config["trait_config"]["pipeline"] = "compv4"
config["trait_config"]["epilogue"] = "cshuffle"
config["trait_config"]["scheduler"] = "interwave" # Invalid combo
result = validate_grouped_conv_config(config)
self.assertFalse(result.is_valid)
self.assertGreater(len(result.errors), 0)
error_str = " ".join(result.errors).lower()
self.assertIn("trait", error_str)
def test_missing_fields_fails(self):
"""Config with missing required fields should fail validation."""
config = {"arch": "gfx942"} # Missing tile_config, trait_config, etc.
result = validate_grouped_conv_config(config)
self.assertFalse(result.is_valid)
self.assertGreater(len(result.errors), 0)
# =============================================================================
# TestAutoCorrectGroupedConvConfig
# =============================================================================
class TestAutoCorrectGroupedConvConfig(unittest.TestCase):
"""Tests for auto_correct_grouped_conv_config."""
def test_invalid_wave_gets_corrected(self):
"""Invalid wave config should be auto-corrected."""
config = make_valid_grouped_conv_config()
config["tile_config"]["wave_m"] = 3
config["tile_config"]["wave_n"] = 3
corrected, result = auto_correct_grouped_conv_config(config)
self.assertIsInstance(corrected, dict)
self.assertIsInstance(result, GroupedConvValidationResult)
# Corrected wave should be valid for arch
wave_m = corrected.get("tile_config", {}).get("wave_m")
wave_n = corrected.get("tile_config", {}).get("wave_n")
self.assertIn(wave_m, [1, 2, 4])
self.assertIn(wave_n, [1, 2, 4])
def test_invalid_trait_gets_corrected(self):
"""Invalid trait combination should be auto-corrected."""
config = make_valid_grouped_conv_config()
config["trait_config"]["scheduler"] = "interwave"
config["trait_config"]["pipeline"] = "compv4"
config["trait_config"]["epilogue"] = "cshuffle"
corrected, result = auto_correct_grouped_conv_config(config)
self.assertIsInstance(corrected, dict)
self.assertIsInstance(result, GroupedConvValidationResult)
# Scheduler should be corrected to intrawave for compv4+cshuffle
scheduler = corrected.get("trait_config", {}).get("scheduler")
self.assertEqual(scheduler, "intrawave")
# =============================================================================
# TestGetGroupedConvDefaultConfig
# =============================================================================
class TestGetGroupedConvDefaultConfig(unittest.TestCase):
"""Tests for get_grouped_conv_default_config."""
def test_returns_config(self):
"""Should return a GroupedConvKernelConfig (or dict via to_dict)."""
config = get_grouped_conv_default_config("2d_fwd")
# Accepts both dataclass and dict
d = config.to_dict() if hasattr(config, "to_dict") else config
self.assertIsInstance(d, dict)
def test_has_tile_config(self):
"""Returned config has tile_config key."""
config = get_grouped_conv_default_config("2d_fwd")
d = config.to_dict() if hasattr(config, "to_dict") else config
self.assertIn("tile_config", d)
self.assertIsInstance(d["tile_config"], dict)
def test_has_trait_config(self):
"""Returned config has trait_config key."""
config = get_grouped_conv_default_config("2d_fwd")
d = config.to_dict() if hasattr(config, "to_dict") else config
self.assertIn("trait_config", d)
self.assertIsInstance(d["trait_config"], dict)
def test_has_variant(self):
"""Returned config has variant."""
config = get_grouped_conv_default_config("2d_fwd")
d = config.to_dict() if hasattr(config, "to_dict") else config
self.assertIn("variant", d)
def test_has_ndim_spatial(self):
"""Returned config has ndim_spatial."""
config = get_grouped_conv_default_config("2d_fwd")
d = config.to_dict() if hasattr(config, "to_dict") else config
self.assertIn("ndim_spatial", d)
def test_has_arch(self):
"""Returned config has arch."""
config = get_grouped_conv_default_config("2d_fwd")
d = config.to_dict() if hasattr(config, "to_dict") else config
self.assertIn("arch", d)
def test_has_layout(self):
"""Returned config has layout."""
config = get_grouped_conv_default_config("2d_fwd")
d = config.to_dict() if hasattr(config, "to_dict") else config
self.assertIn("layout", d)
# =============================================================================
# TestGroupedConvDataType
# =============================================================================
class TestGroupedConvDataType(unittest.TestCase):
"""Tests for GroupedConvDataType enum."""
def test_fp16_exists(self):
"""GroupedConvDataType has FP16."""
self.assertIsNotNone(GroupedConvDataType.FP16)
def test_bf16_exists(self):
"""GroupedConvDataType has BF16."""
self.assertIsNotNone(GroupedConvDataType.BF16)
def test_fp32_exists(self):
"""GroupedConvDataType has FP32."""
self.assertIsNotNone(GroupedConvDataType.FP32)
def test_fp8_exists(self):
"""GroupedConvDataType has FP8."""
self.assertIsNotNone(GroupedConvDataType.FP8)
def test_bf8_exists(self):
"""GroupedConvDataType has BF8."""
self.assertIsNotNone(GroupedConvDataType.BF8)
def test_int8_exists(self):
"""GroupedConvDataType has INT8."""
self.assertIsNotNone(GroupedConvDataType.INT8)
def test_enum_values_unique(self):
"""All enum values should be unique."""
values = [
GroupedConvDataType.FP16,
GroupedConvDataType.BF16,
GroupedConvDataType.FP32,
GroupedConvDataType.FP8,
GroupedConvDataType.BF8,
GroupedConvDataType.INT8,
]
self.assertEqual(len(values), len(set(values)))
# =============================================================================
# TestFormatGroupedConvSummary
# =============================================================================
class TestFormatGroupedConvSummary(unittest.TestCase):
"""Tests for format_grouped_conv_summary."""
def test_returns_non_empty_string(self):
"""Should return a non-empty string."""
config = make_valid_grouped_conv_config()
summary = format_grouped_conv_summary(config)
self.assertIsInstance(summary, str)
self.assertGreater(len(summary), 0)
def test_contains_key_info(self):
"""Summary should contain key config info (variant, arch, layout, dtype)."""
config = make_valid_grouped_conv_config()
summary = format_grouped_conv_summary(config)
# Should mention at least some of: variant, arch, layout, dtype
summary_lower = summary.lower()
has_key_info = (
"2d" in summary_lower
or "fwd" in summary_lower
or "gfx" in summary_lower
or "nhwgc" in summary_lower
or "fp16" in summary_lower
)
self.assertTrue(
has_key_info,
f"Summary should contain key info, got: {summary}",
)
def test_empty_config_returns_something(self):
"""Empty or minimal config should still return a string."""
summary = format_grouped_conv_summary({})
self.assertIsInstance(summary, str)
self.assertGreaterEqual(len(summary), 0)
if __name__ == "__main__":
unittest.main()