Files
composable_kernel/dispatcher/examples/gemm/cpp
Yaswanth Raparti c1127a36f5 [rocm-libraries] ROCm/rocm-libraries#5676 (commit 1d18339)
[CK][CK TILE]Autotuning heuristics infra for universal GEMM
 kernel selection (#5676)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

This PR adds ML-based kernel selection heuristics to the CK Tile
dispatcher, enabling fast and accurate automatic kernel selection for
Universal Gemm kernels. Instead of requiring exhaustive search through
4600+ kernel configurations (taking ~46 seconds per problem shape), the
ML heuristic predicts optimal kernels in microseconds while achieving
>98% of oracle-best performance.

## Technical Details

**ML infrastructure**

https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics
* Feature Engine
([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)):
55-feature extraction including problem dimensions, kernel
configuration, tile efficiency, and hardware profile
* Training Pipeline
([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)):
LightGBM regression with log-transform, GroupKFold cross-validation,
warm-start support
* Predictor
([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)):
Kernel ranking and TFLOPS prediction for problem shapes
* Evaluation
([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)):
Comprehensive metrics including efficiency, NDCG@k, shape family
analysis

**Data Generation Tools:**

*
[generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py):
Build and benchmark kernels across diverse problem shapes
*
[convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py):
Convert benchmark JSON to training-ready parquet format
*
[data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py):
Parse streaming benchmark logs into canonical datasets

**Examples**
*
[09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp):
C++ example demonstrating ML-based kernel selection
*
[09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py):
Python example with validation

**Pre-trained Models
(projects/composablekernel/dispatcher/heuristics/models/):**
* gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean
efficiency)
* gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean
efficiency)

## Test Plan

* Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8
* All shape families tested: tiny M (M<8), small M, medium M, large M
(M≥1024)
* All pipeline types: compv3, compv4, mem

## Test Result

**fp16 Model (gfx950, RCR layout)**
* Mean Efficiency: 99.36%
* P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of
oracle best)
* Min Efficiency: 95.45%

**fp8 Model (gfx950, RCR layout)**
* Mean Efficiency: 98.28% (original), 97.51% (wide coverage)
* P10 Efficiency: 94.64% (original), 93.89% (wide coverage)
* Min Efficiency: 84.5%

## Submission Checklist

- [x ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-02 02:26:32 +00:00
..

GEMM C++ Examples

CK Tile Dispatcher C++ examples for GEMM (General Matrix Multiplication) operations.

Main Documentation: Dispatcher README | Examples Overview

Quick Start

Build and Run

cd /path/to/composable_kernel/dispatcher
mkdir -p build && cd build

cmake .. \
  -DCMAKE_PREFIX_PATH=/opt/rocm \
  -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
  -DBUILD_DISPATCHER_EXAMPLES=ON

# Build (kernels generated automatically by CMake)
make -j$(nproc)

# Run examples
cd examples
./gemm_01_basic
./gemm_03_benchmark_validation
./gemm_04_heuristics

Examples

Example Description Complexity
01_basic_gemm.cpp Basic GEMM with declarative API, autofill, autocorrect ★☆☆☆☆
02_multi_size.cpp Wildcard expansion for multiple configurations ★★☆☆☆
03_benchmark_validation.cpp Performance benchmarking with CPU reference validation ★★☆☆☆
04_heuristics.cpp Heuristic-based kernel selection ★★★☆☆
05_json_export.cpp Registry JSON export for external tools ★★☆☆☆
06_multi_registry.cpp Multiple registries with named kernel sets ★★★☆☆

Example Details

01_basic_gemm.cpp - Basic GEMM

Demonstrates the declarative kernel API with three patterns:

  1. Autofill Pattern - Minimal specification, defaults filled automatically
  2. Autocorrect Pattern - Invalid parameters corrected at build time
  3. Full Specification Pattern - Complete kernel configuration
DECL_KERNEL_SET(basic_kernels,
    // Pattern 1: Autofill - minimal specification
    .add(
        Signature().dtype("fp16").layout("rcr"),
        Algorithm(),  // Defaults filled by autofill
        "gfx942"
    )
    // Pattern 2: Full specification
    .add(
        Signature().dtype("fp16").layout("rcr"),
        Algorithm().tile(256, 256, 32).wave(2, 2, 1).warp(32, 32, 16)
                   .pipeline("compv4").scheduler("intrawave"),
        "gfx942"
    )
);

Features:

  • Uses generic REGISTER_GENERATED_KERNELS macro
  • print_registered_kernels() utility for debugging
  • Demonstrates autofill messages during build

02_multi_size.cpp - Wildcard Expansion

Demonstrates automatic generation of multiple kernel configurations:

DECL_KERNEL_SET(multi_kernels,
    .add(
        Signature().dtype("fp16").layout("rcr"),
        Algorithm().tile(*, *, 32)     // Wildcard tile M and N
                   .wave(2, 2, 1)
                   .warp(32, 32, 16)
                   .pipeline("compv4")
                   .scheduler("intrawave"),
        "gfx942"
    )
);

Wildcard Values:

  • *, -1, or ANY_INT expand to all valid configurations
  • Architecture filter prunes invalid combinations automatically
  • Example generates 5 valid kernels after arch filtering (from 7 expansions)

03_benchmark_validation.cpp - Benchmark + Validation

Consolidated example combining performance benchmarking with correctness validation:

# Benchmark only
./gemm_03_benchmark_validation --warmup 10 --iterations 100

# With CPU validation
./gemm_03_benchmark_validation --verify 1 --rtol 1e-3 --atol 1e-3

# With GPU reference validation (faster for large matrices)
./gemm_03_benchmark_validation --verify 2

Features:

  • Warmup iterations (discarded from timing)
  • Benchmark iterations with statistics (min/max/mean/median)
  • CPU reference validation using ck_tile::reference_gemm
  • GPU reference validation using ck_tile::reference_gemm_gpu
  • Configurable tolerances

04_heuristics.cpp - Heuristic Selection

Demonstrates custom kernel selection based on problem characteristics:

// Problem size analysis
auto heuristic = [](const Problem& p) -> std::optional<KernelKey> {
    if (p.M() * p.N() < 256 * 256) {
        return small_kernel_key;   // Memory-bound heuristic
    } else {
        return large_kernel_key;   // Compute-bound heuristic
    }
};

dispatcher.set_heuristic(heuristic);

Features:

  • Problem size analysis (small vs large matrices)
  • Compute-bound vs memory-bound selection
  • Custom heuristic function registration

05_json_export.cpp - JSON Export

Exports registry information to JSON for external tool integration:

auto json = registry.to_json();
std::ofstream file("kernels.json");
file << json;

Use Cases:

  • Kernel metadata serialization
  • External analysis tools
  • Configuration management

06_multi_registry.cpp - Multiple Registries

Demonstrates using multiple registries with named kernel sets:

// Define separate kernel sets
DECL_KERNEL_SET(compute_optimized, ...);
DECL_KERNEL_SET(latency_optimized, ...);

// Register to specific registries
Registry compute_registry, latency_registry;
REGISTER_KERNEL_SET(compute_optimized, compute_registry);
REGISTER_KERNEL_SET(latency_optimized, latency_registry);

// Use appropriate registry based on workload
Dispatcher compute_dispatcher(compute_registry);
Dispatcher latency_dispatcher(latency_registry);

Features:

  • Named kernel set registration with REGISTER_KERNEL_SET macro
  • Separate registries for different optimization goals
  • Dynamic kernel set selection by name

Benchmark Parameters (stream_config)

CK Tile uses stream_config for benchmark control:

ck_tile::stream_config cfg{
    nullptr,    // stream_id       - HIP stream (nullptr = default)
    true,       // time_kernel     - Enable timing
    1,          // log_level       - Verbosity (0=quiet, 1=normal)
    5,          // cold_niters     - Warmup iterations
    20,         // nrepeat         - Benchmark iterations
    true,       // is_gpu_timer    - Use GPU events vs CPU chrono
    false,      // flush_cache     - Flush L2 cache between iterations
    1           // rotating_count  - Rotating buffers for cache simulation
};
Parameter CLI Option Default Description
cold_niters_ --warmup 5 Warmup iterations
nrepeat_ --iterations 100 Benchmark iterations
flush_cache_ - false Flush L2 cache
rotating_count_ - 1 Rotating buffers
is_gpu_timer_ - true GPU timer vs CPU

Declarative Kernel Pattern

All examples use the declarative DECL_KERNEL_SET macro:

DECL_KERNEL_SET(my_kernels,
    .add(
        Signature()               // WHAT: operation signature
            .dtype("fp16")        // Data type
            .layout("rcr"),       // Matrix layouts (A=row, B=col, C=row)
        Algorithm()               // HOW: implementation details  
            .tile(256, 256, 32)   // Tile sizes (M, N, K)
            .wave(2, 2, 1)        // Wave configuration
            .warp(32, 32, 16)     // Warp tile sizes
            .pipeline("compv4")   // Pipeline type
            .scheduler("intrawave"), // Scheduler type
        "gfx942"                  // WHERE: target architecture
    )
);

Key Macros:

  • DECL_KERNEL_SET(name, ...) - Declare a kernel set
  • REGISTER_GENERATED_KERNELS - Register all kernels from this example
  • REGISTER_KERNEL_SET(name, registry) - Register specific kernel set to a registry