Files
composable_kernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp
Yaswanth Raparti c1127a36f5 [rocm-libraries] ROCm/rocm-libraries#5676 (commit 1d18339)
[CK][CK TILE]Autotuning heuristics infra for universal GEMM
 kernel selection (#5676)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

This PR adds ML-based kernel selection heuristics to the CK Tile
dispatcher, enabling fast and accurate automatic kernel selection for
Universal Gemm kernels. Instead of requiring exhaustive search through
4600+ kernel configurations (taking ~46 seconds per problem shape), the
ML heuristic predicts optimal kernels in microseconds while achieving
>98% of oracle-best performance.

## Technical Details

**ML infrastructure**

https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics
* Feature Engine
([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)):
55-feature extraction including problem dimensions, kernel
configuration, tile efficiency, and hardware profile
* Training Pipeline
([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)):
LightGBM regression with log-transform, GroupKFold cross-validation,
warm-start support
* Predictor
([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)):
Kernel ranking and TFLOPS prediction for problem shapes
* Evaluation
([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)):
Comprehensive metrics including efficiency, NDCG@k, shape family
analysis

**Data Generation Tools:**

*
[generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py):
Build and benchmark kernels across diverse problem shapes
*
[convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py):
Convert benchmark JSON to training-ready parquet format
*
[data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py):
Parse streaming benchmark logs into canonical datasets

**Examples**
*
[09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp):
C++ example demonstrating ML-based kernel selection
*
[09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py):
Python example with validation

**Pre-trained Models
(projects/composablekernel/dispatcher/heuristics/models/):**
* gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean
efficiency)
* gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean
efficiency)

## Test Plan

* Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8
* All shape families tested: tiny M (M<8), small M, medium M, large M
(M≥1024)
* All pipeline types: compv3, compv4, mem

## Test Result

**fp16 Model (gfx950, RCR layout)**
* Mean Efficiency: 99.36%
* P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of
oracle best)
* Min Efficiency: 95.45%

**fp8 Model (gfx950, RCR layout)**
* Mean Efficiency: 98.28% (original), 97.51% (wide coverage)
* P10 Efficiency: 94.64% (original), 93.89% (wide coverage)
* Min Efficiency: 84.5%

## Submission Checklist

- [x ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-02 02:26:32 +00:00

212 lines
8.3 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* Example 09: ML-Based Kernel Selection (Native C++)
*
* Uses a trained LightGBM model loaded via the C API to predict TFLOPS
* for each kernel in the registry and select the best one. The kernels
* are JIT-compiled at build time via DECL_KERNEL_SET (same as other examples).
*
* Build: cd dispatcher/build && cmake .. && make gemm_09_ml_heuristic
* Run: ./gemm_09_ml_heuristic --model <path_to_model.lgbm>
*/
#include <hip/hip_runtime.h>
#include <iostream>
#include <iomanip>
#include <vector>
#include <chrono>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/kernel_decl.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
#include "ck_tile/dispatcher/ml_heuristic.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
using Signature = decl::Signature;
using Algorithm = decl::Algorithm;
// Multiple kernel configs for ML to choose from
DECL_KERNEL_SET(ml_kernels,
// Small tiles
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(64, 64, 32)
.wave(2, 2, 1)
.warp(16, 16, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942")
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(64, 64, 64)
.wave(2, 2, 1)
.warp(16, 16, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942")
// Medium tiles
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(128, 128, 32)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942")
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(128, 128, 64)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942")
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(128, 128, 64)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv4")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942")
// Large tiles
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(256, 256, 32)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942")
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(256, 128, 32)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942")
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(128, 256, 32)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942"));
int main(int argc, char* argv[])
{
ExampleArgs args("Example 09: ML-Based Kernel Selection",
"Uses trained LightGBM model for kernel selection");
args.add_option("--arch", "gfx942", "GPU architecture");
args.add_option("--model", "", "Path to LightGBM model file (.lgbm)");
args.add_option("--log_transform", "false", "Model uses log1p transform");
if(!args.parse(argc, argv))
return 0;
print_header("Example 09: ML-Based Kernel Selection");
std::string gfx_arch = args.get("--arch", "gfx942");
std::string model_path = args.get("--model", "");
bool log_transform = (args.get("--log_transform", "false") == "true");
if(model_path.empty())
{
std::cerr << "Error: --model <path> is required" << std::endl;
std::cerr << "Usage: ./gemm_09_ml_heuristic --model path/to/model_tflops.lgbm" << std::endl;
return 1;
}
// Setup Registry (kernels are JIT compiled from DECL_KERNEL_SET above)
Registry registry;
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << "Registry: " << registry.size() << " kernel(s)" << std::endl;
// Load ML model and create heuristic
HardwareProfile hw;
MLHeuristic ml_heuristic(model_path, &registry, hw, log_transform);
if(!ml_heuristic.is_loaded())
{
std::cerr << "Failed to load model. Exiting." << std::endl;
return 1;
}
// Wire ML heuristic into dispatcher
Dispatcher dispatcher(&registry);
dispatcher.set_strategy(Dispatcher::SelectionStrategy::Heuristic);
dispatcher.set_heuristic([&ml_heuristic](const Problem& p) { return ml_heuristic(p); });
std::cout << "Strategy: ML Heuristic (LightGBM)" << std::endl;
// Test with different problem sizes
using DataType = ck_tile::fp16_t;
std::vector<std::tuple<int, int, int>> sizes = {
{128, 128, 64},
{512, 512, 256},
{1024, 1024, 512},
{2048, 2048, 1024},
};
std::cout << std::endl
<< std::setw(20) << "Shape" << std::setw(30) << "Selected Kernel" << std::setw(15)
<< "Pred TFLOPS" << std::setw(12) << "Select ms" << std::setw(10) << "Status"
<< std::endl;
std::cout << std::string(87, '-') << std::endl;
bool all_passed = true;
for(const auto& [M, N, K] : sizes)
{
Problem problem;
problem.M = M;
problem.N = N;
problem.K = K;
problem.k_batch = 1;
auto t0 = std::chrono::high_resolution_clock::now();
auto kernel = dispatcher.select_kernel(problem);
auto t1 = std::chrono::high_resolution_clock::now();
double select_ms = std::chrono::duration<double, std::milli>(t1 - t0).count();
std::string size_str =
std::to_string(M) + "x" + std::to_string(N) + "x" + std::to_string(K);
if(!kernel)
{
std::cout << std::setw(20) << size_str << std::setw(30) << "NONE" << std::setw(15)
<< "N/A" << std::setw(12) << std::fixed << std::setprecision(2) << select_ms
<< std::setw(10) << "FAIL" << std::endl;
all_passed = false;
continue;
}
double pred = ml_heuristic.predict_tflops(problem, kernel->get_key());
std::string name = kernel->get_key().encode_identifier();
if(name.length() > 27)
name = name.substr(0, 27) + "..";
std::cout << std::setw(20) << size_str << std::setw(30) << name << std::setw(15)
<< std::fixed << std::setprecision(2) << pred << std::setw(12)
<< std::setprecision(2) << select_ms << std::setw(10) << "OK" << std::endl;
}
std::cout << std::endl
<< (all_passed ? "*** ALL TESTS PASSED ***" : "*** SOME TESTS FAILED ***")
<< std::endl;
return all_passed ? 0 : 1;
}