Files
composable_kernel/dispatcher/examples/gemm/python
Yaswanth Raparti c1127a36f5 [rocm-libraries] ROCm/rocm-libraries#5676 (commit 1d18339)
[CK][CK TILE]Autotuning heuristics infra for universal GEMM
 kernel selection (#5676)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

This PR adds ML-based kernel selection heuristics to the CK Tile
dispatcher, enabling fast and accurate automatic kernel selection for
Universal Gemm kernels. Instead of requiring exhaustive search through
4600+ kernel configurations (taking ~46 seconds per problem shape), the
ML heuristic predicts optimal kernels in microseconds while achieving
>98% of oracle-best performance.

## Technical Details

**ML infrastructure**

https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics
* Feature Engine
([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)):
55-feature extraction including problem dimensions, kernel
configuration, tile efficiency, and hardware profile
* Training Pipeline
([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)):
LightGBM regression with log-transform, GroupKFold cross-validation,
warm-start support
* Predictor
([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)):
Kernel ranking and TFLOPS prediction for problem shapes
* Evaluation
([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)):
Comprehensive metrics including efficiency, NDCG@k, shape family
analysis

**Data Generation Tools:**

*
[generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py):
Build and benchmark kernels across diverse problem shapes
*
[convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py):
Convert benchmark JSON to training-ready parquet format
*
[data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py):
Parse streaming benchmark logs into canonical datasets

**Examples**
*
[09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp):
C++ example demonstrating ML-based kernel selection
*
[09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py):
Python example with validation

**Pre-trained Models
(projects/composablekernel/dispatcher/heuristics/models/):**
* gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean
efficiency)
* gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean
efficiency)

## Test Plan

* Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8
* All shape families tested: tiny M (M<8), small M, medium M, large M
(M≥1024)
* All pipeline types: compv3, compv4, mem

## Test Result

**fp16 Model (gfx950, RCR layout)**
* Mean Efficiency: 99.36%
* P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of
oracle best)
* Min Efficiency: 95.45%

**fp8 Model (gfx950, RCR layout)**
* Mean Efficiency: 98.28% (original), 97.51% (wide coverage)
* P10 Efficiency: 94.64% (original), 93.89% (wide coverage)
* Min Efficiency: 84.5%

## Submission Checklist

- [x ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-02 02:26:32 +00:00
..

GEMM Python Examples

CK Tile Dispatcher Python examples for GEMM (General Matrix Multiplication) operations.

Main Documentation: Dispatcher README | Examples Overview

Quick Start

Build Library

cd /path/to/composable_kernel/dispatcher
mkdir -p build && cd build

cmake .. \
  -DCMAKE_PREFIX_PATH=/opt/rocm \
  -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
  -DBUILD_DISPATCHER_EXAMPLES=ON

# Build Python library (kernels generated automatically)
make dispatcher_gemm_lib -j$(nproc)

Run Examples

cd /path/to/composable_kernel/dispatcher

python3 examples/gemm/python/01_basic_gemm.py
python3 examples/gemm/python/04_validation.py
python3 examples/gemm/python/07_stress_test.py
python3 examples/gemm/python/08_heuristics.py

Examples

Example Description
01_basic_gemm.py Basic GEMM with multi-kernel support
02_batch_gemm.py Batched GEMM operations
03_benchmark.py Performance benchmarking
04_validation.py CPU reference validation
05_numpy_integration.py NumPy array integration
06_json_export.py Registry JSON export
07_stress_test.py Multi-kernel stress testing
08_heuristics.py Heuristic-based kernel selection
09_multi_registry.py Multiple registries
10_advanced_benchmark.py Advanced benchmark with full control
11_json_import.py Import kernels from JSON

Example Details

01_basic_gemm.py - Basic GEMM

Demonstrates the Python API with multi-kernel support:

from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table

# Define multiple kernel configurations
kernels = [
    KernelConfig(
        tile_m=128, tile_n=128, tile_k=32,
        wave_m=2, wave_n=2, wave_k=1,
        warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
        pipeline="compv3", scheduler="intrawave"
    ),
    KernelConfig(
        tile_m=256, tile_n=256, tile_k=32,
        wave_m=2, wave_n=2, wave_k=1,
        warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
        pipeline="compv4", scheduler="intrawave"
    ),
]

# Display configurations
print_kernel_config_table(kernels)

# Set up dispatcher with all kernels
lib, dispatcher, registry = setup_gemm_dispatcher(kernels)

# Run GEMM
elapsed_ms = run_gemm(lib, M, N, K, ...)

02_batch_gemm.py - Batch GEMM

Batched matrix multiplication:

  • Multiple independent GEMM operations
  • Batch dimension handling

03_benchmark.py - Benchmarking

Performance measurement:

  • GPU timing
  • TFLOPS calculation
  • Multiple iterations

04_validation.py - Validation

Correctness verification:

  • NumPy reference implementation
  • Tolerance-based validation
  • Error reporting

05_numpy_integration.py - NumPy Integration

Seamless NumPy integration:

  • NumPy arrays to GPU buffers
  • Results back to NumPy
  • Automatic type conversion

06_json_export.py - JSON Export

Registry serialization for tool integration:

  • Export kernel configurations
  • Machine-readable format

07_stress_test.py - Stress Testing

Comprehensive multi-kernel stress testing:

from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table

# Define 48 unique kernel configurations
kernels = [
    KernelConfig(tile_m=128, tile_n=128, tile_k=32, pipeline="compv3", ...),
    KernelConfig(tile_m=256, tile_n=256, tile_k=32, pipeline="compv4", ...),
    KernelConfig(tile_m=128, tile_n=256, tile_k=64, pipeline="compv3", ...),
    # ... many more configurations
]

# Test each kernel
for i, kernel in enumerate(kernels):
    lib, dispatcher, registry = setup_gemm_dispatcher([kernel])
    result = run_and_validate(lib, M, N, K, seed=42 + i)  # Different seed per kernel
    print(f"Kernel {i}: {result.max_err:.6e} {'PASS' if result.passed else 'FAIL'}")

Features:

  • 48 unique kernel configurations
  • Various tile sizes, pipelines, and schedulers
  • Per-kernel validation with unique random seeds
  • Performance reporting

08_heuristics.py - Heuristic Selection

Custom kernel selection based on problem characteristics:

# Define kernel pools for different strategies
SMALL_KERNELS = [KernelConfig(tile_m=64, tile_n=64, ...), ...]
LARGE_KERNELS = [KernelConfig(tile_m=256, tile_n=256, ...), ...]
COMPUTE_KERNELS = [KernelConfig(pipeline="compv4", ...), ...]
MEMORY_KERNELS = [KernelConfig(pipeline="compv3", ...), ...]

# Size-based heuristic
def size_based_heuristic(M, N, K):
    if M * N < 512 * 512:
        return SMALL_KERNELS
    else:
        return LARGE_KERNELS

# Strategy-based selection
def compute_strategy():
    return COMPUTE_KERNELS  # Optimized for compute-bound problems

def memory_strategy():
    return MEMORY_KERNELS   # Optimized for memory-bound problems

# Test different strategies
for strategy in [size_based_heuristic, compute_strategy, memory_strategy]:
    kernels = strategy(M, N, K)
    lib, dispatcher, registry = setup_gemm_dispatcher(kernels)
    elapsed_ms = run_gemm(lib, M, N, K, ...)

Features:

  • 24 kernel configurations across 6 categories
  • Size-based heuristic (small vs large)
  • Optimization strategies (compute, memory, latency)
  • Performance comparison across strategies

09_multi_registry.py - Multiple Registries

Separate registries for different workloads:

  • Compute-optimized registry
  • Latency-optimized registry
  • Dynamic registry selection

10_advanced_benchmark.py - Advanced Benchmark

Full control over benchmark parameters:

  • Warmup iterations
  • Benchmark iterations
  • Statistical analysis

11_json_import.py - JSON Import

Import kernel configurations from JSON:

  • External configuration files
  • Dynamic kernel loading

Utility Module: ctypes_utils.py

from ctypes_utils import (
    KernelConfig,              # Single kernel configuration
    setup_gemm_dispatcher,     # Set up dispatcher with kernels
    print_kernel_config_table, # Display kernel configurations
    Dispatcher,                # High-level dispatcher
    Registry,                  # Kernel registry
    Validator,                 # Validation utilities
)

KernelConfig

config = KernelConfig(
    # Tile sizes
    tile_m=256, tile_n=256, tile_k=32,
    # Wave configuration
    wave_m=2, wave_n=2, wave_k=1,
    # Warp tile sizes
    warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
    # Pipeline and scheduler
    pipeline="compv4",      # "compv3" or "compv4"
    scheduler="intrawave",  # "intrawave" or "interwave"
    # Optional
    epilogue="default",
    padding=True,
    double_buffer=True,
)

setup_gemm_dispatcher

# Single kernel
lib, dispatcher, registry = setup_gemm_dispatcher(config)

# Multiple kernels
lib, dispatcher, registry = setup_gemm_dispatcher([config1, config2, ...])

# With auto-rebuild
lib, dispatcher, registry = setup_gemm_dispatcher(config, auto_rebuild=True)

print_kernel_config_table

kernels = [config1, config2, config3]
print_kernel_config_table(kernels)
# Output:
# +----+-------+-------+-------+--------+-----------+
# | #  | Tile  | Wave  | Warp  | Pipe   | Scheduler |
# +----+-------+-------+-------+--------+-----------+
# | 1  | 128x128x32 | 2x2x1 | 32x32x16 | compv3 | intrawave |
# | 2  | 256x256x32 | 2x2x1 | 32x32x16 | compv4 | intrawave |
# | 3  | 128x256x64 | 2x2x1 | 32x32x16 | compv3 | interwave |
# +----+-------+-------+-------+--------+-----------+

GPU Memory Management

import ctypes
import numpy as np

# Load HIP library
hip = ctypes.CDLL("libamdhip64.so")

# Allocate GPU memory
gpu_ptr = ctypes.c_void_p()
hip.hipMalloc(ctypes.byref(gpu_ptr), size_in_bytes)

# Copy to GPU (1 = hipMemcpyHostToDevice)
hip.hipMemcpy(gpu_ptr, host_array.ctypes.data, size, 1)

# Copy back (2 = hipMemcpyDeviceToHost)
hip.hipMemcpy(host_array.ctypes.data, gpu_ptr, size, 2)

# Free
hip.hipFree(gpu_ptr)

Performance Testing

Test compilation performance with different kernel counts:

# Test with 10 kernels (~15s compile time)
python3 01_basic_gemm.py --num-kernels 10

# Test with 20 kernels (~25s compile time)
python3 01_basic_gemm.py --num-kernels 20

# Test with 48 kernels (~50s compile time)
python3 01_basic_gemm.py --num-kernels 48

Compilation time scales roughly linearly with kernel count.