mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
[CK][CK TILE]Autotuning heuristics infra for universal GEMM kernel selection (#5676) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation This PR adds ML-based kernel selection heuristics to the CK Tile dispatcher, enabling fast and accurate automatic kernel selection for Universal Gemm kernels. Instead of requiring exhaustive search through 4600+ kernel configurations (taking ~46 seconds per problem shape), the ML heuristic predicts optimal kernels in microseconds while achieving >98% of oracle-best performance. ## Technical Details **ML infrastructure** https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics * Feature Engine ([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)): 55-feature extraction including problem dimensions, kernel configuration, tile efficiency, and hardware profile * Training Pipeline ([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)): LightGBM regression with log-transform, GroupKFold cross-validation, warm-start support * Predictor ([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)): Kernel ranking and TFLOPS prediction for problem shapes * Evaluation ([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)): Comprehensive metrics including efficiency, NDCG@k, shape family analysis **Data Generation Tools:** * [generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py): Build and benchmark kernels across diverse problem shapes * [convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py): Convert benchmark JSON to training-ready parquet format * [data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py): Parse streaming benchmark logs into canonical datasets **Examples** * [09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp): C++ example demonstrating ML-based kernel selection * [09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py): Python example with validation **Pre-trained Models (projects/composablekernel/dispatcher/heuristics/models/):** * gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean efficiency) * gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean efficiency) ## Test Plan * Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8 * All shape families tested: tiny M (M<8), small M, medium M, large M (M≥1024) * All pipeline types: compv3, compv4, mem ## Test Result **fp16 Model (gfx950, RCR layout)** * Mean Efficiency: 99.36% * P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of oracle best) * Min Efficiency: 95.45% **fp8 Model (gfx950, RCR layout)** * Mean Efficiency: 98.28% (original), 97.51% (wide coverage) * P10 Efficiency: 94.64% (original), 93.89% (wide coverage) * Min Efficiency: 84.5% ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
GEMM C++ Examples
CK Tile Dispatcher C++ examples for GEMM (General Matrix Multiplication) operations.
Main Documentation: Dispatcher README | Examples Overview
Quick Start
Build and Run
cd /path/to/composable_kernel/dispatcher
mkdir -p build && cd build
cmake .. \
-DCMAKE_PREFIX_PATH=/opt/rocm \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DBUILD_DISPATCHER_EXAMPLES=ON
# Build (kernels generated automatically by CMake)
make -j$(nproc)
# Run examples
cd examples
./gemm_01_basic
./gemm_03_benchmark_validation
./gemm_04_heuristics
Examples
| Example | Description | Complexity |
|---|---|---|
| 01_basic_gemm.cpp | Basic GEMM with declarative API, autofill, autocorrect | ★☆☆☆☆ |
| 02_multi_size.cpp | Wildcard expansion for multiple configurations | ★★☆☆☆ |
| 03_benchmark_validation.cpp | Performance benchmarking with CPU reference validation | ★★☆☆☆ |
| 04_heuristics.cpp | Heuristic-based kernel selection | ★★★☆☆ |
| 05_json_export.cpp | Registry JSON export for external tools | ★★☆☆☆ |
| 06_multi_registry.cpp | Multiple registries with named kernel sets | ★★★☆☆ |
Example Details
01_basic_gemm.cpp - Basic GEMM
Demonstrates the declarative kernel API with three patterns:
- Autofill Pattern - Minimal specification, defaults filled automatically
- Autocorrect Pattern - Invalid parameters corrected at build time
- Full Specification Pattern - Complete kernel configuration
DECL_KERNEL_SET(basic_kernels,
// Pattern 1: Autofill - minimal specification
.add(
Signature().dtype("fp16").layout("rcr"),
Algorithm(), // Defaults filled by autofill
"gfx942"
)
// Pattern 2: Full specification
.add(
Signature().dtype("fp16").layout("rcr"),
Algorithm().tile(256, 256, 32).wave(2, 2, 1).warp(32, 32, 16)
.pipeline("compv4").scheduler("intrawave"),
"gfx942"
)
);
Features:
- Uses generic
REGISTER_GENERATED_KERNELSmacro print_registered_kernels()utility for debugging- Demonstrates autofill messages during build
02_multi_size.cpp - Wildcard Expansion
Demonstrates automatic generation of multiple kernel configurations:
DECL_KERNEL_SET(multi_kernels,
.add(
Signature().dtype("fp16").layout("rcr"),
Algorithm().tile(*, *, 32) // Wildcard tile M and N
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv4")
.scheduler("intrawave"),
"gfx942"
)
);
Wildcard Values:
*,-1, orANY_INTexpand to all valid configurations- Architecture filter prunes invalid combinations automatically
- Example generates 5 valid kernels after arch filtering (from 7 expansions)
03_benchmark_validation.cpp - Benchmark + Validation
Consolidated example combining performance benchmarking with correctness validation:
# Benchmark only
./gemm_03_benchmark_validation --warmup 10 --iterations 100
# With CPU validation
./gemm_03_benchmark_validation --verify 1 --rtol 1e-3 --atol 1e-3
# With GPU reference validation (faster for large matrices)
./gemm_03_benchmark_validation --verify 2
Features:
- Warmup iterations (discarded from timing)
- Benchmark iterations with statistics (min/max/mean/median)
- CPU reference validation using
ck_tile::reference_gemm - GPU reference validation using
ck_tile::reference_gemm_gpu - Configurable tolerances
04_heuristics.cpp - Heuristic Selection
Demonstrates custom kernel selection based on problem characteristics:
// Problem size analysis
auto heuristic = [](const Problem& p) -> std::optional<KernelKey> {
if (p.M() * p.N() < 256 * 256) {
return small_kernel_key; // Memory-bound heuristic
} else {
return large_kernel_key; // Compute-bound heuristic
}
};
dispatcher.set_heuristic(heuristic);
Features:
- Problem size analysis (small vs large matrices)
- Compute-bound vs memory-bound selection
- Custom heuristic function registration
05_json_export.cpp - JSON Export
Exports registry information to JSON for external tool integration:
auto json = registry.to_json();
std::ofstream file("kernels.json");
file << json;
Use Cases:
- Kernel metadata serialization
- External analysis tools
- Configuration management
06_multi_registry.cpp - Multiple Registries
Demonstrates using multiple registries with named kernel sets:
// Define separate kernel sets
DECL_KERNEL_SET(compute_optimized, ...);
DECL_KERNEL_SET(latency_optimized, ...);
// Register to specific registries
Registry compute_registry, latency_registry;
REGISTER_KERNEL_SET(compute_optimized, compute_registry);
REGISTER_KERNEL_SET(latency_optimized, latency_registry);
// Use appropriate registry based on workload
Dispatcher compute_dispatcher(compute_registry);
Dispatcher latency_dispatcher(latency_registry);
Features:
- Named kernel set registration with
REGISTER_KERNEL_SETmacro - Separate registries for different optimization goals
- Dynamic kernel set selection by name
Benchmark Parameters (stream_config)
CK Tile uses stream_config for benchmark control:
ck_tile::stream_config cfg{
nullptr, // stream_id - HIP stream (nullptr = default)
true, // time_kernel - Enable timing
1, // log_level - Verbosity (0=quiet, 1=normal)
5, // cold_niters - Warmup iterations
20, // nrepeat - Benchmark iterations
true, // is_gpu_timer - Use GPU events vs CPU chrono
false, // flush_cache - Flush L2 cache between iterations
1 // rotating_count - Rotating buffers for cache simulation
};
| Parameter | CLI Option | Default | Description |
|---|---|---|---|
cold_niters_ |
--warmup |
5 | Warmup iterations |
nrepeat_ |
--iterations |
100 | Benchmark iterations |
flush_cache_ |
- | false | Flush L2 cache |
rotating_count_ |
- | 1 | Rotating buffers |
is_gpu_timer_ |
- | true | GPU timer vs CPU |
Declarative Kernel Pattern
All examples use the declarative DECL_KERNEL_SET macro:
DECL_KERNEL_SET(my_kernels,
.add(
Signature() // WHAT: operation signature
.dtype("fp16") // Data type
.layout("rcr"), // Matrix layouts (A=row, B=col, C=row)
Algorithm() // HOW: implementation details
.tile(256, 256, 32) // Tile sizes (M, N, K)
.wave(2, 2, 1) // Wave configuration
.warp(32, 32, 16) // Warp tile sizes
.pipeline("compv4") // Pipeline type
.scheduler("intrawave"), // Scheduler type
"gfx942" // WHERE: target architecture
)
);
Key Macros:
DECL_KERNEL_SET(name, ...)- Declare a kernel setREGISTER_GENERATED_KERNELS- Register all kernels from this exampleREGISTER_KERNEL_SET(name, registry)- Register specific kernel set to a registry