Files
composable_kernel/dispatcher/heuristics/collect_additional.sh
Yaswanth Raparti c1127a36f5 [rocm-libraries] ROCm/rocm-libraries#5676 (commit 1d18339)
[CK][CK TILE]Autotuning heuristics infra for universal GEMM
 kernel selection (#5676)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

This PR adds ML-based kernel selection heuristics to the CK Tile
dispatcher, enabling fast and accurate automatic kernel selection for
Universal Gemm kernels. Instead of requiring exhaustive search through
4600+ kernel configurations (taking ~46 seconds per problem shape), the
ML heuristic predicts optimal kernels in microseconds while achieving
>98% of oracle-best performance.

## Technical Details

**ML infrastructure**

https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics
* Feature Engine
([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)):
55-feature extraction including problem dimensions, kernel
configuration, tile efficiency, and hardware profile
* Training Pipeline
([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)):
LightGBM regression with log-transform, GroupKFold cross-validation,
warm-start support
* Predictor
([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)):
Kernel ranking and TFLOPS prediction for problem shapes
* Evaluation
([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)):
Comprehensive metrics including efficiency, NDCG@k, shape family
analysis

**Data Generation Tools:**

*
[generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py):
Build and benchmark kernels across diverse problem shapes
*
[convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py):
Convert benchmark JSON to training-ready parquet format
*
[data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py):
Parse streaming benchmark logs into canonical datasets

**Examples**
*
[09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp):
C++ example demonstrating ML-based kernel selection
*
[09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py):
Python example with validation

**Pre-trained Models
(projects/composablekernel/dispatcher/heuristics/models/):**
* gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean
efficiency)
* gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean
efficiency)

## Test Plan

* Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8
* All shape families tested: tiny M (M<8), small M, medium M, large M
(M≥1024)
* All pipeline types: compv3, compv4, mem

## Test Result

**fp16 Model (gfx950, RCR layout)**
* Mean Efficiency: 99.36%
* P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of
oracle best)
* Min Efficiency: 95.45%

**fp8 Model (gfx950, RCR layout)**
* Mean Efficiency: 98.28% (original), 97.51% (wide coverage)
* P10 Efficiency: 94.64% (original), 93.89% (wide coverage)
* Min Efficiency: 84.5%

## Submission Checklist

- [x ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-02 02:26:32 +00:00

68 lines
1.9 KiB
Bash
Executable File

#!/bin/bash
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# Generate additional benchmark data for shapes NOT in the original log.
# Runs in background; outputs streaming JSON that can be parsed by data_pipeline.py.
BIN_DIR="/workspace/ck_tile/bin"
OUT_LOG="data/additional_shapes.log"
WARMUP=3
REPEAT=10
mkdir -p data
# Additional shapes: square powers-of-2 and common ML sizes not in original DeepSeek set
SHAPES=(
"64,64,64"
"128,128,128"
"256,256,256"
"512,512,512"
"1024,1024,1024"
"2048,2048,2048"
"4096,4096,4096"
"1,4096,4096"
"8,4096,4096"
"32,4096,4096"
"128,4096,4096"
"1,4096,11008"
"32,4096,11008"
"1,8192,8192"
"32,8192,8192"
"1,8192,28672"
"32,8192,28672"
"256,256,8192"
"8192,8192,256"
"1024,4096,1024"
"4096,1024,4096"
"2048,8192,2048"
)
echo "CK Tile Additional Shapes Benchmark" > "$OUT_LOG"
echo "GPU ID: 0" >> "$OUT_LOG"
echo "Implementation: gemm_universal" >> "$OUT_LOG"
echo "" >> "$OUT_LOG"
SHAPE_IDX=0
for SHAPE in "${SHAPES[@]}"; do
IFS=',' read -r M N K <<< "$SHAPE"
SHAPE_IDX=$((SHAPE_IDX + 1))
echo "========================================" >> "$OUT_LOG"
echo "Shape $SHAPE_IDX: M=$M N=$N K=$K dtype=fp8 layout=rcr" >> "$OUT_LOG"
echo "========================================" >> "$OUT_LOG"
KERNEL_COUNT=0
for EXE in "$BIN_DIR"/benchmark_gemm_universal_fp8_rcr_*; do
KERNEL_COUNT=$((KERNEL_COUNT + 1))
OUTPUT=$("$EXE" -m="$M" -n="$N" -k="$K" -warmup=$WARMUP -repeat=$REPEAT -verify=0 2>/dev/null)
# Extract just the JSON block
echo "$OUTPUT" | sed -n '/{/,/^}/p' >> "$OUT_LOG"
done
echo "Found $KERNEL_COUNT kernels" >> "$OUT_LOG"
echo "Completed shape $SHAPE_IDX: M=$M N=$N K=$K ($KERNEL_COUNT kernels)" >&2
done
echo "Done generating additional data" >&2