mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
[CK] [CK_Tile] Add GroupConv to Kernel Dispatcher ## Motivation This PR adds CK Tile group convolution (forward, backward-data, backward-weight) support to the kernel dispatcher, matching and unifying with the existing dispatcher GEMM infrastructure in architecture and usability. The dispatcher provides a unified kernel dispatch system with both C++ and Python frontends, and until now only supported GEMM operations. This PR enables framework integrators to use the same declarative kernel workflow for convolutions as they do for GEMM: declare kernels, build a registry JIT, select kernels within the registry at runtime, and dispatch to GPU. Future PRs will include runtime kernel selection heuristics for autotuning of kernel parameters based on (problem, hardware arch). ## Technical Details Grouped convolution support has been added to the CK Tile Dispatcher with generated_conv_backend.hpp enabling dispatcher.run(in, wei, out, problem) for all 6 conv variants (fwd/bwdd/bwdw x 2D/3D), runtime heuristic kernel selection, and GroupedConvKernelKey with full ConvConfigBase fields. Python side adds parallel JIT via registry.build(max_workers) and heuristic registry.select(). Includes 7 C++ and 6 Python examples covering all directions with CPU reference validation, and shared infrastructure improvements (BaseRegistry CRTP, structured exceptions). As a sanity check, JIT compile times for a single kernel remains the same and for multiple kernels there is better parallelism: Kernels | 1 worker | 8 workers 1 | 7.7 s | 7.7 s 2 | 15.9 s | 8.2 s 4 | 33.4 s | 9.7 s 6 | 52.3 s | 10.2 s ## Test Plan 145 ephemeral unit tests have been added to test basic functionality. All 30 examples/integration tests run end-to-end on gfx950 (MI350): 7 C++ conv, 7 C++ GEMM, 6 Python conv, 10 Python GEMM. CPU reference validation for forward, backward-data, and backward-weight (2D) in both C++ and Python examples pass. ## Test Result 30 examples pass. Peak performance: 132 TFLOPS (Batch-32 forward 56x56), 53 TFLOPS (pointwise 1x1). CPU reference accuracy: max_abs_diff < 0.002 for all directions (fp16 vs fp32 reference). ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
GEMM C++ Examples
CK Tile Dispatcher C++ examples for GEMM (General Matrix Multiplication) operations.
Main Documentation: Dispatcher README | Examples Overview
Quick Start
Build and Run
cd /path/to/composable_kernel/dispatcher
mkdir -p build && cd build
cmake .. \
-DCMAKE_PREFIX_PATH=/opt/rocm \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DBUILD_DISPATCHER_EXAMPLES=ON
# Build (kernels generated automatically by CMake)
make -j$(nproc)
# Run examples
cd examples
./gemm_01_basic
./gemm_03_benchmark_validation
./gemm_04_heuristics
Examples
| Example | Description |
|---|---|
| 01_basic_gemm.cpp | Basic GEMM with declarative API, autofill, autocorrect |
| 02_multi_size.cpp | Wildcard expansion for multiple configurations |
| 03_benchmark_validation.cpp | Performance benchmarking with CPU reference validation |
| 04_heuristics.cpp | Heuristic-based kernel selection |
| 05_json_export.cpp | Registry JSON export for external tools |
| 06_multi_registry.cpp | Multiple registries with named kernel sets |
Example Details
01_basic_gemm.cpp - Basic GEMM
Demonstrates the declarative kernel API with three patterns:
- Autofill Pattern - Minimal specification, defaults filled automatically
- Autocorrect Pattern - Invalid parameters corrected at build time
- Full Specification Pattern - Complete kernel configuration
DECL_KERNEL_SET(basic_kernels,
// Pattern 1: Autofill - minimal specification
.add(
Signature().dtype("fp16").layout("rcr"),
Algorithm(), // Defaults filled by autofill
"gfx942"
)
// Pattern 2: Full specification
.add(
Signature().dtype("fp16").layout("rcr"),
Algorithm().tile(256, 256, 32).wave(2, 2, 1).warp(32, 32, 16)
.pipeline("compv4").scheduler("intrawave"),
"gfx942"
)
);
Features:
- Uses generic
REGISTER_GENERATED_KERNELSmacro print_registered_kernels()utility for debugging- Demonstrates autofill messages during build
02_multi_size.cpp - Wildcard Expansion
Demonstrates automatic generation of multiple kernel configurations:
DECL_KERNEL_SET(multi_kernels,
.add(
Signature().dtype("fp16").layout("rcr"),
Algorithm().tile(*, *, 32) // Wildcard tile M and N
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv4")
.scheduler("intrawave"),
"gfx942"
)
);
Wildcard Values:
*,-1, orANY_INTexpand to all valid configurations- Architecture filter prunes invalid combinations automatically
- Example generates 5 valid kernels after arch filtering (from 7 expansions)
03_benchmark_validation.cpp - Benchmark + Validation
Consolidated example combining performance benchmarking with correctness validation:
# Benchmark only
./gemm_03_benchmark_validation --warmup 10 --iterations 100
# With CPU validation
./gemm_03_benchmark_validation --verify 1 --rtol 1e-3 --atol 1e-3
# With GPU reference validation (faster for large matrices)
./gemm_03_benchmark_validation --verify 2
Features:
- Warmup iterations (discarded from timing)
- Benchmark iterations with statistics (min/max/mean/median)
- CPU reference validation using
ck_tile::reference_gemm - GPU reference validation using
ck_tile::reference_gemm_gpu - Configurable tolerances
04_heuristics.cpp - Heuristic Selection
Demonstrates custom kernel selection based on problem characteristics:
// Problem size analysis
auto heuristic = [](const Problem& p) -> std::optional<KernelKey> {
if (p.M() * p.N() < 256 * 256) {
return small_kernel_key; // Memory-bound heuristic
} else {
return large_kernel_key; // Compute-bound heuristic
}
};
dispatcher.set_heuristic(heuristic);
Features:
- Problem size analysis (small vs large matrices)
- Compute-bound vs memory-bound selection
- Custom heuristic function registration
05_json_export.cpp - JSON Export
Exports registry information to JSON for external tool integration:
auto json = registry.to_json();
std::ofstream file("kernels.json");
file << json;
Use Cases:
- Kernel metadata serialization
- External analysis tools
- Configuration management
06_multi_registry.cpp - Multiple Registries
Demonstrates using multiple registries with named kernel sets:
// Define separate kernel sets
DECL_KERNEL_SET(compute_optimized, ...);
DECL_KERNEL_SET(latency_optimized, ...);
// Register to specific registries
Registry compute_registry, latency_registry;
REGISTER_KERNEL_SET(compute_optimized, compute_registry);
REGISTER_KERNEL_SET(latency_optimized, latency_registry);
// Use appropriate registry based on workload
Dispatcher compute_dispatcher(compute_registry);
Dispatcher latency_dispatcher(latency_registry);
Features:
- Named kernel set registration with
REGISTER_KERNEL_SETmacro - Separate registries for different optimization goals
- Dynamic kernel set selection by name
Benchmark Parameters (stream_config)
CK Tile uses stream_config for benchmark control:
ck_tile::stream_config cfg{
nullptr, // stream_id - HIP stream (nullptr = default)
true, // time_kernel - Enable timing
1, // log_level - Verbosity (0=quiet, 1=normal)
5, // cold_niters - Warmup iterations
20, // nrepeat - Benchmark iterations
true, // is_gpu_timer - Use GPU events vs CPU chrono
false, // flush_cache - Flush L2 cache between iterations
1 // rotating_count - Rotating buffers for cache simulation
};
| Parameter | CLI Option | Default | Description |
|---|---|---|---|
cold_niters_ |
--warmup |
5 | Warmup iterations |
nrepeat_ |
--iterations |
100 | Benchmark iterations |
flush_cache_ |
- | false | Flush L2 cache |
rotating_count_ |
- | 1 | Rotating buffers |
is_gpu_timer_ |
- | true | GPU timer vs CPU |
Declarative Kernel Pattern
All examples use the declarative DECL_KERNEL_SET macro:
DECL_KERNEL_SET(my_kernels,
.add(
Signature() // WHAT: operation signature
.dtype("fp16") // Data type
.layout("rcr"), // Matrix layouts (A=row, B=col, C=row)
Algorithm() // HOW: implementation details
.tile(256, 256, 32) // Tile sizes (M, N, K)
.wave(2, 2, 1) // Wave configuration
.warp(32, 32, 16) // Warp tile sizes
.pipeline("compv4") // Pipeline type
.scheduler("intrawave"), // Scheduler type
"gfx942" // WHERE: target architecture
)
);
Key Macros:
DECL_KERNEL_SET(name, ...)- Declare a kernel setREGISTER_GENERATED_KERNELS- Register all kernels from this exampleREGISTER_KERNEL_SET(name, registry)- Register specific kernel set to a registry