Files
composable_kernel/dispatcher/examples/gemm/cpp
Vidyasagar Ananthan 920acd2c12 [rocm-libraries] ROCm/rocm-libraries#5168 (commit 8b5afcb)
[CK] [CK_Tile] Add GroupConv to Kernel Dispatcher

## Motivation

This PR adds CK Tile group convolution (forward, backward-data,
backward-weight) support to the kernel dispatcher, matching and unifying
with the existing dispatcher GEMM infrastructure in architecture and
usability. The dispatcher provides a unified kernel dispatch system with
both C++ and Python frontends, and until now only supported GEMM
operations. This PR enables framework integrators to use the same
declarative kernel workflow for convolutions as they do for GEMM:
declare kernels, build a registry JIT, select kernels within the
registry at runtime, and dispatch to GPU. Future PRs will include
runtime kernel selection heuristics for autotuning of kernel parameters
based on (problem, hardware arch).

## Technical Details

Grouped convolution support has been added to the CK Tile Dispatcher
with generated_conv_backend.hpp enabling dispatcher.run(in, wei, out,
problem) for all 6 conv variants (fwd/bwdd/bwdw x 2D/3D), runtime
heuristic kernel selection, and GroupedConvKernelKey with full
ConvConfigBase fields. Python side adds parallel JIT via
registry.build(max_workers) and heuristic registry.select(). Includes 7
C++ and 6 Python examples covering all directions with CPU reference
validation, and shared infrastructure improvements (BaseRegistry CRTP,
structured exceptions). As a sanity check, JIT compile times for a
single kernel remains the same and for multiple kernels there is better
parallelism:
Kernels | 1 worker | 8 workers
1 | 7.7 s | 7.7 s
2 | 15.9 s | 8.2 s
4 | 33.4 s | 9.7 s
6 | 52.3 s | 10.2 s

## Test Plan

145 ephemeral unit tests have been added to test basic functionality.
All 30 examples/integration tests run end-to-end on gfx950 (MI350): 7
C++ conv, 7 C++ GEMM, 6 Python conv, 10 Python GEMM. CPU reference
validation for forward, backward-data, and backward-weight (2D) in both
C++ and Python examples pass.

## Test Result

30 examples pass. Peak performance: 132 TFLOPS (Batch-32 forward 56x56),
53 TFLOPS (pointwise 1x1). CPU reference accuracy: max_abs_diff < 0.002
for all directions (fp16 vs fp32 reference).

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-09 17:39:35 +00:00
..

GEMM C++ Examples

CK Tile Dispatcher C++ examples for GEMM (General Matrix Multiplication) operations.

Main Documentation: Dispatcher README | Examples Overview

Quick Start

Build and Run

cd /path/to/composable_kernel/dispatcher
mkdir -p build && cd build

cmake .. \
  -DCMAKE_PREFIX_PATH=/opt/rocm \
  -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
  -DBUILD_DISPATCHER_EXAMPLES=ON

# Build (kernels generated automatically by CMake)
make -j$(nproc)

# Run examples
cd examples
./gemm_01_basic
./gemm_03_benchmark_validation
./gemm_04_heuristics

Examples

Example Description
01_basic_gemm.cpp Basic GEMM with declarative API, autofill, autocorrect
02_multi_size.cpp Wildcard expansion for multiple configurations
03_benchmark_validation.cpp Performance benchmarking with CPU reference validation
04_heuristics.cpp Heuristic-based kernel selection
05_json_export.cpp Registry JSON export for external tools
06_multi_registry.cpp Multiple registries with named kernel sets

Example Details

01_basic_gemm.cpp - Basic GEMM

Demonstrates the declarative kernel API with three patterns:

  1. Autofill Pattern - Minimal specification, defaults filled automatically
  2. Autocorrect Pattern - Invalid parameters corrected at build time
  3. Full Specification Pattern - Complete kernel configuration
DECL_KERNEL_SET(basic_kernels,
    // Pattern 1: Autofill - minimal specification
    .add(
        Signature().dtype("fp16").layout("rcr"),
        Algorithm(),  // Defaults filled by autofill
        "gfx942"
    )
    // Pattern 2: Full specification
    .add(
        Signature().dtype("fp16").layout("rcr"),
        Algorithm().tile(256, 256, 32).wave(2, 2, 1).warp(32, 32, 16)
                   .pipeline("compv4").scheduler("intrawave"),
        "gfx942"
    )
);

Features:

  • Uses generic REGISTER_GENERATED_KERNELS macro
  • print_registered_kernels() utility for debugging
  • Demonstrates autofill messages during build

02_multi_size.cpp - Wildcard Expansion

Demonstrates automatic generation of multiple kernel configurations:

DECL_KERNEL_SET(multi_kernels,
    .add(
        Signature().dtype("fp16").layout("rcr"),
        Algorithm().tile(*, *, 32)     // Wildcard tile M and N
                   .wave(2, 2, 1)
                   .warp(32, 32, 16)
                   .pipeline("compv4")
                   .scheduler("intrawave"),
        "gfx942"
    )
);

Wildcard Values:

  • *, -1, or ANY_INT expand to all valid configurations
  • Architecture filter prunes invalid combinations automatically
  • Example generates 5 valid kernels after arch filtering (from 7 expansions)

03_benchmark_validation.cpp - Benchmark + Validation

Consolidated example combining performance benchmarking with correctness validation:

# Benchmark only
./gemm_03_benchmark_validation --warmup 10 --iterations 100

# With CPU validation
./gemm_03_benchmark_validation --verify 1 --rtol 1e-3 --atol 1e-3

# With GPU reference validation (faster for large matrices)
./gemm_03_benchmark_validation --verify 2

Features:

  • Warmup iterations (discarded from timing)
  • Benchmark iterations with statistics (min/max/mean/median)
  • CPU reference validation using ck_tile::reference_gemm
  • GPU reference validation using ck_tile::reference_gemm_gpu
  • Configurable tolerances

04_heuristics.cpp - Heuristic Selection

Demonstrates custom kernel selection based on problem characteristics:

// Problem size analysis
auto heuristic = [](const Problem& p) -> std::optional<KernelKey> {
    if (p.M() * p.N() < 256 * 256) {
        return small_kernel_key;   // Memory-bound heuristic
    } else {
        return large_kernel_key;   // Compute-bound heuristic
    }
};

dispatcher.set_heuristic(heuristic);

Features:

  • Problem size analysis (small vs large matrices)
  • Compute-bound vs memory-bound selection
  • Custom heuristic function registration

05_json_export.cpp - JSON Export

Exports registry information to JSON for external tool integration:

auto json = registry.to_json();
std::ofstream file("kernels.json");
file << json;

Use Cases:

  • Kernel metadata serialization
  • External analysis tools
  • Configuration management

06_multi_registry.cpp - Multiple Registries

Demonstrates using multiple registries with named kernel sets:

// Define separate kernel sets
DECL_KERNEL_SET(compute_optimized, ...);
DECL_KERNEL_SET(latency_optimized, ...);

// Register to specific registries
Registry compute_registry, latency_registry;
REGISTER_KERNEL_SET(compute_optimized, compute_registry);
REGISTER_KERNEL_SET(latency_optimized, latency_registry);

// Use appropriate registry based on workload
Dispatcher compute_dispatcher(compute_registry);
Dispatcher latency_dispatcher(latency_registry);

Features:

  • Named kernel set registration with REGISTER_KERNEL_SET macro
  • Separate registries for different optimization goals
  • Dynamic kernel set selection by name

Benchmark Parameters (stream_config)

CK Tile uses stream_config for benchmark control:

ck_tile::stream_config cfg{
    nullptr,    // stream_id       - HIP stream (nullptr = default)
    true,       // time_kernel     - Enable timing
    1,          // log_level       - Verbosity (0=quiet, 1=normal)
    5,          // cold_niters     - Warmup iterations
    20,         // nrepeat         - Benchmark iterations
    true,       // is_gpu_timer    - Use GPU events vs CPU chrono
    false,      // flush_cache     - Flush L2 cache between iterations
    1           // rotating_count  - Rotating buffers for cache simulation
};
Parameter CLI Option Default Description
cold_niters_ --warmup 5 Warmup iterations
nrepeat_ --iterations 100 Benchmark iterations
flush_cache_ - false Flush L2 cache
rotating_count_ - 1 Rotating buffers
is_gpu_timer_ - true GPU timer vs CPU

Declarative Kernel Pattern

All examples use the declarative DECL_KERNEL_SET macro:

DECL_KERNEL_SET(my_kernels,
    .add(
        Signature()               // WHAT: operation signature
            .dtype("fp16")        // Data type
            .layout("rcr"),       // Matrix layouts (A=row, B=col, C=row)
        Algorithm()               // HOW: implementation details  
            .tile(256, 256, 32)   // Tile sizes (M, N, K)
            .wave(2, 2, 1)        // Wave configuration
            .warp(32, 32, 16)     // Warp tile sizes
            .pipeline("compv4")   // Pipeline type
            .scheduler("intrawave"), // Scheduler type
        "gfx942"                  // WHERE: target architecture
    )
);

Key Macros:

  • DECL_KERNEL_SET(name, ...) - Declare a kernel set
  • REGISTER_GENERATED_KERNELS - Register all kernels from this example
  • REGISTER_KERNEL_SET(name, registry) - Register specific kernel set to a registry