mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK][CK TILE]Autotuning heuristics infra for universal GEMM kernel selection (#5676)
## Motivation This PR adds ML-based kernel selection heuristics to the CK Tile dispatcher, enabling fast and accurate automatic kernel selection for Universal Gemm kernels. Instead of requiring exhaustive search through 4600+ kernel configurations (taking ~46 seconds per problem shape), the ML heuristic predicts optimal kernels in microseconds while achieving >98% of oracle-best performance. ## Technical Details **ML infrastructure** https://github.com/ROCm/rocm-libraries/tree/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics * Feature Engine ([feature_engine.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/feature_engine.py)): 55-feature extraction including problem dimensions, kernel configuration, tile efficiency, and hardware profile * Training Pipeline ([train.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/train.py)): LightGBM regression with log-transform, GroupKFold cross-validation, warm-start support * Predictor ([predict.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/predict.py)): Kernel ranking and TFLOPS prediction for problem shapes * Evaluation ([evaluate.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/evaluate.py)): Comprehensive metrics including efficiency, NDCG@k, shape family analysis **Data Generation Tools:** * [generate_benchmark_data.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/generate_benchmark_data.py): Build and benchmark kernels across diverse problem shapes * [convert_json_to_parquet.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/convert_json_to_parquet.py): Convert benchmark JSON to training-ready parquet format * [data_pipeline.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/heuristics/data_pipeline.py): Parse streaming benchmark logs into canonical datasets **Examples** * [09_ml_heuristic.cpp](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp): C++ example demonstrating ML-based kernel selection * [09_ml_heuristic.py](https://github.com/ROCm/rocm-libraries/blob/users/vanantha/ck/dispatcher-heuristics/projects/composablekernel/dispatcher/examples/gemm/python/09_ml_heuristic.py): Python example with validation **Pre-trained Models (projects/composablekernel/dispatcher/heuristics/models/):** * gemm_universal_fp8_gfx950/: fp8 RCR model (42K trees, 97.51% mean efficiency) * gemm_universal_fp16_gfx950/: fp16 RCR model (20K trees, 99.36% mean efficiency) ## Test Plan * Evaluated on 25 diverse shapes for fp16, 168 shapes for fp8 * All shape families tested: tiny M (M<8), small M, medium M, large M (M≥1024) * All pipeline types: compv3, compv4, mem ## Test Result **fp16 Model (gfx950, RCR layout)** * Mean Efficiency: 99.36% * P10 Efficiency: 98.05% (90th percentile of shapes achieve ≥98% of oracle best) * Min Efficiency: 95.45% **fp8 Model (gfx950, RCR layout)** * Mean Efficiency: 98.28% (original), 97.51% (wide coverage) * P10 Efficiency: 94.64% (original), 93.89% (wide coverage) * Min Efficiency: 84.5% ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Vidyasagar Ananthan <vidyasagar.ananthan@amd.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
25
.gitignore
vendored
25
.gitignore
vendored
@@ -114,3 +114,28 @@ experimental/grouped_convolution_tile_instances/instances/*
|
||||
!experimental/grouped_convolution_tile_instances/instances/*.inc
|
||||
!experimental/grouped_convolution_tile_instances/instances/*.hpp
|
||||
experimental/grouped_convolution_tile_instances/*.inc
|
||||
# Heuristics: benchmark data (never in git)
|
||||
dispatcher/heuristics/data/
|
||||
|
||||
# Heuristics: experimental/training artifacts (exclude from git)
|
||||
dispatcher/heuristics/models/**/oof_predictions.parquet
|
||||
dispatcher/heuristics/models/**/cv_metrics_*.json
|
||||
dispatcher/heuristics/models/**/eval_report.json
|
||||
dispatcher/heuristics/models/**/feature_importances_*.json
|
||||
dispatcher/heuristics/models/**/model_tflops_ihem.lgbm
|
||||
dispatcher/heuristics/models/**/model_tflops_log.lgbm
|
||||
dispatcher/heuristics/models/**/model_tflops_log_big.lgbm
|
||||
|
||||
# Heuristics: keep in git (production model files):
|
||||
# models/{op}_{dtype}_{arch}/model_tflops.lgbm
|
||||
# models/{op}_{dtype}_{arch}/model_latency.lgbm
|
||||
# models/{op}_{dtype}_{arch}/model_bandwidth.lgbm
|
||||
# models/{op}_{dtype}_{arch}/feature_spec.json
|
||||
# models/{op}_{dtype}_{arch}/train_manifest.json
|
||||
|
||||
# Heuristics: logs and caches
|
||||
dispatcher/heuristics/*.log
|
||||
dispatcher/heuristics/__pycache__/
|
||||
dispatcher/heuristics/tests/__pycache__/
|
||||
dispatcher/heuristics/.pytest_cache/
|
||||
|
||||
|
||||
@@ -154,6 +154,8 @@ rocminfo | grep -i "gfx"
|
||||
|
||||
### Install Python Dependencies
|
||||
|
||||
#### Core Dependencies (Required)
|
||||
|
||||
NumPy is required for Python examples and kernel generation. We recommend using a virtual environment:
|
||||
|
||||
**Option 1: Using standard venv**
|
||||
@@ -165,8 +167,8 @@ python3 -m venv .venv
|
||||
source .venv/bin/activate # Linux/macOS
|
||||
# .venv\Scripts\activate # Windows
|
||||
|
||||
# Install NumPy
|
||||
pip install numpy
|
||||
# Install core dependencies
|
||||
pip install -r python/requirements.txt
|
||||
```
|
||||
|
||||
**Option 2: Using uv (faster alternative)**
|
||||
@@ -179,17 +181,38 @@ uv venv .venv
|
||||
source .venv/bin/activate # Linux/macOS
|
||||
# .venv\Scripts\activate # Windows
|
||||
|
||||
# Install NumPy
|
||||
uv pip install numpy
|
||||
# Install core dependencies
|
||||
uv pip install -r python/requirements.txt
|
||||
```
|
||||
|
||||
**Option 3: System-wide install (not recommended)**
|
||||
```bash
|
||||
pip install numpy
|
||||
pip install -r python/requirements.txt
|
||||
```
|
||||
|
||||
> **Note:** Always activate your virtual environment before running CMake or Python examples.
|
||||
|
||||
#### ML Heuristics Dependencies (Optional)
|
||||
|
||||
For ML-based kernel selection (examples 09-11), install additional dependencies:
|
||||
|
||||
```bash
|
||||
# Activate your virtual environment first
|
||||
source .venv/bin/activate
|
||||
|
||||
# Install ML dependencies (LightGBM, pandas, pyarrow, scikit-learn)
|
||||
pip install -r requirements-ml.txt
|
||||
```
|
||||
|
||||
**Why separate?** ML dependencies are large (especially pyarrow) and not needed for basic dispatcher usage. Install only if you need:
|
||||
- ML-based kernel selection (`examples/gemm/python/09_ml_heuristic.py`)
|
||||
- Model training (`heuristics/train.py`)
|
||||
- Model evaluation (`heuristics/evaluate.py`)
|
||||
- Automated benchmark analysis
|
||||
|
||||
**Core dependencies:** ~50 MB (NumPy only)
|
||||
**With ML dependencies:** ~500 MB (includes LightGBM, pandas, pyarrow, scikit-learn)
|
||||
|
||||
### Supported Data Types
|
||||
|
||||
CK Tile supports a wide range of data types for GEMM operations:
|
||||
@@ -470,6 +493,42 @@ python3 examples/gemm/python/10_advanced_benchmark.py \
|
||||
|
||||
---
|
||||
|
||||
## ML-Based Kernel Selection (Optional)
|
||||
|
||||
The dispatcher includes ML heuristics for automated kernel selection using trained LightGBM models.
|
||||
|
||||
**Prerequisites:** Install ML dependencies first:
|
||||
|
||||
```bash
|
||||
pip install -r requirements-ml.txt # ~500 MB (LightGBM, pandas, pyarrow, scikit-learn)
|
||||
```
|
||||
|
||||
**Documentation:** See [heuristics/README.md](heuristics/README.md) for:
|
||||
- Training and evaluating models
|
||||
- Feature engineering (72 features)
|
||||
- Using pre-trained models
|
||||
- Python API reference
|
||||
|
||||
**Examples:**
|
||||
```bash
|
||||
python3 examples/gemm/python/09_ml_heuristic.py # ML-based kernel selection
|
||||
python3 examples/gemm/python/10_rank_kernels.py # Kernel ranking
|
||||
```
|
||||
|
||||
**Model Compression:** Trained models are stored in compressed `.lgbm.gz` format to save space (~67% size reduction). Python tools automatically decompress models on first use. For C++ examples, decompress manually:
|
||||
|
||||
```bash
|
||||
# If you have compressed models
|
||||
cd heuristics/models/gemm_universal_fp16_gfx950
|
||||
gunzip model_tflops.lgbm.gz
|
||||
|
||||
# Then use in C++ example
|
||||
cd ../../../build
|
||||
./gemm_09_ml_heuristic --model ../heuristics/models/gemm_universal_fp16_gfx950/model_tflops.lgbm
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## External Integration
|
||||
|
||||
### Using Dispatcher in Your Own Project
|
||||
|
||||
@@ -346,6 +346,55 @@ add_declarative_gpu_example(gemm_04_heuristics gemm/cpp/04_heuristics.
|
||||
add_declarative_gpu_example(gemm_05_json_export gemm/cpp/05_json_export.cpp)
|
||||
add_declarative_gpu_example(gemm_06_multi_registry gemm/cpp/06_multi_registry.cpp)
|
||||
|
||||
# ML Heuristic example -- requires LightGBM shared library
|
||||
# Derive site-packages from active Python interpreter (respects virtualenvs)
|
||||
find_package(Python3 COMPONENTS Interpreter)
|
||||
|
||||
set(LIGHTGBM_SEARCH_PATHS)
|
||||
if(Python3_FOUND AND Python3_EXECUTABLE)
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -c "import sysconfig; print(sysconfig.get_path('purelib'))"
|
||||
OUTPUT_VARIABLE PYTHON_SITE_PACKAGES
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
ERROR_QUIET
|
||||
)
|
||||
if(PYTHON_SITE_PACKAGES)
|
||||
list(APPEND LIGHTGBM_SEARCH_PATHS "${PYTHON_SITE_PACKAGES}/lightgbm/lib")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Fallback to common Python 3.x site-packages if auto-detection failed
|
||||
if(NOT PYTHON_SITE_PACKAGES)
|
||||
list(APPEND LIGHTGBM_SEARCH_PATHS
|
||||
"$ENV{HOME}/.local/lib/python3.12/site-packages/lightgbm/lib"
|
||||
)
|
||||
endif()
|
||||
|
||||
find_library(LIGHTGBM_LIB NAMES LightGBM lib_lightgbm _lightgbm
|
||||
HINTS ${CMAKE_PREFIX_PATH}
|
||||
PATHS ${LIGHTGBM_SEARCH_PATHS}
|
||||
NO_DEFAULT_PATH
|
||||
DOC "LightGBM shared library for ML heuristics"
|
||||
)
|
||||
|
||||
# Fallback: search default paths (respects LightGBM_DIR if set by user)
|
||||
if(NOT LIGHTGBM_LIB)
|
||||
find_library(LIGHTGBM_LIB NAMES LightGBM lib_lightgbm)
|
||||
endif()
|
||||
|
||||
if(LIGHTGBM_LIB)
|
||||
add_declarative_gpu_example(gemm_09_ml_heuristic gemm/cpp/09_ml_heuristic.cpp)
|
||||
target_link_libraries(gemm_09_ml_heuristic PRIVATE ${LIGHTGBM_LIB})
|
||||
message(STATUS "LightGBM found: ${LIGHTGBM_LIB} -- building gemm_09_ml_heuristic")
|
||||
else()
|
||||
message(STATUS "LightGBM not found -- skipping gemm_09_ml_heuristic")
|
||||
message(STATUS " To enable ML heuristic example:")
|
||||
message(STATUS " 1. Activate virtualenv: source .venv/bin/activate")
|
||||
message(STATUS " 2. Install: pip install -r ../requirements-ml.txt")
|
||||
message(STATUS " 3. Reconfigure: cmake ..")
|
||||
message(STATUS " Or set CMAKE_PREFIX_PATH or LightGBM_DIR to LightGBM location")
|
||||
endif()
|
||||
|
||||
# =============================================================================
|
||||
# GEMM Python Library - Single Fallback Kernel
|
||||
# =============================================================================
|
||||
|
||||
211
dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp
Normal file
211
dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp
Normal file
@@ -0,0 +1,211 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* Example 09: ML-Based Kernel Selection (Native C++)
|
||||
*
|
||||
* Uses a trained LightGBM model loaded via the C API to predict TFLOPS
|
||||
* for each kernel in the registry and select the best one. The kernels
|
||||
* are JIT-compiled at build time via DECL_KERNEL_SET (same as other examples).
|
||||
*
|
||||
* Build: cd dispatcher/build && cmake .. && make gemm_09_ml_heuristic
|
||||
* Run: ./gemm_09_ml_heuristic --model <path_to_model.lgbm>
|
||||
*/
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <vector>
|
||||
#include <chrono>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/kernel_decl.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
#include "ck_tile/dispatcher/ml_heuristic.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
using Signature = decl::Signature;
|
||||
using Algorithm = decl::Algorithm;
|
||||
|
||||
// Multiple kernel configs for ML to choose from
|
||||
DECL_KERNEL_SET(ml_kernels,
|
||||
// Small tiles
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(64, 64, 32)
|
||||
.wave(2, 2, 1)
|
||||
.warp(16, 16, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942")
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(64, 64, 64)
|
||||
.wave(2, 2, 1)
|
||||
.warp(16, 16, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942")
|
||||
// Medium tiles
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(128, 128, 32)
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942")
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(128, 128, 64)
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942")
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(128, 128, 64)
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv4")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942")
|
||||
// Large tiles
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(256, 256, 32)
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942")
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(256, 128, 32)
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942")
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(128, 256, 32)
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942"));
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 09: ML-Based Kernel Selection",
|
||||
"Uses trained LightGBM model for kernel selection");
|
||||
args.add_option("--arch", "gfx942", "GPU architecture");
|
||||
args.add_option("--model", "", "Path to LightGBM model file (.lgbm)");
|
||||
args.add_option("--log_transform", "false", "Model uses log1p transform");
|
||||
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
print_header("Example 09: ML-Based Kernel Selection");
|
||||
|
||||
std::string gfx_arch = args.get("--arch", "gfx942");
|
||||
std::string model_path = args.get("--model", "");
|
||||
bool log_transform = (args.get("--log_transform", "false") == "true");
|
||||
|
||||
if(model_path.empty())
|
||||
{
|
||||
std::cerr << "Error: --model <path> is required" << std::endl;
|
||||
std::cerr << "Usage: ./gemm_09_ml_heuristic --model path/to/model_tflops.lgbm" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Setup Registry (kernels are JIT compiled from DECL_KERNEL_SET above)
|
||||
Registry registry;
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << "Registry: " << registry.size() << " kernel(s)" << std::endl;
|
||||
|
||||
// Load ML model and create heuristic
|
||||
HardwareProfile hw;
|
||||
MLHeuristic ml_heuristic(model_path, ®istry, hw, log_transform);
|
||||
if(!ml_heuristic.is_loaded())
|
||||
{
|
||||
std::cerr << "Failed to load model. Exiting." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Wire ML heuristic into dispatcher
|
||||
Dispatcher dispatcher(®istry);
|
||||
dispatcher.set_strategy(Dispatcher::SelectionStrategy::Heuristic);
|
||||
dispatcher.set_heuristic([&ml_heuristic](const Problem& p) { return ml_heuristic(p); });
|
||||
|
||||
std::cout << "Strategy: ML Heuristic (LightGBM)" << std::endl;
|
||||
|
||||
// Test with different problem sizes
|
||||
using DataType = ck_tile::fp16_t;
|
||||
std::vector<std::tuple<int, int, int>> sizes = {
|
||||
{128, 128, 64},
|
||||
{512, 512, 256},
|
||||
{1024, 1024, 512},
|
||||
{2048, 2048, 1024},
|
||||
};
|
||||
|
||||
std::cout << std::endl
|
||||
<< std::setw(20) << "Shape" << std::setw(30) << "Selected Kernel" << std::setw(15)
|
||||
<< "Pred TFLOPS" << std::setw(12) << "Select ms" << std::setw(10) << "Status"
|
||||
<< std::endl;
|
||||
std::cout << std::string(87, '-') << std::endl;
|
||||
|
||||
bool all_passed = true;
|
||||
|
||||
for(const auto& [M, N, K] : sizes)
|
||||
{
|
||||
Problem problem;
|
||||
problem.M = M;
|
||||
problem.N = N;
|
||||
problem.K = K;
|
||||
problem.k_batch = 1;
|
||||
|
||||
auto t0 = std::chrono::high_resolution_clock::now();
|
||||
auto kernel = dispatcher.select_kernel(problem);
|
||||
auto t1 = std::chrono::high_resolution_clock::now();
|
||||
double select_ms = std::chrono::duration<double, std::milli>(t1 - t0).count();
|
||||
|
||||
std::string size_str =
|
||||
std::to_string(M) + "x" + std::to_string(N) + "x" + std::to_string(K);
|
||||
|
||||
if(!kernel)
|
||||
{
|
||||
std::cout << std::setw(20) << size_str << std::setw(30) << "NONE" << std::setw(15)
|
||||
<< "N/A" << std::setw(12) << std::fixed << std::setprecision(2) << select_ms
|
||||
<< std::setw(10) << "FAIL" << std::endl;
|
||||
all_passed = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
double pred = ml_heuristic.predict_tflops(problem, kernel->get_key());
|
||||
std::string name = kernel->get_key().encode_identifier();
|
||||
if(name.length() > 27)
|
||||
name = name.substr(0, 27) + "..";
|
||||
|
||||
std::cout << std::setw(20) << size_str << std::setw(30) << name << std::setw(15)
|
||||
<< std::fixed << std::setprecision(2) << pred << std::setw(12)
|
||||
<< std::setprecision(2) << select_ms << std::setw(10) << "OK" << std::endl;
|
||||
}
|
||||
|
||||
std::cout << std::endl
|
||||
<< (all_passed ? "*** ALL TESTS PASSED ***" : "*** SOME TESTS FAILED ***")
|
||||
<< std::endl;
|
||||
|
||||
return all_passed ? 0 : 1;
|
||||
}
|
||||
305
dispatcher/examples/gemm/python/09_ml_heuristic.py
Normal file
305
dispatcher/examples/gemm/python/09_ml_heuristic.py
Normal file
@@ -0,0 +1,305 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 09: ML-Based Kernel Selection
|
||||
|
||||
Uses a trained LightGBM model to select the optimal kernel for each problem
|
||||
size. The model predicts TFLOPS for every candidate in the kernel pool and
|
||||
picks the highest-scoring one, which is then JIT-compiled and run.
|
||||
|
||||
This replaces the hand-crafted rules in 08_heuristics.py with a data-driven
|
||||
approach achieving 97-98% of oracle-best TFLOPS efficiency.
|
||||
|
||||
Complexity: *****
|
||||
|
||||
Prerequisites:
|
||||
- Trained model in dispatcher/heuristics/models/gemm_universal_fp8_gfx950/
|
||||
- lightgbm, pandas, numpy, pyarrow installed
|
||||
|
||||
Usage:
|
||||
python3 09_ml_heuristic.py
|
||||
python3 09_ml_heuristic.py --dtype fp16 --arch gfx942
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "heuristics"))
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
)
|
||||
|
||||
from predict import Predictor
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelSpec:
|
||||
"""Kernel specification -- same structure as 08_heuristics.py"""
|
||||
|
||||
name: str
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
pipeline: str = "compv3"
|
||||
scheduler: str = "intrawave"
|
||||
wave_m: int = 2
|
||||
wave_n: int = 2
|
||||
wave_k: int = 1
|
||||
warp_m: int = 32
|
||||
warp_n: int = 32
|
||||
warp_k: int = 16
|
||||
|
||||
|
||||
# Kernel pool: representative configs spanning small to large tiles,
|
||||
# compv3/compv4/mem pipelines, and intrawave/interwave schedulers.
|
||||
KERNEL_POOL = [
|
||||
# Small tiles
|
||||
KernelSpec("s_64x64_k32_v3", 64, 64, 32, "compv3", warp_m=16, warp_n=16),
|
||||
KernelSpec("s_64x64_k64_v3", 64, 64, 64, "compv3", warp_m=16, warp_n=16),
|
||||
KernelSpec("s_64x64_k128_v3", 64, 64, 128, "compv3", warp_m=16, warp_n=16),
|
||||
KernelSpec("s_64x64_k32_v4", 64, 64, 32, "compv4", warp_m=16, warp_n=16),
|
||||
KernelSpec("s_64x64_k64_mem", 64, 64, 64, "mem", warp_m=16, warp_n=16),
|
||||
KernelSpec("s_64x64_k128_mem", 64, 64, 128, "mem", warp_m=16, warp_n=16),
|
||||
# Medium tiles
|
||||
KernelSpec("m_128x128_k32_v3", 128, 128, 32, "compv3"),
|
||||
KernelSpec("m_128x128_k64_v3", 128, 128, 64, "compv3"),
|
||||
KernelSpec("m_128x128_k128_v3", 128, 128, 128, "compv3"),
|
||||
KernelSpec("m_128x128_k32_v4", 128, 128, 32, "compv4"),
|
||||
KernelSpec("m_128x128_k64_v4", 128, 128, 64, "compv4"),
|
||||
KernelSpec("m_128x128_k64_mem", 128, 128, 64, "mem"),
|
||||
KernelSpec("m_128x128_k128_mem", 128, 128, 128, "mem"),
|
||||
# Rectangular medium
|
||||
KernelSpec("r_64x128_k32", 64, 128, 32, "compv3", warp_m=16),
|
||||
KernelSpec("r_128x64_k32", 128, 64, 32, "compv3", warp_n=16),
|
||||
KernelSpec("r_64x128_k64", 64, 128, 64, "compv3", warp_m=16),
|
||||
KernelSpec("r_128x64_k64", 128, 64, 64, "compv3", warp_n=16),
|
||||
# Large tiles
|
||||
KernelSpec("l_256x128_k32", 256, 128, 32, "compv3"),
|
||||
KernelSpec("l_128x256_k32", 128, 256, 32, "compv3"),
|
||||
KernelSpec("l_256x256_k32", 256, 256, 32, "compv3"),
|
||||
KernelSpec("l_256x256_k64", 256, 256, 64, "compv3"),
|
||||
# Interwave variants
|
||||
KernelSpec("m_128x128_k64_iw", 128, 128, 64, "compv3", "interwave"),
|
||||
KernelSpec("m_128x128_k64_mem_iw", 128, 128, 64, "mem", "interwave"),
|
||||
]
|
||||
|
||||
|
||||
def spec_to_feature_dict(spec: KernelSpec, dtype: str, layout: str) -> dict:
|
||||
"""Convert a KernelSpec to the dict format the feature engine expects.
|
||||
|
||||
Note: pad_m/n/k default to True to match KernelConfig defaults and actual
|
||||
compiled kernels. This ensures the ML model receives the correct padding
|
||||
flags that will be used during JIT compilation.
|
||||
"""
|
||||
return {
|
||||
"kernel_name": spec.name,
|
||||
"tile_m": spec.tile_m,
|
||||
"tile_n": spec.tile_n,
|
||||
"tile_k": spec.tile_k,
|
||||
"warp_m": spec.wave_m,
|
||||
"warp_n": spec.wave_n,
|
||||
"warp_k": spec.wave_k,
|
||||
"warp_tile_m": spec.warp_m,
|
||||
"warp_tile_n": spec.warp_n,
|
||||
"warp_tile_k": spec.warp_k,
|
||||
"pipeline": spec.pipeline,
|
||||
"scheduler": spec.scheduler,
|
||||
"epilogue": "cshuffle",
|
||||
"pad_m": True, # Match KernelConfig default
|
||||
"pad_n": True, # Match KernelConfig default
|
||||
"pad_k": True, # Match KernelConfig default
|
||||
"persistent": False,
|
||||
"dtype": dtype,
|
||||
"layout": layout,
|
||||
}
|
||||
|
||||
|
||||
def spec_to_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig:
|
||||
"""Convert a KernelSpec to the dispatcher's KernelConfig for JIT compilation."""
|
||||
return KernelConfig(
|
||||
dtype_a=dtype,
|
||||
dtype_b=dtype,
|
||||
dtype_c=dtype,
|
||||
dtype_acc="fp32",
|
||||
layout_a="row",
|
||||
layout_b="col",
|
||||
layout_c="row",
|
||||
tile_m=spec.tile_m,
|
||||
tile_n=spec.tile_n,
|
||||
tile_k=spec.tile_k,
|
||||
wave_m=spec.wave_m,
|
||||
wave_n=spec.wave_n,
|
||||
wave_k=spec.wave_k,
|
||||
warp_m=spec.warp_m,
|
||||
warp_n=spec.warp_n,
|
||||
warp_k=spec.warp_k,
|
||||
pipeline=spec.pipeline,
|
||||
scheduler=spec.scheduler,
|
||||
epilogue="cshuffle",
|
||||
gfx_arch=arch,
|
||||
)
|
||||
|
||||
|
||||
def ml_select_kernel(
|
||||
predictor: Predictor,
|
||||
pool: List[KernelSpec],
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
dtype: str,
|
||||
layout: str,
|
||||
) -> tuple:
|
||||
"""Score all kernels in the pool and return (best_spec, predicted_tflops)."""
|
||||
problem = {"m": M, "n": N, "k": K, "dtype": dtype, "layout": layout, "split_k": 1}
|
||||
kernel_dicts = [spec_to_feature_dict(s, dtype, layout) for s in pool]
|
||||
|
||||
ranked = predictor.rank_kernels(problem, kernel_dicts)
|
||||
if not ranked:
|
||||
return pool[0], 0.0
|
||||
|
||||
best_name, best_tflops = ranked[0]
|
||||
best_spec = next((s for s in pool if s.name == best_name), pool[0])
|
||||
return best_spec, best_tflops
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ML-based kernel selection for GEMM")
|
||||
parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16", "fp8"])
|
||||
parser.add_argument("--arch", default="gfx942")
|
||||
parser.add_argument(
|
||||
"--model_dir",
|
||||
default=str(
|
||||
Path(__file__).parent.parent.parent.parent
|
||||
/ "heuristics"
|
||||
/ "models"
|
||||
/ "gemm_universal_fp8_gfx950"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_run", action="store_true", help="Only predict, don't run GEMMs"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 75)
|
||||
print(" Example 09: ML-Based Kernel Selection")
|
||||
print("=" * 75)
|
||||
print(f"\n Model: {args.model_dir}")
|
||||
print(f" Dtype: {args.dtype}")
|
||||
print(f" Arch: {args.arch}")
|
||||
print(f" Pool: {len(KERNEL_POOL)} kernels")
|
||||
|
||||
predictor = Predictor(args.model_dir)
|
||||
print(" Model loaded successfully")
|
||||
|
||||
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float16
|
||||
|
||||
test_sizes = [
|
||||
(128, 128, 64),
|
||||
(256, 256, 128),
|
||||
(512, 512, 256),
|
||||
(1024, 1024, 512),
|
||||
(2048, 2048, 1024),
|
||||
]
|
||||
|
||||
header = f"{'Shape':<20} {'Selected Kernel':<25} {'Pred TFLOPS':>12}"
|
||||
if not args.no_run:
|
||||
header += f" {'Time (ms)':>10} {'TFLOPS':>10} {'Status':<8}"
|
||||
print(f"\n {header}")
|
||||
print(" " + "-" * len(header))
|
||||
|
||||
results = []
|
||||
|
||||
for M, N, K in test_sizes:
|
||||
t0 = time.time()
|
||||
best_spec, pred_tflops = ml_select_kernel(
|
||||
predictor, KERNEL_POOL, M, N, K, args.dtype, "rcr"
|
||||
)
|
||||
_ = (time.time() - t0) * 1000 # ML selection time (unused)
|
||||
|
||||
size_str = f"{M}x{N}x{K}"
|
||||
line = f" {size_str:<20} {best_spec.name:<25} {pred_tflops:>12.2f}"
|
||||
|
||||
if args.no_run:
|
||||
print(line)
|
||||
results.append((size_str, best_spec.name, True, 0, pred_tflops))
|
||||
continue
|
||||
|
||||
config = spec_to_kernel_config(best_spec, args.dtype, args.arch)
|
||||
|
||||
setup = setup_gemm_dispatcher(
|
||||
config=config,
|
||||
registry_name=f"ml_{best_spec.name}",
|
||||
verbose=False,
|
||||
auto_rebuild=True,
|
||||
)
|
||||
|
||||
if not setup.success:
|
||||
line += f" {'N/A':>10} {'N/A':>10} {'BUILD':>8}"
|
||||
print(line)
|
||||
results.append((size_str, best_spec.name, False, 0, 0))
|
||||
cleanup_gemm()
|
||||
continue
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
line += f" {'N/A':>10} {'N/A':>10} {'UNSUP':>8}"
|
||||
print(line)
|
||||
results.append((size_str, best_spec.name, False, 0, 0))
|
||||
cleanup_gemm()
|
||||
continue
|
||||
|
||||
np.random.seed(42)
|
||||
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
|
||||
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
|
||||
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
|
||||
if result.success:
|
||||
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(
|
||||
np_dtype
|
||||
)
|
||||
max_err = np.max(np.abs(result.output - C_ref))
|
||||
passed = max_err < 1e-2
|
||||
status = "PASS" if passed else "FAIL"
|
||||
line += f" {result.time_ms:>10.4f} {result.tflops:>10.2f} {status:<8}"
|
||||
results.append(
|
||||
(size_str, best_spec.name, passed, result.time_ms, result.tflops)
|
||||
)
|
||||
else:
|
||||
line += f" {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
|
||||
results.append((size_str, best_spec.name, False, 0, 0))
|
||||
|
||||
print(line)
|
||||
cleanup_gemm()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 75)
|
||||
print(" SUMMARY")
|
||||
print("=" * 75)
|
||||
passed = sum(1 for r in results if r[2])
|
||||
print(f"\n Results: {passed}/{len(results)} tests passed")
|
||||
valid = [r for r in results if r[2] and r[4] > 0]
|
||||
if valid:
|
||||
avg = sum(r[4] for r in valid) / len(valid)
|
||||
print(f" Average TFLOPS: {avg:.2f}")
|
||||
if passed == len(results):
|
||||
print("\n *** ALL TESTS PASSED ***")
|
||||
print("=" * 75)
|
||||
return 0 if passed == len(results) else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
60
dispatcher/heuristics/.gitignore
vendored
Normal file
60
dispatcher/heuristics/.gitignore
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
# Python bytecode and caches
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
|
||||
# Jupyter notebooks
|
||||
*.ipynb
|
||||
.ipynb_checkpoints/
|
||||
|
||||
# Virtual environments
|
||||
.venv/
|
||||
venv/
|
||||
ENV/
|
||||
|
||||
# IDE and editor files
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Test output and logs
|
||||
*.log
|
||||
test_output.log
|
||||
custom_shapes_gpu_test.log
|
||||
|
||||
# Benchmark and analysis output files
|
||||
*.csv
|
||||
*.json
|
||||
!models/*/feature_spec.json
|
||||
!models/*/train_manifest.json
|
||||
|
||||
# Data files (parquet, arrow)
|
||||
*.parquet
|
||||
*.arrow
|
||||
|
||||
# Temporary and NFS files
|
||||
.nfs*
|
||||
*.tmp
|
||||
*.bak
|
||||
|
||||
# Decompressed model files (compressed .lgbm.gz versions are tracked)
|
||||
models/**/*.lgbm
|
||||
|
||||
# User-specific test and analysis scripts
|
||||
test_*.py
|
||||
!tests/test_*.py
|
||||
find_*.py
|
||||
oracle_*.json
|
||||
validation_results_*.csv
|
||||
custom_shapes_*.csv
|
||||
fp16_bf16_*.csv
|
||||
|
||||
# Ignore all markdown files except tracked documentation
|
||||
*.md
|
||||
!DATA_GENERATION.md
|
||||
!LEARNINGS.md
|
||||
!README.md
|
||||
412
dispatcher/heuristics/DATA_GENERATION.md
Normal file
412
dispatcher/heuristics/DATA_GENERATION.md
Normal file
@@ -0,0 +1,412 @@
|
||||
# Data Generation Guide
|
||||
|
||||
This document explains how to build benchmark binaries from the CK Tile engine,
|
||||
generate benchmark datasets, and manage them for the ML kernel performance
|
||||
prediction system.
|
||||
|
||||
## Overview
|
||||
|
||||
The ML heuristic needs benchmark data: measured TFLOPS, latency, and bandwidth
|
||||
for every (problem shape, kernel config) pair. The tile engine builds one
|
||||
executable per kernel configuration. Each executable benchmarks a single kernel
|
||||
on a given problem size and outputs JSON with performance metrics.
|
||||
|
||||
```
|
||||
CK source --> CMake configure --> ninja build --> benchmark binaries
|
||||
(4608 per op/dtype/layout)
|
||||
|
||||
benchmark binaries --> run on GPU --> streaming log --> parquet dataset
|
||||
(per shape) (JSON blocks) (canonical schema)
|
||||
```
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **ROCm**: HIP >= 6.0.3 (for gfx950: HIP >= 6.0.4)
|
||||
- **Build tools**: CMake >= 3.21, Ninja, HIP-aware clang compiler
|
||||
- **Python**: 3.10+ with `pandas`, `pyarrow`
|
||||
- **GPU**: ROCm-capable AMD GPU (MI250X, MI300X, MI355X, etc.)
|
||||
|
||||
---
|
||||
|
||||
## Part 1: Building Benchmark Binaries from the Tile Engine
|
||||
|
||||
If you already have pre-built binaries (e.g., in `/workspace/ck_tile/bin/`),
|
||||
skip to Part 2. This section explains how to build them from source.
|
||||
|
||||
### Step 1: CMake Configure
|
||||
|
||||
From the CK repository root:
|
||||
|
||||
```bash
|
||||
cmake -S /workspace/rocm-libraries/projects/composablekernel \
|
||||
-B build \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGPU_TARGETS="gfx950" \
|
||||
-DGEMM_UNIVERSAL_DATATYPE="fp8" \
|
||||
-DGEMM_UNIVERSAL_LAYOUT="rcr" \
|
||||
-G Ninja
|
||||
```
|
||||
|
||||
**Key CMake variables:**
|
||||
|
||||
| Variable | Default | Description |
|
||||
|---|---|---|
|
||||
| `GPU_TARGETS` | (required) | Target GPU architectures. Supported: `gfx90a`, `gfx942`, `gfx950`, `gfx1201`. Semicolon-separated for multiple. |
|
||||
| `GEMM_UNIVERSAL_DATATYPE` | `"fp8;fp16"` | Data types to build. Options: `fp8`, `fp16`, `bf16`, `bf8`. Semicolon-separated. |
|
||||
| `GEMM_UNIVERSAL_LAYOUT` | `"rcr;rrr;crr;ccr"` | Layouts to build. Semicolon-separated. |
|
||||
| `GEMM_UNIVERSAL_CONFIG_FILE` | `"default_config.json"` | Kernel config file (in the `configs/` directory). Controls which tile sizes, warp configs, pipelines, etc. are enumerated. |
|
||||
| `ENABLE_CCACHE_GEMM_UNIVERSAL` | `OFF` | Enable ccache for faster rebuilds. |
|
||||
|
||||
**Example: build only fp8 RCR for gfx950 (fastest, ~4608 kernels):**
|
||||
```bash
|
||||
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release \
|
||||
-DGPU_TARGETS="gfx950" \
|
||||
-DGEMM_UNIVERSAL_DATATYPE="fp8" \
|
||||
-DGEMM_UNIVERSAL_LAYOUT="rcr" \
|
||||
-G Ninja
|
||||
```
|
||||
|
||||
**Example: build all dtypes and layouts (slow, ~4608 * 4 * 4 = ~73K kernels):**
|
||||
```bash
|
||||
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release \
|
||||
-DGPU_TARGETS="gfx950" \
|
||||
-DGEMM_UNIVERSAL_DATATYPE="fp8;fp16;bf16;bf8" \
|
||||
-DGEMM_UNIVERSAL_LAYOUT="rcr;rrr;crr;ccr" \
|
||||
-G Ninja
|
||||
```
|
||||
|
||||
### What happens during configure
|
||||
|
||||
1. CMake calls `gemm_universal_instance_builder.py --list_kernels` to enumerate
|
||||
all valid kernel configurations from the config JSON.
|
||||
2. It writes `gemm_universal_kernel_list.txt` (one kernel per line) and
|
||||
`gemm_universal_kernel_count.txt` to the build directory.
|
||||
3. For each kernel, it creates a ninja build target.
|
||||
|
||||
### Step 2: Build
|
||||
|
||||
```bash
|
||||
# Build all benchmarks for the configured dtypes/layouts
|
||||
ninja -C build benchmark_gemm_universal_all
|
||||
|
||||
# Or build a specific dtype/layout combo
|
||||
ninja -C build benchmark_gemm_universal_fp8_rcr
|
||||
|
||||
# Or build by pipeline type
|
||||
ninja -C build benchmark_gemm_universal_compv4_pipeline
|
||||
ninja -C build benchmark_gemm_universal_mem_pipeline
|
||||
|
||||
# Or build a single specific kernel
|
||||
ninja -C build benchmark_gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128
|
||||
```
|
||||
|
||||
**Build time estimates:**
|
||||
- ~4608 kernels (one dtype, one layout): 1-4 hours depending on CPU cores
|
||||
- Use `-j <N>` to control parallelism: `ninja -C build -j 32 benchmark_gemm_universal_fp8_rcr`
|
||||
|
||||
### Step 3: Verify binaries
|
||||
|
||||
Binaries are placed in `build/bin/`:
|
||||
|
||||
```bash
|
||||
ls build/bin/benchmark_gemm_universal_fp8_rcr_* | wc -l
|
||||
# Expected: 4608 (for default config)
|
||||
|
||||
# Test one binary
|
||||
./build/bin/benchmark_gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128 \
|
||||
-m=1024 -n=1024 -k=1024 -warmup=3 -repeat=10 -verify=0
|
||||
```
|
||||
|
||||
### Kernel config files
|
||||
|
||||
The config files live in:
|
||||
```
|
||||
tile_engine/ops/gemm/gemm_universal/configs/
|
||||
default_config.json # Default: full enumeration
|
||||
default_ci_config.json # CI: reduced set for fast testing
|
||||
user_provided_config.json # Custom: your own subset
|
||||
```
|
||||
|
||||
To use a custom config:
|
||||
```bash
|
||||
cmake ... -DGEMM_UNIVERSAL_CONFIG_FILE="user_provided_config.json"
|
||||
```
|
||||
|
||||
The config controls which tile sizes (e.g., 128x128x64, 256x256x32), warp
|
||||
configurations (e.g., 2x2x1, 1x4x1), pipelines (compv3, compv4, mem),
|
||||
schedulers, and other parameters are included in the kernel enumeration.
|
||||
|
||||
### Building StreamK / other ops
|
||||
|
||||
The same pattern applies to other tile engine ops:
|
||||
|
||||
```bash
|
||||
# StreamK
|
||||
ninja -C build benchmark_gemm_streamk_fp8_rcr
|
||||
|
||||
# Grouped convolution
|
||||
ninja -C build benchmark_grouped_conv_fwd_fp16_nhwgc
|
||||
```
|
||||
|
||||
Each op has its own instance builder and config directory.
|
||||
|
||||
---
|
||||
|
||||
## Part 2: Running Benchmarks and Generating Data
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Run benchmarks for a set of shapes
|
||||
|
||||
Each binary accepts `-m=`, `-n=`, `-k=`, `-warmup=`, `-repeat=`, `-verify=` flags
|
||||
and outputs JSON to stdout:
|
||||
|
||||
```bash
|
||||
/workspace/ck_tile/bin/benchmark_gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128 \
|
||||
-m=1024 -n=1024 -k=1024 -warmup=3 -repeat=10 -verify=0
|
||||
```
|
||||
|
||||
Output:
|
||||
```json
|
||||
{
|
||||
"name": "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_...",
|
||||
"problem": {
|
||||
"split_k": 1, "m": 1024, "n": 1024, "k": 1024,
|
||||
"dtype_a": "fp8", "dtype_b": "fp8", ...
|
||||
},
|
||||
"perf_result": {
|
||||
"latency(ms)": 0.04,
|
||||
"tflops(TFlops)": 204.60,
|
||||
"bandwidth(GB/s)": 624.39
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Batch generation using provided scripts
|
||||
|
||||
**Wide coverage (diverse shapes across all regimes):**
|
||||
```bash
|
||||
python3 generate_wide_coverage.py \
|
||||
--bin_dir /workspace/ck_tile/bin \
|
||||
--out_dir data/wide_coverage \
|
||||
--batch_size 25 \
|
||||
--warmup 3 --repeat 10
|
||||
```
|
||||
|
||||
**Edge-case dimensions (N=1, K=1, small N/K):**
|
||||
```bash
|
||||
python3 generate_edge_dims.py
|
||||
```
|
||||
|
||||
Both scripts write streaming log files that `data_pipeline.py` can parse.
|
||||
|
||||
### 3. Parse logs into parquet
|
||||
|
||||
```bash
|
||||
python3 data_pipeline.py <log_file> \
|
||||
-o data/my_dataset.parquet \
|
||||
--arch gfx950 \
|
||||
--capture_hw
|
||||
```
|
||||
|
||||
The `--capture_hw` flag runs `rocminfo` once and injects the GPU hardware
|
||||
profile (CU count, clock speed, cache sizes, etc.) into every row.
|
||||
|
||||
## Canonical Data Schema
|
||||
|
||||
Every parquet file follows this schema:
|
||||
|
||||
| Column | Type | Description |
|
||||
|---|---|---|
|
||||
| `op_type` | str | `gemm_universal`, `gemm_streamk`, etc. |
|
||||
| `dtype` | str | `fp8`, `fp16`, `bf16`, `bf8` |
|
||||
| `layout` | str | `rcr`, `rrr`, `crr`, `ccr` |
|
||||
| `arch` | str | `gfx942`, `gfx950`, etc. |
|
||||
| `kernel_name` | str | Full kernel identifier |
|
||||
| `m`, `n`, `k` | int | Problem dimensions |
|
||||
| `split_k` | int | Split-K factor (1 = standard) |
|
||||
| `measured_tflops` | float | Ground-truth TFLOPS |
|
||||
| `latency_ms` | float | Measured latency |
|
||||
| `bandwidth_gb_s` | float | Measured bandwidth |
|
||||
| `is_valid` | bool | True if tflops > 0 and latency > 0 |
|
||||
| `tile_m`, `tile_n`, `tile_k` | int | Tile dimensions |
|
||||
| `warp_m`, `warp_n`, `warp_k` | int | Warp config |
|
||||
| `warp_tile_m/n/k` | int | Warp tile dimensions |
|
||||
| `pipeline` | str | `compv3`, `compv4`, `mem`, etc. |
|
||||
| `scheduler` | str | `intrawave`, `interwave` |
|
||||
| `epilogue` | str | `cshuffle`, `default` |
|
||||
| `pad_m`, `pad_n`, `pad_k` | bool | Padding flags |
|
||||
| `persistent` | bool | Persistent kernel flag |
|
||||
| `run_id` | str | Unique collection run identifier |
|
||||
|
||||
## Shape Selection Guidelines
|
||||
|
||||
Good training data requires diverse shapes. Cover all of these regimes:
|
||||
|
||||
### By M dimension (batch size / output rows)
|
||||
- **M=1**: single-token inference (hardest case for tiling)
|
||||
- **Tiny M (2-16)**: small batch inference
|
||||
- **Small M (32-128)**: medium batch
|
||||
- **Medium M (256-2048)**: large batch / training
|
||||
- **Large M (4096-20480)**: very large batch
|
||||
|
||||
### By N and K dimension
|
||||
- **N=1**: vector-matrix multiply (degenerate)
|
||||
- **K=1**: rank-1 update / outer product (degenerate)
|
||||
- **Small N or K (2-16)**: stress tile efficiency
|
||||
- **Deep K (K > 4096)**: compute-bound regime
|
||||
- **Shallow K (K < 256)**: memory-bound regime
|
||||
|
||||
### By shape family
|
||||
- **Square**: M ~ N ~ K (powers of 2)
|
||||
- **Tall**: M >> N (tall output matrix)
|
||||
- **Wide**: N >> M (wide output matrix)
|
||||
- **Deep-K**: K >> M and K >> N
|
||||
|
||||
### Special cases
|
||||
- **Prime dimensions**: 17, 31, 127, 251, 509, 1021, 2039, 4093
|
||||
(worst-case for tile alignment, tests padding logic)
|
||||
- **Non-power-of-2**: 48, 96, 192, 384, 576, 768, 1536, 3072, 4608
|
||||
(common in LLM architectures)
|
||||
- **LLM inference shapes**: DeepSeek, LLaMA-7B, LLaMA-70B MLP/attention dims
|
||||
|
||||
### Minimum recommended coverage
|
||||
|
||||
For a production-quality model, aim for:
|
||||
- At least 200 unique (M, N, K) shapes
|
||||
- At least 10 shapes per shape family
|
||||
- All kernel configs (4608 for fp8 RCR) run against every shape
|
||||
- Multiple layouts if training a cross-layout model
|
||||
|
||||
## Benchmark Quality Guidelines
|
||||
|
||||
### Warmup and repeat
|
||||
- Minimum `warmup=3`, `repeat=10` for fast iteration
|
||||
- Production quality: `warmup=5`, `repeat=20` for stable measurements
|
||||
- The `perf_result` values are averaged over `repeat` iterations
|
||||
|
||||
### Noise handling
|
||||
- Use **median** latency when aggregating multiple runs of the same benchmark
|
||||
- Flag measurements where coefficient of variation exceeds 10%
|
||||
- Avoid benchmarking under thermal throttling (check GPU temperature)
|
||||
- Lock GPU clocks if possible for reproducibility
|
||||
|
||||
### Environment metadata
|
||||
Store with every dataset:
|
||||
- GPU model and architecture (from `rocminfo`)
|
||||
- ROCm driver version
|
||||
- Clock mode (default / locked)
|
||||
- Git hash of the CK tile engine build (if available)
|
||||
- Timestamp
|
||||
|
||||
## Adding Data for a New Op
|
||||
|
||||
To generate benchmark data for a new operation (e.g., `gemm_streamk`):
|
||||
|
||||
1. **Build the binaries** using the tile engine:
|
||||
```bash
|
||||
ninja -C build benchmark_gemm_streamk_fp8_rcr
|
||||
```
|
||||
|
||||
2. **Write a generation script** (or modify `generate_wide_coverage.py`):
|
||||
- Change the executable glob pattern to match the new op
|
||||
- Add any op-specific CLI flags the binaries need
|
||||
|
||||
3. **Run and parse**:
|
||||
```bash
|
||||
python3 data_pipeline.py my_streamk_run.log \
|
||||
-o data/gemm_streamk_fp8_gfx950.parquet --arch gfx950
|
||||
```
|
||||
|
||||
4. **Train**:
|
||||
```bash
|
||||
python3 train.py --op gemm_streamk --dtype fp8 --arch gfx950 \
|
||||
--data_dir data/ --out_dir models/gemm_streamk_fp8_gfx950
|
||||
```
|
||||
|
||||
## Adding Data for a New Layout
|
||||
|
||||
Same binaries, same shapes -- just change the layout filter:
|
||||
|
||||
```bash
|
||||
# Build rrr binaries
|
||||
ninja -C build benchmark_gemm_universal_fp8_rrr
|
||||
|
||||
# Generate and parse
|
||||
# ... (same flow, different bin_dir or executable glob)
|
||||
|
||||
# Train a cross-layout model by putting all layouts in the same data_dir
|
||||
python3 train.py --data_dir data/ --out_dir models/gemm_universal_fp8_gfx950_all_layouts
|
||||
```
|
||||
|
||||
The feature engine includes `layout` as a categorical feature, so one model
|
||||
can handle all layouts.
|
||||
|
||||
## Incremental Data Collection
|
||||
|
||||
When you have a trained model and want to add more data:
|
||||
|
||||
1. Generate new data (new shapes, new layouts, etc.)
|
||||
2. Parse into parquet alongside existing data
|
||||
3. Warm-start from the previous model:
|
||||
```bash
|
||||
python3 train.py --data_dir data/ --out_dir models/v2 \
|
||||
--warm_start models/v1 \
|
||||
--warm_start_n_estimators 200
|
||||
```
|
||||
|
||||
This adds 200 new trees on top of the existing model. The feature schema
|
||||
must match exactly (enforced automatically).
|
||||
|
||||
## File Organization
|
||||
|
||||
Recommended directory structure:
|
||||
|
||||
```
|
||||
heuristics/
|
||||
data/
|
||||
gemm_universal_fp8_rcr_gfx950.parquet # original 108 shapes
|
||||
wide_coverage/ # batch log files
|
||||
wide_coverage_batch_001.log
|
||||
wide_coverage_batch_002.log
|
||||
...
|
||||
edge_dims/ # N=1, K=1 edge cases
|
||||
edge_dims_batch_001.log
|
||||
...
|
||||
models/
|
||||
gemm_universal_fp8_gfx950/ # trained model artifacts
|
||||
model_tflops.lgbm
|
||||
model_latency.lgbm
|
||||
model_bandwidth.lgbm
|
||||
feature_spec.json
|
||||
train_manifest.json
|
||||
cv_metrics_tflops.json
|
||||
eval_report.json
|
||||
...
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Benchmark binary exits with non-zero code
|
||||
Some kernel configs are invalid for certain problem sizes (e.g., tile_m=256
|
||||
with M=16). The data pipeline marks these as `is_valid=False` and they are
|
||||
filtered out during training. This is expected.
|
||||
|
||||
### Edge dims produce very few results
|
||||
N=1 and K=1 shapes are degenerate -- most kernel configurations have minimum
|
||||
dimension requirements and will fail or produce zero TFLOPS. The small number
|
||||
of valid results is still useful (it tells the model which configs work for
|
||||
these shapes).
|
||||
|
||||
### Benchmarks are slow
|
||||
Each shape requires running all 4608 kernel executables sequentially. At
|
||||
~0.01s per kernel, that is ~46 seconds per shape. For 700 shapes, expect
|
||||
~9 hours. Tips:
|
||||
- Run on a dedicated GPU (no other workloads)
|
||||
- Use `--batch_size 25` to get incremental output
|
||||
- Parse and train on partial data while generation continues
|
||||
|
||||
### Data from different GPUs / driver versions
|
||||
Store `run_id` and hardware metadata with each dataset. Training on mixed
|
||||
data is allowed but not recommended for production models. Filter to a
|
||||
single `run_id` or `arch` for clean experiments.
|
||||
151
dispatcher/heuristics/LEARNINGS.md
Normal file
151
dispatcher/heuristics/LEARNINGS.md
Normal file
@@ -0,0 +1,151 @@
|
||||
# Learnings and Design Decisions
|
||||
|
||||
Empirical findings from building the CK Tile kernel performance prediction system.
|
||||
These inform the current defaults and explain why certain approaches were chosen.
|
||||
|
||||
## 1. Log-Transform is Essential for Cross-Scale Accuracy
|
||||
|
||||
**Problem**: GEMM TFLOPS spans 5 orders of magnitude across different problem
|
||||
sizes. When training on raw TFLOPS, the regression loss (RMSE) is dominated by
|
||||
large shapes where absolute errors are biggest. The model learns to predict
|
||||
large shapes accurately but ignores tiny shapes where the TFLOPS values are
|
||||
much lower.
|
||||
|
||||
**Evidence** (168 shapes, 626K rows, 5-fold GroupKFold CV):
|
||||
|
||||
|
||||
| Model | Mean Eff | P10 Eff | tiny_m Eff | Min Eff |
|
||||
| ----------------------------- | ---------- | ---------- | ---------- | ---------- |
|
||||
| Raw TFLOPS (500 trees) | 92.73% | 80.24% | 84.55% | 4.26% |
|
||||
| **log1p(TFLOPS)** (500 trees) | **96.92%** | **94.34%** | **94.89%** | **60.27%** |
|
||||
| log1p(TFLOPS) (2000 trees) | 97.51% | 93.89% | 96.04% | 63.56% |
|
||||
|
||||
|
||||
**Solution**: Train on `log1p(measured_tflops)` and apply `expm1()` to
|
||||
predictions. This is now the default in `train.py`. Pass `--no_log_transform`
|
||||
to revert to raw regression (not recommended).
|
||||
|
||||
**Why log1p, not log**: `log1p(x) = log(1 + x)` handles zero and near-zero
|
||||
TFLOPS gracefully, whereas `log(x)` produces -inf for x=0.
|
||||
|
||||
## 2. Tiny-M Shapes are the Hardest Case
|
||||
|
||||
M=1 (single-token inference) shapes are fundamentally different from batch shapes:
|
||||
|
||||
- Most kernel configurations produce very low TFLOPS
|
||||
- The "best" kernel is often only marginally better than the rest
|
||||
- The oracle performance itself is very low, so any prediction error tanks efficiency
|
||||
- Many kernels fail outright (tile_m=128 with M=1 wastes 127/128 of the tile)
|
||||
|
||||
The bottom shapes in our evaluation are all M=1, with efficiencies in the
|
||||
63-70% range. These shapes have such low absolute performance that the model's
|
||||
noise floor exceeds the performance difference between kernels.
|
||||
|
||||
**Mitigation**: Log-transform helps significantly (tiny_m improved from 84% to
|
||||
96%). For production use with M=1, consider a dedicated fallback (e.g.,
|
||||
hardcoded kernel selection for M < 4 based on known-good configs).
|
||||
|
||||
## 3. IHEM (Hard Example Mining) Hurts When Scale is the Issue
|
||||
|
||||
We tried Iterative Hard Example Mining with sample reweighting (2x-5x weight
|
||||
on hard shapes). Result: it made things **worse**, degrading mean efficiency
|
||||
from 94.31% to 92.90% over 3 iterations.
|
||||
|
||||
**Why**: The hard shapes are hard because of scale mismatch, not because the
|
||||
model lacks capacity. Reweighting amplifies the small-TFLOPS rows, which
|
||||
distorts the learned relationship between features and performance for the
|
||||
majority of shapes. The log-transform was the correct fix -- it addresses the
|
||||
root cause (scale) rather than the symptom (bad predictions on tiny shapes).
|
||||
|
||||
**Lesson**: IHEM is useful when the model has capacity gaps (e.g., certain
|
||||
pipeline types are underrepresented). It is counterproductive when the issue
|
||||
is target-variable scale. Always try target transforms before reweighting.
|
||||
|
||||
## 4. GroupKFold Key = (M, N, K) Forces Generalization
|
||||
|
||||
The validation uses `GroupKFold` where the group key is `(M, N, K)` -- all
|
||||
kernels for the same shape go to the same fold. This means:
|
||||
|
||||
- The model is always evaluated on shapes it has **never seen** during training
|
||||
- Layout is excluded from the key, forcing the model to generalize across layouts
|
||||
- Since models are per-arch, `arch` is implicit (constant within one training run)
|
||||
|
||||
This is much stricter than random row splitting, where the model would see some
|
||||
kernels for each shape during training. Our efficiency numbers are conservative
|
||||
estimates of real-world performance on unseen shapes.
|
||||
|
||||
## 5. Model Size vs Accuracy Tradeoff
|
||||
|
||||
|
||||
| Config | Trees | Leaves | LR | Mean Eff | P10 Eff | Train Time |
|
||||
| ------------------ | -------- | ------- | -------- | ---------- | ---------- | ------------- |
|
||||
| Small (default v1) | 500 | 127 | 0.05 | 96.92% | 94.34% | ~20s |
|
||||
| **Big (current)** | **2000** | **255** | **0.02** | **97.51%** | **93.89%** | **~25s/fold** |
|
||||
|
||||
|
||||
The bigger model improved mean efficiency by 0.6% but P10 didn't improve
|
||||
(actually slightly worse). The extra capacity helps on medium shapes but
|
||||
doesn't crack the tiny-M floor. This suggests the feature set, not model
|
||||
capacity, is the limiting factor for the hardest shapes.
|
||||
|
||||
For C++ deployment, the bigger model (2000 trees, 255 leaves) is still fast
|
||||
enough -- LightGBM inference is O(trees * log(leaves)) per sample, which is
|
||||
~microseconds even at 2000 trees.
|
||||
|
||||
## 6. N=1 and K=1 Shapes are Degenerate
|
||||
|
||||
We generated benchmark data for 546 edge-case shapes (N=1, K=1, small N/K).
|
||||
Result: **zero valid kernel results** across 94 shapes. All 4608 kernels either
|
||||
fail or produce 0 TFLOPS for these degenerate dimensions.
|
||||
|
||||
This means:
|
||||
|
||||
- The tile engine kernels have hard minimum dimension requirements
|
||||
- N=1 / K=1 shapes cannot be handled by the current kernel set
|
||||
- These shapes need dedicated kernels (e.g., BLAS-1/BLAS-2 fallbacks)
|
||||
- The ML model should not be expected to handle them -- they should be filtered
|
||||
out before reaching the heuristic
|
||||
|
||||
## 7. Feature Engineering Insights
|
||||
|
||||
From LightGBM feature importances on the log-target model:
|
||||
|
||||
**Top features** (by split count):
|
||||
|
||||
- `M, N, K` -- raw dimensions are always the most important
|
||||
- `tile_m, tile_n, tile_k` -- the tile shape is the primary kernel differentiator
|
||||
- `overall_tile_efficiency` -- how well the shape fits the tile (the interaction)
|
||||
- `num_tiles_m, total_output_tiles` -- work decomposition
|
||||
- `arithmetic_intensity` -- compute vs memory bound regime
|
||||
- `pipeline` -- pipeline type (compv3 vs compv4 vs mem) significantly affects perf
|
||||
|
||||
**Low-importance features**:
|
||||
|
||||
- Hardware constants (CUs, clock, caches) -- they're constant within one arch
|
||||
model, so they provide no discriminative signal. They'll become important when
|
||||
training cross-arch models.
|
||||
- `split_k` -- always 1 in current data
|
||||
- `persistent` -- rarely True in current kernel set
|
||||
|
||||
## 8. Warm-Start Works for Incremental Updates
|
||||
|
||||
LightGBM's `init_model` parameter successfully continues training from an
|
||||
existing model. New trees are added on top of existing ones. Key considerations:
|
||||
|
||||
- Feature schema must match exactly (enforced by `check_feature_compatibility`)
|
||||
- Use fewer new trees (200-500) since we're refining, not starting fresh
|
||||
- The `train_manifest.json` tracks the full lineage (total trees, data sizes)
|
||||
- Quality should be at least as good as the base model (tested)
|
||||
|
||||
## 9. Data Volume Matters More Than Model Complexity
|
||||
|
||||
|
||||
| Dataset | Shapes | Rows | Mean Eff (log, 500 trees) |
|
||||
| --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------ | ---- | ----------------------------- |
|
||||
| Original (DeepSeek only) | 108 | 418K | 98.28% (on seen distribution) |
|
||||
| + Wide coverage M=1 distribution. Adding 60 diverse shapes (many M=1) exposed the model's weakness on tiny shapes. More diverse training data is always better than a bigger model on narrow data.Summary of DefaultsBased on these findings, the current defaults in `train.py` are:- **Target transform**: `log1p` for TFLOPS and bandwidth (scale normalization)- **Model**: 2000 trees, 255 leaves, max depth 15, LR 0.02- **Validation**: 5-fold GroupKFold, key = (M, N, K)- **Early stopping**: patience 100 (let trees fully converge)- **Warm start**: 500 new trees (was 200, increased for bigger base model) | 168 | 626K | 96.92% (harder distribution) |
|
||||
|
||||
|
||||
The original 108-shape model looked great (98.28%) but was overfitting to the
|
||||
DeepSeek LLM inference
|
||||
|
||||
271
dispatcher/heuristics/README.md
Normal file
271
dispatcher/heuristics/README.md
Normal file
@@ -0,0 +1,271 @@
|
||||
# CK Tile Heuristics: ML-Based Kernel Selection
|
||||
|
||||
Fast, accurate kernel selection for CK Tile operations using LightGBM regression
|
||||
with Origami-augmented feature engineering.
|
||||
|
||||
## What This Does
|
||||
|
||||
Instead of running all 4608+ kernel configurations on the GPU to find the best
|
||||
one (exhaustive search taking ~46 seconds per shape), this system trains an ML
|
||||
model that predicts TFLOPS for any (problem, kernel) pair in microseconds. It
|
||||
scores all candidates instantly and picks the best kernel -- achieving 98.28%
|
||||
of oracle-best TFLOPS efficiency across 108 tested shapes.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Generate and convert benchmark data
|
||||
|
||||
**Step 1: Generate benchmark data**
|
||||
|
||||
```bash
|
||||
python3 generate_benchmark_data.py \
|
||||
--build_dir /path/to/build \
|
||||
--output_dir data/fp16_original \
|
||||
--dtype fp16 \
|
||||
--layout rcr \
|
||||
--num_build_jobs 4 \
|
||||
--warmup 10 \
|
||||
--repeat 50
|
||||
```
|
||||
|
||||
This outputs JSON with all benchmark results.
|
||||
|
||||
**Step 2: Convert JSON to parquet training format**
|
||||
|
||||
```bash
|
||||
python3 convert_json_to_parquet.py \
|
||||
--input data/fp16_original/benchmark_results_fp16_rcr.json \
|
||||
--output data/fp16_original/fp16_training_data.parquet \
|
||||
--arch gfx950
|
||||
```
|
||||
|
||||
The converter automatically fixes pad flags for `_mem` kernels and validates data.
|
||||
|
||||
**Alternative: Parse existing logs**
|
||||
|
||||
If you have raw benchmark logs from CK Tile:
|
||||
|
||||
```bash
|
||||
python3 data_pipeline.py ck_tile_testrun_2.log \
|
||||
-o data/gemm_universal_fp8_rcr_gfx950.parquet \
|
||||
--arch gfx950 --capture_hw
|
||||
```
|
||||
|
||||
### 2. Train a model
|
||||
|
||||
```bash
|
||||
python3 train.py \
|
||||
--data_dir data/ \
|
||||
--out_dir models/gemm_universal_fp8_gfx950 \
|
||||
--op gemm_universal --dtype fp8 --arch gfx950
|
||||
```
|
||||
|
||||
**Note**: Trained models are automatically compressed to `.lgbm.gz` format to save space (~67% reduction). The Python tools automatically decompress them on first use and cache the decompressed version. For warm-start training, decompression happens automatically.
|
||||
|
||||
### 3. Evaluate
|
||||
|
||||
```bash
|
||||
python3 evaluate.py \
|
||||
--model_dir models/gemm_universal_fp8_gfx950 \
|
||||
--data_dir data/ --op gemm_universal --dtype fp8
|
||||
```
|
||||
|
||||
### 4. Predict the best kernel for a problem
|
||||
|
||||
```bash
|
||||
python3 predict.py \
|
||||
--model_dir models/gemm_universal_fp8_gfx950 \
|
||||
--m 128 --n 1536 --k 7168 --layout rcr
|
||||
```
|
||||
|
||||
### 5. Search for optimal configs (optional)
|
||||
|
||||
```bash
|
||||
python3 search.py \
|
||||
--model_dir models/gemm_universal_fp8_gfx950 \
|
||||
--m 128 --n 1536 --k 7168 \
|
||||
--strategy random --budget 500 --top_k 10
|
||||
```
|
||||
|
||||
### 6. Using models in C++ (requires decompression)
|
||||
|
||||
C++ code uses the LightGBM C API which requires uncompressed `.lgbm` files. If you have compressed models (`.lgbm.gz`), decompress them first:
|
||||
|
||||
```bash
|
||||
cd models/gemm_universal_fp16_gfx950
|
||||
gunzip model_tflops.lgbm.gz
|
||||
```
|
||||
|
||||
Then use in C++ examples:
|
||||
|
||||
```bash
|
||||
cd dispatcher/build
|
||||
./gemm_09_ml_heuristic --model ../heuristics/models/gemm_universal_fp16_gfx950/model_tflops.lgbm
|
||||
```
|
||||
|
||||
**Note**: Python tools automatically decompress `.lgbm.gz` files on first use, so you can run Python scripts first to trigger decompression, then use the same models in C++.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Problem (M, N, K, dtype, layout)
|
||||
|
|
||||
v
|
||||
FeatureEngine.extract_batch() <-- 55 features: problem, kernel, interaction, hardware
|
||||
|
|
||||
v
|
||||
LGBMRegressor.predict() <-- predicts TFLOPS for each candidate kernel
|
||||
|
|
||||
v
|
||||
Sort by predicted TFLOPS <-- rank all candidates
|
||||
|
|
||||
v
|
||||
Select Top-1 kernel <-- 98.28% mean efficiency, <1ms inference
|
||||
```
|
||||
|
||||
Three models are trained per (op, dtype, arch):
|
||||
- **TFLOPS model** (primary): used for kernel ranking
|
||||
- **Latency model** (auxiliary): for latency-sensitive workloads
|
||||
- **Bandwidth model** (auxiliary): for memory-bound analysis
|
||||
|
||||
## File Inventory
|
||||
|
||||
| File | Purpose |
|
||||
|---|---|
|
||||
| `generate_benchmark_data.py` | Build and run benchmarks across ~25 diverse problem sizes, output JSON |
|
||||
| `convert_json_to_parquet.py` | Convert benchmark JSON to parquet training format, fix `_mem` pad flags |
|
||||
| `data_pipeline.py` | Parse raw benchmark logs into canonical parquet datasets |
|
||||
| `feature_engine.py` | 55-feature extraction: problem, kernel, interaction, hardware profile |
|
||||
| `train.py` | Multi-target LGBMRegressor training with GroupKFold CV, IHEM, warm-start |
|
||||
| `predict.py` | Predictor class: predict TFLOPS/latency/bandwidth, rank kernels |
|
||||
| `evaluate.py` | Full evaluation: global metrics, per-shape/layout/pipeline slices |
|
||||
| `search.py` | Surrogate search: discrete DE, random top-K |
|
||||
| `generate_wide_coverage.py` | Generate benchmark data across 706 diverse shapes |
|
||||
| `generate_edge_dims.py` | Generate N=1, K=1, and other edge-case shapes |
|
||||
| `DATA_GENERATION.md` | Detailed guide for building binaries and generating data |
|
||||
| `plan.md` | Full design plan with architecture, milestones, and rationale |
|
||||
|
||||
## Features Used (55 total)
|
||||
|
||||
### Problem features (13)
|
||||
`M, N, K, split_k, log2(M), log2(N), log2(K), log2(MNK),
|
||||
arithmetic_intensity, aspect_ratio_mn, aspect_ratio_mk, aspect_ratio_nk, layout`
|
||||
|
||||
### Kernel features (17)
|
||||
`tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n,
|
||||
warp_tile_k, pipeline, scheduler, epilogue, pad_m, pad_n, pad_k, persistent,
|
||||
num_warps, tile_volume, tile_mn, lds_usage_estimate, lds_usage_ratio`
|
||||
|
||||
### Interaction features (9)
|
||||
`num_tiles_m, num_tiles_n, num_tiles_k, total_output_tiles,
|
||||
tile_eff_m, tile_eff_n, tile_eff_k, overall_tile_efficiency, cu_utilization`
|
||||
|
||||
### Hardware profile features (12)
|
||||
`hw_num_cus, hw_simds_per_cu, hw_total_simds, hw_shader_engines,
|
||||
hw_max_clock_mhz, hw_max_waves_per_cu, hw_wavefront_size, hw_lds_capacity,
|
||||
hw_l1_cache_kb, hw_l2_cache_kb, hw_l3_cache_kb, hw_num_xcd`
|
||||
|
||||
## Model Performance
|
||||
|
||||
### fp8 RCR, gfx950
|
||||
|
||||
| Metric | 108 shapes (original) | 168 shapes (wide coverage) |
|
||||
|---|---|---|
|
||||
| Mean TFLOPS Efficiency | 98.28% | 97.51% |
|
||||
| P10 TFLOPS Efficiency | 94.64% | 93.89% |
|
||||
| tiny_m (M=1) Efficiency | 95.57% | 96.04% |
|
||||
| R2 (TFLOPS) | 0.997 | 0.993 |
|
||||
|
||||
### fp16 RCR, gfx950
|
||||
|
||||
Trained on 25 shapes, 1,024 kernels, 21,920 valid benchmarks.
|
||||
|
||||
| Metric | Value |
|
||||
|---|---|
|
||||
| Mean TFLOPS Efficiency | 99.36% |
|
||||
| P10 TFLOPS Efficiency | 98.05% |
|
||||
| P50 TFLOPS Efficiency | 100.00% |
|
||||
| Min Efficiency | 95.45% |
|
||||
| NDCG@1 | 64.00% |
|
||||
| Top-5 Hit Rate | 88.00% |
|
||||
|
||||
**Shape Family Breakdown:**
|
||||
|
||||
| Shape Family | Mean Eff | P10 Eff | Shapes |
|
||||
|---|---|---|---|
|
||||
| Large M (M≥1024) | 99.54% | 99.07% | 4 |
|
||||
| Medium M (128≤M<1024) | 99.62% | 98.74% | 7 |
|
||||
| Small M (8≤M<128) | 98.82% | 96.22% | 8 |
|
||||
| Tiny M (M<8) | 99.65% | 98.96% | 6 |
|
||||
|
||||
**Pipeline Breakdown:**
|
||||
|
||||
| Pipeline | Mean Eff | P10 Eff |
|
||||
|---|---|---|
|
||||
| compv3 | 99.75% | 99.09% |
|
||||
| compv4 | 99.40% | 98.54% |
|
||||
| mem | 99.08% | 96.59% |
|
||||
|
||||
Training uses `log1p(TFLOPS)` as the target by default, which normalizes the
|
||||
scale across shapes spanning 0.02 to 2230 TFLOPS. This was the key finding
|
||||
that improved tiny-M shapes from 84% to 96% efficiency. See
|
||||
[LEARNINGS.md](LEARNINGS.md) for details.
|
||||
|
||||
## Validation
|
||||
|
||||
Training uses `GroupKFold(n_splits=5)` with group key `(M, N, K)` to ensure
|
||||
the model is evaluated on shapes it has never seen during training. Layout is
|
||||
excluded from the group key to force cross-layout generalization.
|
||||
|
||||
## Incremental Training (Warm Start)
|
||||
|
||||
When new benchmark data arrives, update the model without retraining from scratch:
|
||||
|
||||
```bash
|
||||
python3 train.py \
|
||||
--data_dir data/ \
|
||||
--out_dir models/v2 \
|
||||
--warm_start models/gemm_universal_fp8_gfx950 \
|
||||
--warm_start_n_estimators 200
|
||||
```
|
||||
|
||||
This adds 200 new trees on top of the existing model. Feature schemas must
|
||||
match exactly (automatically enforced).
|
||||
|
||||
## Extending to New Ops
|
||||
|
||||
Adding support for a new operation (e.g., `gemm_streamk`, `grouped_conv`):
|
||||
|
||||
1. **Build binaries**: `ninja -C build benchmark_gemm_streamk_fp8_rcr`
|
||||
2. **Subclass `FeatureEngine`**: add op-specific features (e.g., StreamK split factor)
|
||||
3. **Generate data**: run benchmarks across diverse shapes
|
||||
4. **Train**: `python3 train.py --op gemm_streamk --dtype fp8 --data_dir data/ --out_dir models/`
|
||||
|
||||
The training, evaluation, prediction, and search infrastructure is fully
|
||||
op-agnostic -- only the feature engine needs a new subclass.
|
||||
|
||||
## Tests
|
||||
|
||||
102 tests covering all modules:
|
||||
|
||||
```bash
|
||||
python3 -m pytest tests/ -v
|
||||
```
|
||||
|
||||
Test coverage includes:
|
||||
- Log parsing with malformed JSON, empty logs, single-kernel shapes
|
||||
- Feature formula correctness (tile efficiency, LDS usage, arithmetic intensity)
|
||||
- Corner-case shapes: M=1, N=1, K=1, prime dimensions, 20480x7168x256
|
||||
- Batch vs single extraction parity
|
||||
- Parameter space validation and projection
|
||||
- Predictor: single/batch prediction, ranking, missing models, empty inputs
|
||||
- Training: group keys, efficiency computation, warm-start, feature compatibility
|
||||
- Search: random, DE, config validity, determinism
|
||||
|
||||
## Documentation
|
||||
|
||||
- **[README.md](README.md)**: This file -- quick start, architecture, performance
|
||||
- **[DATA_GENERATION.md](DATA_GENERATION.md)**: Complete guide for building tile engine
|
||||
binaries, running benchmarks, managing datasets, and troubleshooting
|
||||
- **[LEARNINGS.md](LEARNINGS.md)**: Empirical findings and design decisions (log-transform,
|
||||
IHEM results, tiny-M analysis, feature importance, N=1/K=1 edge cases)
|
||||
4
dispatcher/heuristics/__init__.py
Normal file
4
dispatcher/heuristics/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# CK Tile Heuristics: ML-based kernel selection
|
||||
67
dispatcher/heuristics/collect_additional.sh
Executable file
67
dispatcher/heuristics/collect_additional.sh
Executable file
@@ -0,0 +1,67 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# Generate additional benchmark data for shapes NOT in the original log.
|
||||
# Runs in background; outputs streaming JSON that can be parsed by data_pipeline.py.
|
||||
|
||||
BIN_DIR="/workspace/ck_tile/bin"
|
||||
OUT_LOG="data/additional_shapes.log"
|
||||
WARMUP=3
|
||||
REPEAT=10
|
||||
|
||||
mkdir -p data
|
||||
|
||||
# Additional shapes: square powers-of-2 and common ML sizes not in original DeepSeek set
|
||||
SHAPES=(
|
||||
"64,64,64"
|
||||
"128,128,128"
|
||||
"256,256,256"
|
||||
"512,512,512"
|
||||
"1024,1024,1024"
|
||||
"2048,2048,2048"
|
||||
"4096,4096,4096"
|
||||
"1,4096,4096"
|
||||
"8,4096,4096"
|
||||
"32,4096,4096"
|
||||
"128,4096,4096"
|
||||
"1,4096,11008"
|
||||
"32,4096,11008"
|
||||
"1,8192,8192"
|
||||
"32,8192,8192"
|
||||
"1,8192,28672"
|
||||
"32,8192,28672"
|
||||
"256,256,8192"
|
||||
"8192,8192,256"
|
||||
"1024,4096,1024"
|
||||
"4096,1024,4096"
|
||||
"2048,8192,2048"
|
||||
)
|
||||
|
||||
echo "CK Tile Additional Shapes Benchmark" > "$OUT_LOG"
|
||||
echo "GPU ID: 0" >> "$OUT_LOG"
|
||||
echo "Implementation: gemm_universal" >> "$OUT_LOG"
|
||||
echo "" >> "$OUT_LOG"
|
||||
|
||||
SHAPE_IDX=0
|
||||
for SHAPE in "${SHAPES[@]}"; do
|
||||
IFS=',' read -r M N K <<< "$SHAPE"
|
||||
SHAPE_IDX=$((SHAPE_IDX + 1))
|
||||
|
||||
echo "========================================" >> "$OUT_LOG"
|
||||
echo "Shape $SHAPE_IDX: M=$M N=$N K=$K dtype=fp8 layout=rcr" >> "$OUT_LOG"
|
||||
echo "========================================" >> "$OUT_LOG"
|
||||
|
||||
KERNEL_COUNT=0
|
||||
for EXE in "$BIN_DIR"/benchmark_gemm_universal_fp8_rcr_*; do
|
||||
KERNEL_COUNT=$((KERNEL_COUNT + 1))
|
||||
OUTPUT=$("$EXE" -m="$M" -n="$N" -k="$K" -warmup=$WARMUP -repeat=$REPEAT -verify=0 2>/dev/null)
|
||||
# Extract just the JSON block
|
||||
echo "$OUTPUT" | sed -n '/{/,/^}/p' >> "$OUT_LOG"
|
||||
done
|
||||
|
||||
echo "Found $KERNEL_COUNT kernels" >> "$OUT_LOG"
|
||||
echo "Completed shape $SHAPE_IDX: M=$M N=$N K=$K ($KERNEL_COUNT kernels)" >&2
|
||||
done
|
||||
|
||||
echo "Done generating additional data" >&2
|
||||
233
dispatcher/heuristics/convert_json_to_parquet.py
Normal file
233
dispatcher/heuristics/convert_json_to_parquet.py
Normal file
@@ -0,0 +1,233 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Convert benchmark JSON results to parquet format for training.
|
||||
|
||||
Usage:
|
||||
python convert_json_to_parquet.py \
|
||||
--input benchmark_results_fp16_rcr.json \
|
||||
--output fp16_training_data.parquet
|
||||
|
||||
Features:
|
||||
- Converts JSON benchmark results to flat row format
|
||||
- Automatically fixes pad flags for _mem kernels
|
||||
- Captures both successes and failures
|
||||
- Compatible with existing training data format
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def convert_json_to_parquet(json_file: Path, output_file: Path, arch: str = "gfx950"):
|
||||
"""Convert benchmark JSON to parquet training data format."""
|
||||
|
||||
print(f"Loading {json_file}...")
|
||||
with open(json_file) as f:
|
||||
data = json.load(f)
|
||||
|
||||
metadata = data.get("metadata", {})
|
||||
dtype = metadata.get("dtype", "fp16")
|
||||
layout = metadata.get("layout", "rcr")
|
||||
|
||||
print(f" Data type: {dtype}")
|
||||
print(f" Layout: {layout}")
|
||||
print(f" Kernels: {metadata.get('num_kernels', 0)}")
|
||||
print(f" Problem sizes: {metadata.get('num_problems', 0)}")
|
||||
print()
|
||||
|
||||
rows = []
|
||||
for kernel_result in data["results"]:
|
||||
kernel_config = kernel_result["kernel_config"]
|
||||
|
||||
for benchmark in kernel_result["benchmarks"]:
|
||||
# Common fields for both valid and invalid runs
|
||||
row = {
|
||||
"op_type": "gemm_universal",
|
||||
"dtype": dtype,
|
||||
"layout": layout,
|
||||
"arch": arch,
|
||||
"kernel_name": kernel_config["name"],
|
||||
"m": benchmark["m"],
|
||||
"n": benchmark["n"],
|
||||
"k": benchmark["k"],
|
||||
"split_k": 1,
|
||||
"is_valid": benchmark["is_valid"],
|
||||
"run_id": 0,
|
||||
"pipeline": kernel_config["pipeline"],
|
||||
"epilogue": kernel_config["epilogue"],
|
||||
"scheduler": kernel_config["scheduler"],
|
||||
"pad_m": kernel_config["pad_m"],
|
||||
"pad_n": kernel_config["pad_n"],
|
||||
"pad_k": kernel_config["pad_k"],
|
||||
"persistent": kernel_config["persistent"],
|
||||
"tile_m": kernel_config["tile_m"],
|
||||
"tile_n": kernel_config["tile_n"],
|
||||
"tile_k": kernel_config["tile_k"],
|
||||
"warp_m": kernel_config["warp_m"],
|
||||
"warp_n": kernel_config["warp_n"],
|
||||
"warp_k": kernel_config["warp_k"],
|
||||
"warp_tile_m": kernel_config["warp_tile_m"],
|
||||
"warp_tile_n": kernel_config["warp_tile_n"],
|
||||
"warp_tile_k": kernel_config["warp_tile_k"],
|
||||
}
|
||||
|
||||
if benchmark["is_valid"]:
|
||||
# Valid run - include performance metrics
|
||||
row["measured_tflops"] = benchmark["tflops"]
|
||||
row["latency_ms"] = benchmark["avg_time_ms"]
|
||||
# Calculate bandwidth if needed
|
||||
m, n, k = benchmark["m"], benchmark["n"], benchmark["k"]
|
||||
bytes_transferred = (m * k + k * n + m * n) * 2 # FP16 = 2 bytes
|
||||
if benchmark["avg_time_ms"] > 0:
|
||||
row["bandwidth_gb_s"] = (bytes_transferred / 1e9) / (
|
||||
benchmark["avg_time_ms"] / 1000
|
||||
)
|
||||
else:
|
||||
row["bandwidth_gb_s"] = 0.0
|
||||
else:
|
||||
# Failed run - zero metrics
|
||||
row["measured_tflops"] = 0.0
|
||||
row["latency_ms"] = 0.0
|
||||
row["bandwidth_gb_s"] = 0.0
|
||||
|
||||
rows.append(row)
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
|
||||
print(f"Converted {len(df):,} benchmark results")
|
||||
print(f" Valid: {df['is_valid'].sum():,}")
|
||||
print(f" Failed: {(~df['is_valid']).sum():,}")
|
||||
print()
|
||||
|
||||
# Fix pad flags for _mem kernels (critical for P1 features!)
|
||||
print("Fixing pad flags for _mem kernels...")
|
||||
mem_mask = df["pipeline"] == "mem"
|
||||
mem_count = mem_mask.sum()
|
||||
|
||||
if mem_count > 0:
|
||||
df.loc[mem_mask, "pad_m"] = True
|
||||
df.loc[mem_mask, "pad_n"] = True
|
||||
df.loc[mem_mask, "pad_k"] = True
|
||||
print(f" ✓ Fixed {mem_count:,} _mem kernel rows")
|
||||
print()
|
||||
|
||||
# Save to parquet
|
||||
df.to_parquet(output_file, index=False)
|
||||
print(f"✓ Saved to {output_file}")
|
||||
print()
|
||||
|
||||
# Show statistics
|
||||
print("=" * 80)
|
||||
print("STATISTICS")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
print("Dimension ranges:")
|
||||
print(f" M: {df['m'].min():,} - {df['m'].max():,}")
|
||||
print(f" N: {df['n'].min():,} - {df['n'].max():,}")
|
||||
print(f" K: {df['k'].min():,} - {df['k'].max():,}")
|
||||
print()
|
||||
|
||||
print("Pipeline distribution:")
|
||||
print(df["pipeline"].value_counts())
|
||||
print()
|
||||
|
||||
print("Pad flag distribution:")
|
||||
pad_combos = df[["pad_m", "pad_n", "pad_k"]].value_counts()
|
||||
print(pad_combos)
|
||||
print()
|
||||
|
||||
if (~df["is_valid"]).sum() > 0:
|
||||
print("Failure analysis:")
|
||||
failed = df[~df["is_valid"]]
|
||||
print(f" Total failures: {len(failed):,}")
|
||||
|
||||
# Group by pipeline
|
||||
print("\n By pipeline:")
|
||||
for pipeline, count in failed["pipeline"].value_counts().items():
|
||||
print(f" {pipeline}: {count:,}")
|
||||
|
||||
# Show sample failures
|
||||
print("\n Sample failures:")
|
||||
for _, row in failed.head(5).iterrows():
|
||||
print(
|
||||
f" {row['kernel_name'][:60]:60s} M={row['m']:4d} N={row['n']:4d} K={row['k']:4d}"
|
||||
)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def merge_datasets(parquet_files: list[Path], output_file: Path):
|
||||
"""Merge multiple parquet files into one."""
|
||||
|
||||
print("=" * 80)
|
||||
print("MERGING DATASETS")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
dfs = []
|
||||
for pq_file in parquet_files:
|
||||
if pq_file.exists():
|
||||
df = pd.read_parquet(pq_file)
|
||||
print(f" {pq_file.name}: {len(df):,} rows")
|
||||
dfs.append(df)
|
||||
else:
|
||||
print(f" ✗ {pq_file} not found, skipping")
|
||||
|
||||
if not dfs:
|
||||
print("No files to merge!")
|
||||
return
|
||||
|
||||
combined = pd.concat(dfs, ignore_index=True)
|
||||
combined.to_parquet(output_file, index=False)
|
||||
|
||||
print()
|
||||
print(f"✓ Merged {len(combined):,} total rows to {output_file}")
|
||||
print()
|
||||
|
||||
# Show dtype distribution
|
||||
print("Data type distribution:")
|
||||
print(combined["dtype"].value_counts())
|
||||
print()
|
||||
|
||||
print("Layout distribution:")
|
||||
print(combined["layout"].value_counts())
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert benchmark JSON to parquet training data",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input", type=str, required=True, help="Input JSON file from benchmark"
|
||||
)
|
||||
parser.add_argument("--output", type=str, required=True, help="Output parquet file")
|
||||
parser.add_argument("--arch", type=str, default="gfx950", help="GPU architecture")
|
||||
parser.add_argument(
|
||||
"--merge_with", type=str, nargs="*", help="Additional parquet files to merge"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
input_file = Path(args.input)
|
||||
output_file = Path(args.output)
|
||||
|
||||
# Convert JSON to parquet
|
||||
df = convert_json_to_parquet(input_file, output_file, args.arch)
|
||||
|
||||
# Merge if requested
|
||||
if args.merge_with:
|
||||
merge_files = [output_file] + [Path(f) for f in args.merge_with]
|
||||
merged_output = output_file.parent / f"{output_file.stem}_merged.parquet"
|
||||
merge_datasets(merge_files, merged_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
394
dispatcher/heuristics/data_pipeline.py
Normal file
394
dispatcher/heuristics/data_pipeline.py
Normal file
@@ -0,0 +1,394 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Data pipeline for CK Tile heuristics.
|
||||
|
||||
Parses benchmark logs and structured JSON into a canonical parquet dataset.
|
||||
Supports:
|
||||
- Streaming log format (Shape N: headers + inline JSON) from ck_tile profiling runs
|
||||
- Structured JSON from generate_benchmark_data.py
|
||||
- Direct parquet passthrough
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import subprocess
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
CANONICAL_COLUMNS = [
|
||||
"op_type",
|
||||
"dtype",
|
||||
"layout",
|
||||
"arch",
|
||||
"kernel_name",
|
||||
"m",
|
||||
"n",
|
||||
"k",
|
||||
"split_k",
|
||||
"measured_tflops",
|
||||
"latency_ms",
|
||||
"bandwidth_gb_s",
|
||||
"is_valid",
|
||||
"tile_m",
|
||||
"tile_n",
|
||||
"tile_k",
|
||||
"warp_m",
|
||||
"warp_n",
|
||||
"warp_k",
|
||||
"warp_tile_m",
|
||||
"warp_tile_n",
|
||||
"warp_tile_k",
|
||||
"pipeline",
|
||||
"scheduler",
|
||||
"epilogue",
|
||||
"pad_m",
|
||||
"pad_n",
|
||||
"pad_k",
|
||||
"persistent",
|
||||
"run_id",
|
||||
]
|
||||
|
||||
|
||||
def parse_kernel_name(name: str) -> dict:
|
||||
"""Extract kernel config fields from a gemm_universal kernel name.
|
||||
|
||||
Name format:
|
||||
gemm_universal_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler}
|
||||
_{padM}_{padN}_{padK}_{persistent}_{tileM}x{tileN}x{tileK}
|
||||
_{warpM}x{warpN}x{warpK}_{warpTileM}x{warpTileN}x{warpTileK}
|
||||
"""
|
||||
result = {}
|
||||
try:
|
||||
prefix_match = re.match(
|
||||
r"gemm_universal_(\w+?)_((?:rcr|rrr|crr|ccr))_(.*)", name
|
||||
)
|
||||
if not prefix_match:
|
||||
return result
|
||||
result["dtype"] = prefix_match.group(1)
|
||||
result["layout"] = prefix_match.group(2)
|
||||
remainder = prefix_match.group(3)
|
||||
|
||||
parts = remainder.split("_")
|
||||
if len(parts) < 10:
|
||||
return result
|
||||
|
||||
result["pipeline"] = parts[0]
|
||||
result["epilogue"] = parts[1]
|
||||
result["scheduler"] = parts[2]
|
||||
result["pad_m"] = parts[3] == "True"
|
||||
result["pad_n"] = parts[4] == "True"
|
||||
result["pad_k"] = parts[5] == "True"
|
||||
result["persistent"] = parts[6] == "True"
|
||||
|
||||
tile_dims = parts[7].split("x")
|
||||
warp_dims = parts[8].split("x")
|
||||
warp_tile_dims = parts[9].split("x")
|
||||
|
||||
result["tile_m"] = int(tile_dims[0])
|
||||
result["tile_n"] = int(tile_dims[1])
|
||||
result["tile_k"] = int(tile_dims[2])
|
||||
result["warp_m"] = int(warp_dims[0])
|
||||
result["warp_n"] = int(warp_dims[1])
|
||||
result["warp_k"] = int(warp_dims[2])
|
||||
result["warp_tile_m"] = int(warp_tile_dims[0])
|
||||
result["warp_tile_n"] = int(warp_tile_dims[1])
|
||||
result["warp_tile_k"] = int(warp_tile_dims[2])
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
def _layout_from_problem(problem: dict) -> str:
|
||||
"""Derive layout shorthand (rcr/rrr/etc.) from problem JSON fields."""
|
||||
la = problem.get("layout_a", "")
|
||||
lb = problem.get("layout_b", "")
|
||||
lc = problem.get("layout_c", "")
|
||||
|
||||
def _tag(s):
|
||||
s = s.lower()
|
||||
if "row" in s:
|
||||
return "r"
|
||||
if "col" in s:
|
||||
return "c"
|
||||
return "?"
|
||||
|
||||
return _tag(la) + _tag(lb) + _tag(lc)
|
||||
|
||||
|
||||
def parse_streaming_log(
|
||||
path: str | Path,
|
||||
arch: str = "unknown",
|
||||
run_id: Optional[str] = None,
|
||||
op_type: str = "gemm_universal",
|
||||
) -> pd.DataFrame:
|
||||
"""Parse a CK Tile streaming benchmark log into a canonical DataFrame.
|
||||
|
||||
The log alternates between shape headers and JSON result blocks:
|
||||
Shape N: M=16 N=1536 K=7168 dtype=fp8 layout=rcr
|
||||
{
|
||||
"name": "gemm_universal_...",
|
||||
"problem": { ... },
|
||||
"perf_result": { "latency(ms)": ..., "tflops(TFlops)": ..., "bandwidth(GB/s)": ... }
|
||||
}
|
||||
"""
|
||||
path = Path(path)
|
||||
if run_id is None:
|
||||
run_id = hashlib.md5(path.name.encode()).hexdigest()[:12]
|
||||
|
||||
shape_re = re.compile(
|
||||
r"Shape\s+\d+:\s+M=(\d+)\s+N=(\d+)\s+K=(\d+)\s+dtype=(\w+)\s+layout=(\w+)"
|
||||
)
|
||||
|
||||
rows = []
|
||||
current_m, current_n, current_k = 0, 0, 0
|
||||
current_dtype, current_layout = "", ""
|
||||
json_buf = []
|
||||
brace_depth = 0
|
||||
|
||||
with open(path, "r") as f:
|
||||
for line in f:
|
||||
stripped = line.strip()
|
||||
|
||||
shape_match = shape_re.search(stripped)
|
||||
if shape_match:
|
||||
current_m = int(shape_match.group(1))
|
||||
current_n = int(shape_match.group(2))
|
||||
current_k = int(shape_match.group(3))
|
||||
current_dtype = shape_match.group(4)
|
||||
current_layout = shape_match.group(5)
|
||||
continue
|
||||
|
||||
if brace_depth == 0 and stripped.startswith("{"):
|
||||
json_buf = [stripped]
|
||||
brace_depth = stripped.count("{") - stripped.count("}")
|
||||
if brace_depth == 0:
|
||||
raw = "\n".join(json_buf)
|
||||
try:
|
||||
obj = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
elif brace_depth > 0:
|
||||
json_buf.append(stripped)
|
||||
brace_depth += stripped.count("{") - stripped.count("}")
|
||||
if brace_depth <= 0:
|
||||
brace_depth = 0
|
||||
raw = "\n".join(json_buf)
|
||||
try:
|
||||
obj = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
# If we get here, obj was successfully parsed
|
||||
kernel_name = obj.get("name", "")
|
||||
problem = obj.get("problem", {})
|
||||
perf = obj.get("perf_result", {})
|
||||
|
||||
m = problem.get("m", current_m)
|
||||
n = problem.get("n", current_n)
|
||||
k = problem.get("k", current_k)
|
||||
split_k = problem.get("split_k", 1)
|
||||
dtype = problem.get("dtype_a", current_dtype)
|
||||
layout = (
|
||||
_layout_from_problem(problem)
|
||||
if problem.get("layout_a")
|
||||
else current_layout
|
||||
)
|
||||
|
||||
tflops = perf.get("tflops(TFlops)", 0.0)
|
||||
latency = perf.get("latency(ms)", 0.0)
|
||||
bandwidth = perf.get("bandwidth(GB/s)", 0.0)
|
||||
|
||||
kp = parse_kernel_name(kernel_name)
|
||||
|
||||
row = {
|
||||
"op_type": op_type,
|
||||
"dtype": dtype,
|
||||
"layout": layout,
|
||||
"arch": arch,
|
||||
"kernel_name": kernel_name,
|
||||
"m": m,
|
||||
"n": n,
|
||||
"k": k,
|
||||
"split_k": split_k,
|
||||
"measured_tflops": tflops,
|
||||
"latency_ms": latency,
|
||||
"bandwidth_gb_s": bandwidth,
|
||||
"is_valid": tflops > 0 and latency > 0,
|
||||
"run_id": run_id,
|
||||
}
|
||||
row.update(kp)
|
||||
rows.append(row)
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
for col in CANONICAL_COLUMNS:
|
||||
if col not in df.columns:
|
||||
df[col] = None
|
||||
return df
|
||||
|
||||
|
||||
def get_hardware_profile() -> dict:
|
||||
"""Capture GPU hardware profile from rocminfo."""
|
||||
profile = {}
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["rocminfo"], capture_output=True, text=True, timeout=30
|
||||
)
|
||||
output = result.stdout
|
||||
|
||||
gpu_section = False
|
||||
for line in output.split("\n"):
|
||||
line = line.strip()
|
||||
if "Device Type:" in line and "GPU" in line:
|
||||
gpu_section = True
|
||||
continue
|
||||
if gpu_section and "Device Type:" in line and "GPU" not in line:
|
||||
break
|
||||
if not gpu_section:
|
||||
continue
|
||||
|
||||
if ":" not in line:
|
||||
continue
|
||||
key, _, val = line.partition(":")
|
||||
key = key.strip()
|
||||
val = val.strip()
|
||||
|
||||
mapping = {
|
||||
"Name": "gfx_name",
|
||||
"Marketing Name": "marketing_name",
|
||||
"Compute Unit": "num_cus",
|
||||
"SIMDs per CU": "simds_per_cu",
|
||||
"Shader Engines": "shader_engines",
|
||||
"Shader Arrs. per Eng.": "shader_arrays_per_engine",
|
||||
"Max Clock Freq. (MHz)": "max_clock_mhz",
|
||||
"Wavefront Size": "wavefront_size",
|
||||
"Max Waves Per CU": "max_waves_per_cu",
|
||||
"Chip ID": "chip_id",
|
||||
}
|
||||
|
||||
if key in mapping:
|
||||
raw = val.split("(")[0].strip()
|
||||
try:
|
||||
profile[mapping[key]] = int(raw)
|
||||
except ValueError:
|
||||
profile[mapping[key]] = raw
|
||||
|
||||
for line in output.split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("L1:") and "num_cus" in profile:
|
||||
raw = line.split(":")[1].strip().split("(")[0].strip()
|
||||
try:
|
||||
profile["l1_cache_kb"] = int(raw)
|
||||
except ValueError:
|
||||
pass
|
||||
elif line.startswith("L2:"):
|
||||
raw = line.split(":")[1].strip().split("(")[0].strip()
|
||||
try:
|
||||
profile["l2_cache_kb"] = int(raw)
|
||||
except ValueError:
|
||||
pass
|
||||
elif line.startswith("L3:"):
|
||||
raw = line.split(":")[1].strip().split("(")[0].strip()
|
||||
try:
|
||||
profile["l3_cache_kb"] = int(raw)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
pass
|
||||
|
||||
return profile
|
||||
|
||||
|
||||
def load_parquet(path: str | Path) -> pd.DataFrame:
|
||||
"""Load a canonical parquet dataset."""
|
||||
return pd.read_parquet(path)
|
||||
|
||||
|
||||
def save_parquet(df: pd.DataFrame, path: str | Path):
|
||||
"""Save a DataFrame in canonical parquet format."""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(path, index=False, engine="pyarrow")
|
||||
|
||||
|
||||
def build_training_dataset(
|
||||
data_dir: str | Path,
|
||||
op_type: str = "gemm_universal",
|
||||
dtype: str = "fp8",
|
||||
) -> pd.DataFrame:
|
||||
"""Load and merge all parquet files matching the given op/dtype from a directory."""
|
||||
data_dir = Path(data_dir)
|
||||
frames = []
|
||||
for f in sorted(data_dir.glob("*.parquet")):
|
||||
df = pd.read_parquet(f)
|
||||
if "op_type" in df.columns:
|
||||
df = df[df["op_type"] == op_type]
|
||||
if "dtype" in df.columns:
|
||||
df = df[df["dtype"] == dtype]
|
||||
if len(df) > 0:
|
||||
frames.append(df)
|
||||
if not frames:
|
||||
raise FileNotFoundError(
|
||||
f"No parquet files with op_type={op_type}, dtype={dtype} in {data_dir}"
|
||||
)
|
||||
return pd.concat(frames, ignore_index=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import time
|
||||
|
||||
parser = argparse.ArgumentParser(description="Parse CK Tile benchmark data")
|
||||
parser.add_argument("input", help="Input file (log or parquet)")
|
||||
parser.add_argument("--output", "-o", required=True, help="Output parquet path")
|
||||
parser.add_argument("--arch", default="gfx950", help="GPU architecture")
|
||||
parser.add_argument("--op_type", default="gemm_universal", help="Operation type")
|
||||
parser.add_argument(
|
||||
"--capture_hw",
|
||||
action="store_true",
|
||||
help="Capture hardware profile from rocminfo",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
input_path = Path(args.input)
|
||||
|
||||
print(f"Parsing {input_path}...")
|
||||
t0 = time.time()
|
||||
|
||||
if input_path.suffix == ".parquet":
|
||||
df = load_parquet(input_path)
|
||||
else:
|
||||
df = parse_streaming_log(input_path, arch=args.arch, op_type=args.op_type)
|
||||
|
||||
elapsed = time.time() - t0
|
||||
print(f"Parsed {len(df)} rows in {elapsed:.1f}s")
|
||||
print(f" Unique shapes: {df.groupby(['m', 'n', 'k']).ngroups}")
|
||||
print(f" Unique kernels: {df['kernel_name'].nunique()}")
|
||||
print(f" Valid rows: {df['is_valid'].sum()} / {len(df)}")
|
||||
|
||||
if df["measured_tflops"].max() > 0:
|
||||
print(
|
||||
f" TFLOPS range: {df['measured_tflops'].min():.2f} - {df['measured_tflops'].max():.2f}"
|
||||
)
|
||||
|
||||
if args.capture_hw:
|
||||
hw = get_hardware_profile()
|
||||
print(f" Hardware profile: {hw}")
|
||||
for k, v in hw.items():
|
||||
df[f"hw_{k}"] = v
|
||||
|
||||
save_parquet(df, args.output)
|
||||
print(f"Saved to {args.output}")
|
||||
324
dispatcher/heuristics/dispatcher_integration.py
Normal file
324
dispatcher/heuristics/dispatcher_integration.py
Normal file
@@ -0,0 +1,324 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Dispatcher integration for ML-based kernel selection.
|
||||
|
||||
Bridges the trained LightGBM Predictor with the CK Tile dispatcher's
|
||||
kernel selection flow. Provides heuristic functions compatible with
|
||||
both the Python pre-selection pattern (08_heuristics.py style) and
|
||||
the C++ HeuristicFunction signature.
|
||||
|
||||
Name mapping between feature engine and dispatcher KernelConfig:
|
||||
Feature engine Dispatcher KernelConfig
|
||||
--------------------- ----------------------
|
||||
warp_m (warps/block) wave_m
|
||||
warp_n wave_n
|
||||
warp_k wave_k
|
||||
warp_tile_m warp_m
|
||||
warp_tile_n warp_n
|
||||
warp_tile_k warp_k
|
||||
|
||||
Usage:
|
||||
from dispatcher_integration import create_ml_heuristic
|
||||
|
||||
heuristic = create_ml_heuristic("models/gemm_universal_fp8_gfx950")
|
||||
best_spec = heuristic(M=1024, N=1024, K=1024, kernel_pool=KERNEL_POOL)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from data_pipeline import parse_kernel_name
|
||||
from predict import Predictor
|
||||
|
||||
|
||||
LAYOUT_TO_DISPATCHER = {
|
||||
"rcr": ("row", "col", "row"),
|
||||
"rrr": ("row", "row", "row"),
|
||||
"crr": ("col", "row", "row"),
|
||||
"ccr": ("col", "col", "row"),
|
||||
}
|
||||
|
||||
DTYPE_TO_C_DTYPE = {
|
||||
"fp8": "fp16",
|
||||
"fp16": "fp16",
|
||||
"bf16": "bf16",
|
||||
"fp32": "fp32",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLKernelSpec:
|
||||
"""Kernel spec returned by the ML heuristic, compatible with the dispatcher
|
||||
example pattern. Carries both the feature-engine-space config and the
|
||||
dispatcher-space KernelConfig fields."""
|
||||
|
||||
kernel_name: str
|
||||
predicted_tflops: float
|
||||
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
wave_m: int
|
||||
wave_n: int
|
||||
wave_k: int
|
||||
warp_m: int
|
||||
warp_n: int
|
||||
warp_k: int
|
||||
pipeline: str
|
||||
scheduler: str
|
||||
epilogue: str
|
||||
pad_m: bool
|
||||
pad_n: bool
|
||||
pad_k: bool
|
||||
persistent: bool
|
||||
|
||||
|
||||
def kernel_config_to_feature_dict(kernel_name: str) -> dict:
|
||||
"""Parse a tile-engine kernel name into a feature-engine-compatible dict.
|
||||
|
||||
Returns a dict with fields matching what GemmUniversalFeatureEngine.extract()
|
||||
expects for the kernel parameter: tile_m/n/k, warp_m/n/k (warps per block),
|
||||
warp_tile_m/n/k, pipeline, scheduler, epilogue, pad_m/n/k, persistent.
|
||||
"""
|
||||
parsed = parse_kernel_name(kernel_name)
|
||||
if not parsed:
|
||||
return {}
|
||||
parsed["kernel_name"] = kernel_name
|
||||
return parsed
|
||||
|
||||
|
||||
def feature_dict_to_dispatcher_config(
|
||||
feat: dict, dtype: str = "fp8", arch: str = "gfx950"
|
||||
) -> dict:
|
||||
"""Convert a feature-engine kernel dict to dispatcher KernelConfig fields.
|
||||
|
||||
Handles the naming inversion:
|
||||
feature engine warp_m -> KernelConfig wave_m (warps per block)
|
||||
feature engine warp_tile_m -> KernelConfig warp_m (elements per warp)
|
||||
"""
|
||||
layout = feat.get("layout", "rcr")
|
||||
la, lb, lc = LAYOUT_TO_DISPATCHER.get(layout, ("row", "col", "row"))
|
||||
c_dtype = DTYPE_TO_C_DTYPE.get(dtype, dtype)
|
||||
|
||||
return {
|
||||
"dtype_a": dtype,
|
||||
"dtype_b": dtype,
|
||||
"dtype_c": c_dtype,
|
||||
"dtype_acc": "fp32",
|
||||
"layout_a": la,
|
||||
"layout_b": lb,
|
||||
"layout_c": lc,
|
||||
"tile_m": feat.get("tile_m", 128),
|
||||
"tile_n": feat.get("tile_n", 128),
|
||||
"tile_k": feat.get("tile_k", 64),
|
||||
"wave_m": feat.get("warp_m", 2),
|
||||
"wave_n": feat.get("warp_n", 2),
|
||||
"wave_k": feat.get("warp_k", 1),
|
||||
"warp_m": feat.get("warp_tile_m", 32),
|
||||
"warp_n": feat.get("warp_tile_n", 32),
|
||||
"warp_k": feat.get("warp_tile_k", 16),
|
||||
"pipeline": feat.get("pipeline", "compv3"),
|
||||
"scheduler": feat.get("scheduler", "intrawave"),
|
||||
"epilogue": feat.get("epilogue", "cshuffle"),
|
||||
"pad_m": feat.get("pad_m", True),
|
||||
"pad_n": feat.get("pad_n", True),
|
||||
"pad_k": feat.get("pad_k", True),
|
||||
"gfx_arch": arch,
|
||||
}
|
||||
|
||||
|
||||
def feature_dict_to_ml_spec(feat: dict, predicted_tflops: float = 0.0) -> MLKernelSpec:
|
||||
"""Convert a feature-engine kernel dict + prediction to an MLKernelSpec."""
|
||||
return MLKernelSpec(
|
||||
kernel_name=feat.get("kernel_name", "unknown"),
|
||||
predicted_tflops=predicted_tflops,
|
||||
tile_m=feat.get("tile_m", 128),
|
||||
tile_n=feat.get("tile_n", 128),
|
||||
tile_k=feat.get("tile_k", 64),
|
||||
wave_m=feat.get("warp_m", 2),
|
||||
wave_n=feat.get("warp_n", 2),
|
||||
wave_k=feat.get("warp_k", 1),
|
||||
warp_m=feat.get("warp_tile_m", 32),
|
||||
warp_n=feat.get("warp_tile_n", 32),
|
||||
warp_k=feat.get("warp_tile_k", 16),
|
||||
pipeline=feat.get("pipeline", "compv3"),
|
||||
scheduler=feat.get("scheduler", "intrawave"),
|
||||
epilogue=feat.get("epilogue", "cshuffle"),
|
||||
pad_m=feat.get("pad_m", False),
|
||||
pad_n=feat.get("pad_n", False),
|
||||
pad_k=feat.get("pad_k", False),
|
||||
persistent=feat.get("persistent", False),
|
||||
)
|
||||
|
||||
|
||||
def load_kernel_pool_from_binaries(bin_dir: str | Path) -> list[dict]:
|
||||
"""Discover benchmark executables and parse their names into feature dicts.
|
||||
|
||||
Each executable name encodes the full kernel config. This creates the
|
||||
candidate pool for the ML heuristic without needing a registry JSON export.
|
||||
"""
|
||||
bin_dir = Path(bin_dir)
|
||||
configs = []
|
||||
for exe in sorted(bin_dir.glob("benchmark_gemm_universal_*")):
|
||||
name = exe.stem.replace("benchmark_", "")
|
||||
feat = kernel_config_to_feature_dict(name)
|
||||
if feat and "tile_m" in feat:
|
||||
configs.append(feat)
|
||||
return configs
|
||||
|
||||
|
||||
def create_ml_heuristic(
|
||||
model_dir: str | Path,
|
||||
dtype: str = "fp8",
|
||||
arch: str = "gfx950",
|
||||
layout: str = "rcr",
|
||||
kernel_pool: Optional[list[dict]] = None,
|
||||
bin_dir: Optional[str | Path] = None,
|
||||
):
|
||||
"""Create an ML heuristic function for kernel selection.
|
||||
|
||||
Returns a callable with signature:
|
||||
(M: int, N: int, K: int) -> MLKernelSpec
|
||||
|
||||
The returned function scores all candidate kernels using the trained
|
||||
LightGBM regressor and returns the best one as an MLKernelSpec.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_dir : str or Path
|
||||
Path to trained model directory (must contain model_tflops.lgbm or
|
||||
model_tflops_log_big.lgbm and feature_spec.json).
|
||||
dtype : str
|
||||
Data type for the problem (fp8, fp16, bf16).
|
||||
arch : str
|
||||
GPU architecture (gfx942, gfx950).
|
||||
layout : str
|
||||
Matrix layout (rcr, rrr, crr, ccr).
|
||||
kernel_pool : list of dict, optional
|
||||
Pre-parsed kernel configs. If None, loads from bin_dir.
|
||||
bin_dir : str or Path, optional
|
||||
Directory with benchmark executables. Used to build kernel_pool if
|
||||
kernel_pool is not provided. Defaults to /workspace/ck_tile/bin.
|
||||
"""
|
||||
model_dir = Path(model_dir)
|
||||
predictor = Predictor(model_dir)
|
||||
|
||||
if kernel_pool is None:
|
||||
if bin_dir is None:
|
||||
bin_dir = Path("/workspace/ck_tile/bin")
|
||||
kernel_pool = load_kernel_pool_from_binaries(bin_dir)
|
||||
|
||||
if not kernel_pool:
|
||||
raise ValueError(
|
||||
"No kernel configs found. Check bin_dir or provide kernel_pool."
|
||||
)
|
||||
|
||||
def heuristic(M: int, N: int, K: int) -> MLKernelSpec:
|
||||
problem = {
|
||||
"m": M,
|
||||
"n": N,
|
||||
"k": K,
|
||||
"dtype": dtype,
|
||||
"layout": layout,
|
||||
"split_k": 1,
|
||||
}
|
||||
|
||||
ranked = predictor.rank_kernels(problem, kernel_pool)
|
||||
|
||||
if not ranked:
|
||||
feat = kernel_pool[0]
|
||||
return feature_dict_to_ml_spec(feat, 0.0)
|
||||
|
||||
best_name, best_tflops = ranked[0]
|
||||
best_feat = next(
|
||||
(kp for kp in kernel_pool if kp.get("kernel_name") == best_name),
|
||||
kernel_pool[0],
|
||||
)
|
||||
return feature_dict_to_ml_spec(best_feat, best_tflops)
|
||||
|
||||
return heuristic
|
||||
|
||||
|
||||
def create_ranked_heuristic(
|
||||
model_dir: str | Path,
|
||||
dtype: str = "fp8",
|
||||
arch: str = "gfx950",
|
||||
layout: str = "rcr",
|
||||
kernel_pool: Optional[list[dict]] = None,
|
||||
bin_dir: Optional[str | Path] = None,
|
||||
top_k: int = 5,
|
||||
):
|
||||
"""Create an ML heuristic that returns the top-K ranked kernel specs.
|
||||
|
||||
Returns a callable with signature:
|
||||
(M: int, N: int, K: int) -> list[MLKernelSpec]
|
||||
|
||||
Useful when you want fallback options if the top-1 kernel fails to build.
|
||||
"""
|
||||
model_dir = Path(model_dir)
|
||||
predictor = Predictor(model_dir)
|
||||
|
||||
if kernel_pool is None:
|
||||
if bin_dir is None:
|
||||
bin_dir = Path("/workspace/ck_tile/bin")
|
||||
kernel_pool = load_kernel_pool_from_binaries(bin_dir)
|
||||
|
||||
name_to_feat = {kp.get("kernel_name", ""): kp for kp in kernel_pool}
|
||||
|
||||
def heuristic(M: int, N: int, K: int) -> list[MLKernelSpec]:
|
||||
problem = {
|
||||
"m": M,
|
||||
"n": N,
|
||||
"k": K,
|
||||
"dtype": dtype,
|
||||
"layout": layout,
|
||||
"split_k": 1,
|
||||
}
|
||||
|
||||
ranked = predictor.rank_kernels(problem, kernel_pool)
|
||||
results = []
|
||||
for name, tflops in ranked[:top_k]:
|
||||
feat = name_to_feat.get(name, kernel_pool[0])
|
||||
results.append(feature_dict_to_ml_spec(feat, tflops))
|
||||
return results
|
||||
|
||||
return heuristic
|
||||
|
||||
|
||||
def ml_spec_to_dispatcher_config(
|
||||
spec: MLKernelSpec, dtype: str = "fp8", arch: str = "gfx950"
|
||||
) -> dict:
|
||||
"""Convert an MLKernelSpec to a dict compatible with ctypes_utils.KernelConfig."""
|
||||
layout_a, layout_b, layout_c = "row", "col", "row"
|
||||
c_dtype = DTYPE_TO_C_DTYPE.get(dtype, dtype)
|
||||
|
||||
return {
|
||||
"dtype_a": dtype,
|
||||
"dtype_b": dtype,
|
||||
"dtype_c": c_dtype,
|
||||
"dtype_acc": "fp32",
|
||||
"layout_a": layout_a,
|
||||
"layout_b": layout_b,
|
||||
"layout_c": layout_c,
|
||||
"tile_m": spec.tile_m,
|
||||
"tile_n": spec.tile_n,
|
||||
"tile_k": spec.tile_k,
|
||||
"wave_m": spec.wave_m,
|
||||
"wave_n": spec.wave_n,
|
||||
"wave_k": spec.wave_k,
|
||||
"warp_m": spec.warp_m,
|
||||
"warp_n": spec.warp_n,
|
||||
"warp_k": spec.warp_k,
|
||||
"pipeline": spec.pipeline,
|
||||
"scheduler": spec.scheduler,
|
||||
"epilogue": spec.epilogue,
|
||||
"pad_m": spec.pad_m,
|
||||
"pad_n": spec.pad_n,
|
||||
"pad_k": spec.pad_k,
|
||||
"gfx_arch": arch,
|
||||
}
|
||||
254
dispatcher/heuristics/evaluate.py
Normal file
254
dispatcher/heuristics/evaluate.py
Normal file
@@ -0,0 +1,254 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Evaluation and reporting for CK Tile kernel performance models.
|
||||
|
||||
Computes:
|
||||
- Global metrics: TFLOPS efficiency (mean, p10, p50, min), R2, NDCG@1, Top-K hit rate
|
||||
- Per-slice breakdowns: by layout, shape family, K-depth regime, pipeline
|
||||
- Cross-target consistency checks
|
||||
- Feature importance analysis
|
||||
|
||||
Usage:
|
||||
python evaluate.py --model_dir models/gemm_universal_fp8_gfx950 --data_dir data/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from data_pipeline import build_training_dataset
|
||||
from feature_engine import GemmUniversalFeatureEngine
|
||||
from predict import Predictor
|
||||
from train import compute_tflops_efficiency
|
||||
|
||||
|
||||
def classify_shape_family(m: int, n: int, k: int) -> str:
|
||||
"""Classify a GEMM shape into a family for sliced evaluation.
|
||||
|
||||
Families:
|
||||
- tiny_m: M < 32 (single-token / very small batch inference)
|
||||
- small_m: 32 <= M < 256
|
||||
- medium_m: 256 <= M < 4096
|
||||
- large_m: M >= 4096
|
||||
- square: 0.5 <= M/N <= 2.0 and 0.5 <= M/K <= 2.0
|
||||
- tall: M/N > 2.0
|
||||
- wide: M/N < 0.5
|
||||
"""
|
||||
if m < 32:
|
||||
return "tiny_m"
|
||||
elif m < 256:
|
||||
return "small_m"
|
||||
elif m < 4096:
|
||||
return "medium_m"
|
||||
elif m >= 4096:
|
||||
return "large_m"
|
||||
return "other"
|
||||
|
||||
|
||||
def classify_k_regime(k: int) -> str:
|
||||
"""Classify K dimension into depth regime."""
|
||||
if k < 512:
|
||||
return "shallow_k"
|
||||
elif k < 4096:
|
||||
return "medium_k"
|
||||
else:
|
||||
return "deep_k"
|
||||
|
||||
|
||||
def evaluate_model(
|
||||
predictor: Predictor,
|
||||
df: pd.DataFrame,
|
||||
feature_engine: GemmUniversalFeatureEngine,
|
||||
) -> dict:
|
||||
"""Run full evaluation on a dataset. Returns a metrics dictionary.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
predictor : Predictor
|
||||
Trained predictor with at least a TFLOPS model loaded.
|
||||
df : pd.DataFrame
|
||||
Benchmark data in canonical schema.
|
||||
feature_engine : GemmUniversalFeatureEngine
|
||||
Feature engine matching the trained model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict with keys: global_metrics, shape_family_metrics, k_regime_metrics,
|
||||
pipeline_metrics, per_shape_efficiency.
|
||||
"""
|
||||
valid = df[df["is_valid"].fillna(False) & (df["measured_tflops"] > 0)].copy()
|
||||
valid = valid.reset_index(drop=True)
|
||||
|
||||
X = feature_engine.extract_batch(valid)
|
||||
model = predictor._load_model("tflops")
|
||||
if model is None:
|
||||
raise FileNotFoundError("No TFLOPS model found")
|
||||
|
||||
# Predict and apply inverse log transform if model was trained in log-space
|
||||
raw_pred = model.predict(X)
|
||||
if "tflops" in predictor._log_targets:
|
||||
valid["pred_tflops"] = np.expm1(raw_pred)
|
||||
else:
|
||||
# Clamp to non-negative even for non-log models
|
||||
valid["pred_tflops"] = np.maximum(0.0, raw_pred)
|
||||
|
||||
y_true = valid["measured_tflops"].values
|
||||
y_pred = valid["pred_tflops"].values
|
||||
|
||||
ss_res = np.sum((y_true - y_pred) ** 2)
|
||||
ss_tot = np.sum((y_true - y_true.mean()) ** 2)
|
||||
r2 = 1 - ss_res / max(ss_tot, 1e-10)
|
||||
rmse = np.sqrt(np.mean((y_true - y_pred) ** 2))
|
||||
mae = np.mean(np.abs(y_true - y_pred))
|
||||
|
||||
eff_df = compute_tflops_efficiency(valid, "pred_tflops")
|
||||
|
||||
ndcg1_count = 0
|
||||
total_shapes = 0
|
||||
topk_hits = {3: 0, 5: 0, 10: 0}
|
||||
|
||||
for (m, n, k), group in valid.groupby(["m", "n", "k"]):
|
||||
if group["measured_tflops"].max() <= 0:
|
||||
continue
|
||||
total_shapes += 1
|
||||
oracle_idx = group["measured_tflops"].idxmax()
|
||||
pred_ranking = group.sort_values("pred_tflops", ascending=False).index.tolist()
|
||||
|
||||
if pred_ranking[0] == oracle_idx:
|
||||
ndcg1_count += 1
|
||||
|
||||
oracle_rank = pred_ranking.index(oracle_idx)
|
||||
for topk in topk_hits:
|
||||
if oracle_rank < topk:
|
||||
topk_hits[topk] += 1
|
||||
|
||||
global_metrics = {
|
||||
"r2": r2,
|
||||
"rmse": rmse,
|
||||
"mae": mae,
|
||||
"num_valid_rows": len(valid),
|
||||
"num_shapes": total_shapes,
|
||||
"efficiency_mean": float(eff_df["efficiency"].mean()) if len(eff_df) > 0 else 0,
|
||||
"efficiency_p10": float(eff_df["efficiency"].quantile(0.1))
|
||||
if len(eff_df) > 0
|
||||
else 0,
|
||||
"efficiency_p50": float(eff_df["efficiency"].quantile(0.5))
|
||||
if len(eff_df) > 0
|
||||
else 0,
|
||||
"efficiency_min": float(eff_df["efficiency"].min()) if len(eff_df) > 0 else 0,
|
||||
"ndcg_at_1": ndcg1_count / max(total_shapes, 1),
|
||||
"top3_hit_rate": topk_hits[3] / max(total_shapes, 1),
|
||||
"top5_hit_rate": topk_hits[5] / max(total_shapes, 1),
|
||||
"top10_hit_rate": topk_hits[10] / max(total_shapes, 1),
|
||||
}
|
||||
|
||||
def _slice_efficiency(slice_df):
|
||||
if len(slice_df) == 0:
|
||||
return {"count": 0}
|
||||
eff = compute_tflops_efficiency(slice_df, "pred_tflops")
|
||||
if len(eff) == 0:
|
||||
return {"count": 0}
|
||||
return {
|
||||
"count": len(eff),
|
||||
"mean": float(eff["efficiency"].mean()),
|
||||
"p10": float(eff["efficiency"].quantile(0.1)),
|
||||
"min": float(eff["efficiency"].min()),
|
||||
}
|
||||
|
||||
valid["shape_family"] = valid.apply(
|
||||
lambda r: classify_shape_family(r["m"], r["n"], r["k"]), axis=1
|
||||
)
|
||||
valid["k_regime"] = valid["k"].apply(classify_k_regime)
|
||||
|
||||
shape_family_metrics = {}
|
||||
for family, group in valid.groupby("shape_family"):
|
||||
shape_family_metrics[family] = _slice_efficiency(group)
|
||||
|
||||
k_regime_metrics = {}
|
||||
for regime, group in valid.groupby("k_regime"):
|
||||
k_regime_metrics[regime] = _slice_efficiency(group)
|
||||
|
||||
pipeline_metrics = {}
|
||||
if "pipeline" in valid.columns:
|
||||
for pipeline, group in valid.groupby("pipeline"):
|
||||
pipeline_metrics[str(pipeline)] = _slice_efficiency(group)
|
||||
|
||||
return {
|
||||
"global_metrics": global_metrics,
|
||||
"shape_family_metrics": shape_family_metrics,
|
||||
"k_regime_metrics": k_regime_metrics,
|
||||
"pipeline_metrics": pipeline_metrics,
|
||||
"per_shape_efficiency": eff_df.to_dict(orient="records")
|
||||
if len(eff_df) > 0
|
||||
else [],
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Evaluate CK Tile performance model")
|
||||
parser.add_argument(
|
||||
"--model_dir", required=True, help="Directory with trained models"
|
||||
)
|
||||
parser.add_argument("--data_dir", required=True, help="Directory with parquet data")
|
||||
parser.add_argument("--op", default="gemm_universal")
|
||||
parser.add_argument("--dtype", default="fp8")
|
||||
parser.add_argument("--output", "-o", help="Output JSON path for metrics")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Loading data from {args.data_dir}...")
|
||||
df = build_training_dataset(args.data_dir, op_type=args.op, dtype=args.dtype)
|
||||
print(f" {len(df)} rows, {df.groupby(['m', 'n', 'k']).ngroups} shapes")
|
||||
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
predictor = Predictor(args.model_dir, feature_engine=fe)
|
||||
|
||||
print("Evaluating...")
|
||||
results = evaluate_model(predictor, df, fe)
|
||||
|
||||
gm = results["global_metrics"]
|
||||
print("\nGlobal Metrics:")
|
||||
print(f" R2: {gm['r2']:.4f}")
|
||||
print(f" RMSE: {gm['rmse']:.2f}")
|
||||
print(f" Efficiency Mean: {gm['efficiency_mean']:.4f}")
|
||||
print(f" Efficiency P10: {gm['efficiency_p10']:.4f}")
|
||||
print(f" Efficiency P50: {gm['efficiency_p50']:.4f}")
|
||||
print(f" Efficiency Min: {gm['efficiency_min']:.4f}")
|
||||
print(f" NDCG@1: {gm['ndcg_at_1']:.4f}")
|
||||
print(f" Top-3 Hit Rate: {gm['top3_hit_rate']:.4f}")
|
||||
print(f" Top-5 Hit Rate: {gm['top5_hit_rate']:.4f}")
|
||||
print(f" Top-10 Hit Rate: {gm['top10_hit_rate']:.4f}")
|
||||
|
||||
print("\nShape Family Breakdown:")
|
||||
for family, metrics in sorted(results["shape_family_metrics"].items()):
|
||||
if metrics.get("count", 0) > 0:
|
||||
print(
|
||||
f" {family:12s}: mean={metrics['mean']:.4f} p10={metrics['p10']:.4f} min={metrics['min']:.4f} (n={metrics['count']})"
|
||||
)
|
||||
|
||||
print("\nK-Depth Regime Breakdown:")
|
||||
for regime, metrics in sorted(results["k_regime_metrics"].items()):
|
||||
if metrics.get("count", 0) > 0:
|
||||
print(
|
||||
f" {regime:12s}: mean={metrics['mean']:.4f} p10={metrics['p10']:.4f} min={metrics['min']:.4f} (n={metrics['count']})"
|
||||
)
|
||||
|
||||
print("\nPipeline Breakdown:")
|
||||
for pipeline, metrics in sorted(results["pipeline_metrics"].items()):
|
||||
if metrics.get("count", 0) > 0:
|
||||
print(
|
||||
f" {pipeline:15s}: mean={metrics['mean']:.4f} p10={metrics['p10']:.4f} (n={metrics['count']})"
|
||||
)
|
||||
|
||||
if args.output:
|
||||
with open(args.output, "w") as f:
|
||||
json.dump(results, f, indent=2, default=str)
|
||||
print(f"\nFull results saved to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
577
dispatcher/heuristics/feature_engine.py
Normal file
577
dispatcher/heuristics/feature_engine.py
Normal file
@@ -0,0 +1,577 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Feature engineering for CK Tile kernel performance prediction.
|
||||
|
||||
Provides a strict FeatureEngine interface with per-op subclasses.
|
||||
All feature engines produce a consistent numpy array for LightGBM.
|
||||
"""
|
||||
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
DTYPE_BYTES = {
|
||||
"fp32": 4.0,
|
||||
"fp16": 2.0,
|
||||
"bf16": 2.0,
|
||||
"fp8": 1.0,
|
||||
"bf8": 1.0,
|
||||
"int8": 1.0,
|
||||
"int4": 0.5,
|
||||
}
|
||||
|
||||
LAYOUT_MAP = {"rcr": 0, "rrr": 1, "crr": 2, "ccr": 3}
|
||||
PIPELINE_MAP = {"compv3": 0, "compv4": 1, "compv5": 2, "mem": 3, "preshufflev2": 4}
|
||||
SCHEDULER_MAP = {"intrawave": 0, "interwave": 1}
|
||||
EPILOGUE_MAP = {"default": 0, "cshuffle": 1}
|
||||
|
||||
|
||||
class FeatureEngine(ABC):
|
||||
"""Abstract base for per-op feature extraction."""
|
||||
|
||||
@abstractmethod
|
||||
def get_feature_names(self) -> list[str]:
|
||||
"""Ordered list of feature names matching the output array columns."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_categorical_features(self) -> list[str]:
|
||||
"""Feature names that should be treated as categorical by LightGBM."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def extract(self, problem: dict, kernel: dict) -> np.ndarray:
|
||||
"""Extract a single feature vector from a (problem, kernel) pair."""
|
||||
...
|
||||
|
||||
def extract_batch(self, df: pd.DataFrame) -> np.ndarray:
|
||||
"""Vectorized batch extraction from a DataFrame. Override for speed."""
|
||||
names = self.get_feature_names()
|
||||
result = np.zeros((len(df), len(names)), dtype=np.float64)
|
||||
for i in range(len(df)):
|
||||
row = df.iloc[i]
|
||||
prob = row.to_dict()
|
||||
kern = row.to_dict()
|
||||
result[i] = self.extract(prob, kern)
|
||||
return result
|
||||
|
||||
def get_parameter_space(self) -> dict[str, list]:
|
||||
"""Valid discrete values for each kernel parameter (for surrogate search)."""
|
||||
return {}
|
||||
|
||||
def get_constraints(self) -> list:
|
||||
"""Multi-param constraint functions returning True if config is valid."""
|
||||
return []
|
||||
|
||||
def validate_config(self, config: dict) -> bool:
|
||||
"""Check all constraints. Returns True if the config is valid."""
|
||||
ps = self.get_parameter_space()
|
||||
for k, valid_vals in ps.items():
|
||||
if k in config and config[k] not in valid_vals:
|
||||
return False
|
||||
for constraint in self.get_constraints():
|
||||
if not constraint(config):
|
||||
return False
|
||||
return True
|
||||
|
||||
def project_to_valid(self, config: dict) -> dict:
|
||||
"""Snap a config to the nearest valid discrete point."""
|
||||
ps = self.get_parameter_space()
|
||||
result = dict(config)
|
||||
for k, valid_vals in ps.items():
|
||||
if k not in result:
|
||||
continue
|
||||
v = result[k]
|
||||
if isinstance(valid_vals[0], (int, float)):
|
||||
result[k] = min(valid_vals, key=lambda x: abs(x - v))
|
||||
elif v not in valid_vals:
|
||||
result[k] = valid_vals[0]
|
||||
return result
|
||||
|
||||
|
||||
class GemmUniversalFeatureEngine(FeatureEngine):
|
||||
"""Feature engine for gemm_universal kernels."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_cus: int = 256,
|
||||
lds_capacity: int = 65536,
|
||||
max_clock_mhz: int = 2400,
|
||||
simds_per_cu: int = 4,
|
||||
shader_engines: int = 32,
|
||||
max_waves_per_cu: int = 32,
|
||||
wavefront_size: int = 64,
|
||||
l1_cache_kb: int = 32,
|
||||
l2_cache_kb: int = 4096,
|
||||
l3_cache_kb: int = 262144,
|
||||
num_xcd: int = 8,
|
||||
):
|
||||
self._hw = {
|
||||
"num_cus": num_cus,
|
||||
"lds_capacity": lds_capacity,
|
||||
"max_clock_mhz": max_clock_mhz,
|
||||
"simds_per_cu": simds_per_cu,
|
||||
"shader_engines": shader_engines,
|
||||
"max_waves_per_cu": max_waves_per_cu,
|
||||
"wavefront_size": wavefront_size,
|
||||
"l1_cache_kb": l1_cache_kb,
|
||||
"l2_cache_kb": l2_cache_kb,
|
||||
"l3_cache_kb": l3_cache_kb,
|
||||
"num_xcd": num_xcd,
|
||||
"total_simds": num_cus * simds_per_cu,
|
||||
}
|
||||
|
||||
def get_feature_names(self) -> list[str]:
|
||||
return [
|
||||
# Problem features
|
||||
"M",
|
||||
"N",
|
||||
"K",
|
||||
"split_k",
|
||||
"log2_M",
|
||||
"log2_N",
|
||||
"log2_K",
|
||||
"log2_MNK",
|
||||
"arithmetic_intensity",
|
||||
"aspect_ratio_mn",
|
||||
"aspect_ratio_mk",
|
||||
"aspect_ratio_nk",
|
||||
"layout",
|
||||
# Kernel features
|
||||
"tile_m",
|
||||
"tile_n",
|
||||
"tile_k",
|
||||
"warp_m",
|
||||
"warp_n",
|
||||
"warp_k",
|
||||
"warp_tile_m",
|
||||
"warp_tile_n",
|
||||
"warp_tile_k",
|
||||
"pipeline",
|
||||
"scheduler",
|
||||
"epilogue",
|
||||
"pad_m",
|
||||
"pad_n",
|
||||
"pad_k",
|
||||
"persistent",
|
||||
"num_warps",
|
||||
"tile_volume",
|
||||
"tile_mn",
|
||||
"lds_usage_estimate",
|
||||
"lds_usage_ratio",
|
||||
# Interaction features
|
||||
"num_tiles_m",
|
||||
"num_tiles_n",
|
||||
"num_tiles_k",
|
||||
"total_output_tiles",
|
||||
"tile_eff_m",
|
||||
"tile_eff_n",
|
||||
"tile_eff_k",
|
||||
"overall_tile_efficiency",
|
||||
"cu_utilization",
|
||||
# P0 FIX: Problem-to-tile ratio features
|
||||
"ratio_M_to_tile_m",
|
||||
"ratio_N_to_tile_n",
|
||||
"ratio_K_to_tile_k",
|
||||
"problem_smaller_than_tile_m",
|
||||
"problem_smaller_than_tile_n",
|
||||
"problem_smaller_than_tile_k",
|
||||
"any_dim_too_small",
|
||||
# P1 FIX: Padding requirement interaction features
|
||||
"needs_padding_m",
|
||||
"needs_padding_n",
|
||||
"needs_padding_k",
|
||||
"has_padding_when_needed_m",
|
||||
"has_padding_when_needed_n",
|
||||
"has_padding_when_needed_k",
|
||||
"missing_required_padding_m",
|
||||
"missing_required_padding_n",
|
||||
"missing_required_padding_k",
|
||||
"missing_any_required_padding",
|
||||
# Hardware features
|
||||
"hw_num_cus",
|
||||
"hw_simds_per_cu",
|
||||
"hw_total_simds",
|
||||
"hw_shader_engines",
|
||||
"hw_max_clock_mhz",
|
||||
"hw_max_waves_per_cu",
|
||||
"hw_wavefront_size",
|
||||
"hw_lds_capacity",
|
||||
"hw_l1_cache_kb",
|
||||
"hw_l2_cache_kb",
|
||||
"hw_l3_cache_kb",
|
||||
"hw_num_xcd",
|
||||
]
|
||||
|
||||
def get_categorical_features(self) -> list[str]:
|
||||
return ["layout", "pipeline", "scheduler", "epilogue"]
|
||||
|
||||
def extract(self, problem: dict, kernel: dict) -> np.ndarray:
|
||||
M = int(problem.get("m", problem.get("M", 0)))
|
||||
N = int(problem.get("n", problem.get("N", 0)))
|
||||
K = int(problem.get("k", problem.get("K", 0)))
|
||||
split_k = int(problem.get("split_k", 1))
|
||||
dtype = str(problem.get("dtype", "fp8"))
|
||||
bpe = DTYPE_BYTES.get(dtype, 1.0)
|
||||
|
||||
log2_M = math.log2(max(M, 1))
|
||||
log2_N = math.log2(max(N, 1))
|
||||
log2_K = math.log2(max(K, 1))
|
||||
log2_MNK = math.log2(max(M * N * K, 1))
|
||||
|
||||
mem_bytes = (M * K + K * N + M * N) * bpe
|
||||
ai = (2.0 * M * N * K) / max(mem_bytes, 1)
|
||||
|
||||
ar_mn = M / max(N, 1)
|
||||
ar_mk = M / max(K, 1)
|
||||
ar_nk = N / max(K, 1)
|
||||
|
||||
layout_code = LAYOUT_MAP.get(str(problem.get("layout", "rcr")), 0)
|
||||
|
||||
tile_m = int(kernel.get("tile_m", 128))
|
||||
tile_n = int(kernel.get("tile_n", 128))
|
||||
tile_k = int(kernel.get("tile_k", 64))
|
||||
warp_m = int(kernel.get("warp_m", 2))
|
||||
warp_n = int(kernel.get("warp_n", 2))
|
||||
warp_k = int(kernel.get("warp_k", 1))
|
||||
warp_tile_m = int(kernel.get("warp_tile_m", 32))
|
||||
warp_tile_n = int(kernel.get("warp_tile_n", 32))
|
||||
warp_tile_k = int(kernel.get("warp_tile_k", 16))
|
||||
|
||||
pipeline_code = PIPELINE_MAP.get(str(kernel.get("pipeline", "compv4")), 0)
|
||||
scheduler_code = SCHEDULER_MAP.get(str(kernel.get("scheduler", "intrawave")), 0)
|
||||
epilogue_code = EPILOGUE_MAP.get(str(kernel.get("epilogue", "cshuffle")), 0)
|
||||
|
||||
pad_m = float(kernel.get("pad_m", False))
|
||||
pad_n = float(kernel.get("pad_n", False))
|
||||
pad_k = float(kernel.get("pad_k", False))
|
||||
persistent = float(kernel.get("persistent", False))
|
||||
|
||||
num_warps = warp_m * warp_n * warp_k
|
||||
tile_volume = tile_m * tile_n * tile_k
|
||||
tile_mn = tile_m * tile_n
|
||||
|
||||
lds_est = (tile_m * tile_k + tile_n * tile_k) * bpe
|
||||
lds_cap = self._hw["lds_capacity"]
|
||||
if str(kernel.get("pipeline", "")).startswith("compv4"):
|
||||
lds_cap = 32768
|
||||
lds_ratio = lds_est / max(lds_cap, 1)
|
||||
|
||||
num_tiles_m = math.ceil(M / max(tile_m, 1))
|
||||
num_tiles_n = math.ceil(N / max(tile_n, 1))
|
||||
num_tiles_k = math.ceil(K / max(tile_k, 1))
|
||||
total_output_tiles = num_tiles_m * num_tiles_n
|
||||
|
||||
rem_m = M % tile_m if tile_m > 0 else 0
|
||||
tile_eff_m = rem_m / tile_m if rem_m > 0 else 1.0
|
||||
rem_n = N % tile_n if tile_n > 0 else 0
|
||||
tile_eff_n = rem_n / tile_n if rem_n > 0 else 1.0
|
||||
rem_k = K % tile_k if tile_k > 0 else 0
|
||||
tile_eff_k = rem_k / tile_k if rem_k > 0 else 1.0
|
||||
overall_eff = tile_eff_m * tile_eff_n * tile_eff_k
|
||||
|
||||
cu_util = total_output_tiles / max(self._hw["num_cus"], 1)
|
||||
|
||||
# P0 FIX: Problem-to-tile ratio features (avoid oversized tiles for tiny problems)
|
||||
ratio_M_to_tile_m = M / max(tile_m, 1)
|
||||
ratio_N_to_tile_n = N / max(tile_n, 1)
|
||||
ratio_K_to_tile_k = K / max(tile_k, 1)
|
||||
|
||||
# Binary features: is problem dimension smaller than tile?
|
||||
problem_smaller_than_tile_m = float(M < tile_m)
|
||||
problem_smaller_than_tile_n = float(N < tile_n)
|
||||
problem_smaller_than_tile_k = float(K < tile_k)
|
||||
any_dim_too_small = float((M < tile_m) or (N < tile_n) or (K < tile_k))
|
||||
|
||||
# P1 FIX: Padding requirement features (does this kernel have padding when needed?)
|
||||
needs_padding_m = float(M % tile_m != 0) if tile_m > 0 else 0.0
|
||||
needs_padding_n = float(N % tile_n != 0) if tile_n > 0 else 0.0
|
||||
needs_padding_k = float(K % tile_k != 0) if tile_k > 0 else 0.0
|
||||
|
||||
# Interaction features: kernel has padding capability when problem needs it
|
||||
has_padding_when_needed_m = float(needs_padding_m and pad_m)
|
||||
has_padding_when_needed_n = float(needs_padding_n and pad_n)
|
||||
has_padding_when_needed_k = float(needs_padding_k and pad_k)
|
||||
|
||||
# Critical feature: missing required padding (kernel will likely fail)
|
||||
missing_required_padding_m = float(needs_padding_m and not pad_m)
|
||||
missing_required_padding_n = float(needs_padding_n and not pad_n)
|
||||
missing_required_padding_k = float(needs_padding_k and not pad_k)
|
||||
missing_any_required_padding = float(
|
||||
missing_required_padding_m
|
||||
or missing_required_padding_n
|
||||
or missing_required_padding_k
|
||||
)
|
||||
|
||||
hw = self._hw
|
||||
return np.array(
|
||||
[
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
split_k,
|
||||
log2_M,
|
||||
log2_N,
|
||||
log2_K,
|
||||
log2_MNK,
|
||||
ai,
|
||||
ar_mn,
|
||||
ar_mk,
|
||||
ar_nk,
|
||||
layout_code,
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
pipeline_code,
|
||||
scheduler_code,
|
||||
epilogue_code,
|
||||
pad_m,
|
||||
pad_n,
|
||||
pad_k,
|
||||
persistent,
|
||||
num_warps,
|
||||
tile_volume,
|
||||
tile_mn,
|
||||
lds_est,
|
||||
lds_ratio,
|
||||
num_tiles_m,
|
||||
num_tiles_n,
|
||||
num_tiles_k,
|
||||
total_output_tiles,
|
||||
tile_eff_m,
|
||||
tile_eff_n,
|
||||
tile_eff_k,
|
||||
overall_eff,
|
||||
cu_util,
|
||||
# P0 FIX: New ratio and binary features
|
||||
ratio_M_to_tile_m,
|
||||
ratio_N_to_tile_n,
|
||||
ratio_K_to_tile_k,
|
||||
problem_smaller_than_tile_m,
|
||||
problem_smaller_than_tile_n,
|
||||
problem_smaller_than_tile_k,
|
||||
any_dim_too_small,
|
||||
# P1 FIX: Padding requirement interaction features
|
||||
needs_padding_m,
|
||||
needs_padding_n,
|
||||
needs_padding_k,
|
||||
has_padding_when_needed_m,
|
||||
has_padding_when_needed_n,
|
||||
has_padding_when_needed_k,
|
||||
missing_required_padding_m,
|
||||
missing_required_padding_n,
|
||||
missing_required_padding_k,
|
||||
missing_any_required_padding,
|
||||
hw["num_cus"],
|
||||
hw["simds_per_cu"],
|
||||
hw["total_simds"],
|
||||
hw["shader_engines"],
|
||||
hw["max_clock_mhz"],
|
||||
hw["max_waves_per_cu"],
|
||||
hw["wavefront_size"],
|
||||
hw["lds_capacity"],
|
||||
hw["l1_cache_kb"],
|
||||
hw["l2_cache_kb"],
|
||||
hw["l3_cache_kb"],
|
||||
hw["num_xcd"],
|
||||
],
|
||||
dtype=np.float64,
|
||||
)
|
||||
|
||||
def extract_batch(self, df: pd.DataFrame) -> np.ndarray:
|
||||
"""Vectorized batch extraction -- much faster than row-by-row."""
|
||||
n = len(df)
|
||||
names = self.get_feature_names()
|
||||
result = np.zeros((n, len(names)), dtype=np.float64)
|
||||
|
||||
M = df["m"].values.astype(np.float64)
|
||||
N = df["n"].values.astype(np.float64)
|
||||
K = df["k"].values.astype(np.float64)
|
||||
split_k = df["split_k"].fillna(1).values.astype(np.float64)
|
||||
|
||||
dtype_col = df["dtype"].fillna("fp8")
|
||||
bpe = dtype_col.map(DTYPE_BYTES).fillna(1.0).values
|
||||
|
||||
result[:, 0] = M
|
||||
result[:, 1] = N
|
||||
result[:, 2] = K
|
||||
result[:, 3] = split_k
|
||||
result[:, 4] = np.log2(np.maximum(M, 1))
|
||||
result[:, 5] = np.log2(np.maximum(N, 1))
|
||||
result[:, 6] = np.log2(np.maximum(K, 1))
|
||||
result[:, 7] = np.log2(np.maximum(M * N * K, 1))
|
||||
|
||||
mem = (M * K + K * N + M * N) * bpe
|
||||
result[:, 8] = (2.0 * M * N * K) / np.maximum(mem, 1)
|
||||
result[:, 9] = M / np.maximum(N, 1)
|
||||
result[:, 10] = M / np.maximum(K, 1)
|
||||
result[:, 11] = N / np.maximum(K, 1)
|
||||
|
||||
result[:, 12] = df["layout"].map(LAYOUT_MAP).fillna(0).values
|
||||
|
||||
tile_m = df["tile_m"].fillna(128).values.astype(np.float64)
|
||||
tile_n = df["tile_n"].fillna(128).values.astype(np.float64)
|
||||
tile_k = df["tile_k"].fillna(64).values.astype(np.float64)
|
||||
warp_m = df["warp_m"].fillna(2).values.astype(np.float64)
|
||||
warp_n = df["warp_n"].fillna(2).values.astype(np.float64)
|
||||
warp_k = df["warp_k"].fillna(1).values.astype(np.float64)
|
||||
warp_tile_m = df["warp_tile_m"].fillna(32).values.astype(np.float64)
|
||||
warp_tile_n = df["warp_tile_n"].fillna(32).values.astype(np.float64)
|
||||
warp_tile_k = df["warp_tile_k"].fillna(16).values.astype(np.float64)
|
||||
|
||||
result[:, 13] = tile_m
|
||||
result[:, 14] = tile_n
|
||||
result[:, 15] = tile_k
|
||||
result[:, 16] = warp_m
|
||||
result[:, 17] = warp_n
|
||||
result[:, 18] = warp_k
|
||||
result[:, 19] = warp_tile_m
|
||||
result[:, 20] = warp_tile_n
|
||||
result[:, 21] = warp_tile_k
|
||||
|
||||
result[:, 22] = df["pipeline"].map(PIPELINE_MAP).fillna(0).values
|
||||
result[:, 23] = df["scheduler"].map(SCHEDULER_MAP).fillna(0).values
|
||||
result[:, 24] = df["epilogue"].map(EPILOGUE_MAP).fillna(0).values
|
||||
|
||||
result[:, 25] = df["pad_m"].fillna(False).astype(float).values
|
||||
result[:, 26] = df["pad_n"].fillna(False).astype(float).values
|
||||
result[:, 27] = df["pad_k"].fillna(False).astype(float).values
|
||||
result[:, 28] = df["persistent"].fillna(False).astype(float).values
|
||||
|
||||
num_warps = warp_m * warp_n * warp_k
|
||||
result[:, 29] = num_warps
|
||||
result[:, 30] = tile_m * tile_n * tile_k
|
||||
result[:, 31] = tile_m * tile_n
|
||||
|
||||
lds_est = (tile_m * tile_k + tile_n * tile_k) * bpe
|
||||
result[:, 32] = lds_est
|
||||
lds_cap = np.full(n, self._hw["lds_capacity"], dtype=np.float64)
|
||||
is_compv4 = df["pipeline"].fillna("").str.startswith("compv4")
|
||||
lds_cap[is_compv4] = 32768
|
||||
result[:, 33] = lds_est / np.maximum(lds_cap, 1)
|
||||
|
||||
ntm = np.ceil(M / np.maximum(tile_m, 1))
|
||||
ntn = np.ceil(N / np.maximum(tile_n, 1))
|
||||
ntk = np.ceil(K / np.maximum(tile_k, 1))
|
||||
result[:, 34] = ntm
|
||||
result[:, 35] = ntn
|
||||
result[:, 36] = ntk
|
||||
result[:, 37] = ntm * ntn
|
||||
|
||||
rem_m = np.mod(M, np.maximum(tile_m, 1))
|
||||
result[:, 38] = np.where(rem_m > 0, rem_m / tile_m, 1.0)
|
||||
rem_n = np.mod(N, np.maximum(tile_n, 1))
|
||||
result[:, 39] = np.where(rem_n > 0, rem_n / tile_n, 1.0)
|
||||
rem_k = np.mod(K, np.maximum(tile_k, 1))
|
||||
result[:, 40] = np.where(rem_k > 0, rem_k / tile_k, 1.0)
|
||||
result[:, 41] = result[:, 38] * result[:, 39] * result[:, 40]
|
||||
|
||||
result[:, 42] = (ntm * ntn) / max(self._hw["num_cus"], 1)
|
||||
|
||||
# P0 FIX: Problem-to-tile ratio features
|
||||
result[:, 43] = M / np.maximum(tile_m, 1) # ratio_M_to_tile_m
|
||||
result[:, 44] = N / np.maximum(tile_n, 1) # ratio_N_to_tile_n
|
||||
result[:, 45] = K / np.maximum(tile_k, 1) # ratio_K_to_tile_k
|
||||
|
||||
# Binary features: is problem smaller than tile?
|
||||
result[:, 46] = (M < tile_m).astype(float) # problem_smaller_than_tile_m
|
||||
result[:, 47] = (N < tile_n).astype(float) # problem_smaller_than_tile_n
|
||||
result[:, 48] = (K < tile_k).astype(float) # problem_smaller_than_tile_k
|
||||
result[:, 49] = ((M < tile_m) | (N < tile_n) | (K < tile_k)).astype(
|
||||
float
|
||||
) # any_dim_too_small
|
||||
|
||||
# P1 FIX: Padding requirement features
|
||||
pad_m_bool = df["pad_m"].fillna(False).astype(bool).values
|
||||
pad_n_bool = df["pad_n"].fillna(False).astype(bool).values
|
||||
pad_k_bool = df["pad_k"].fillna(False).astype(bool).values
|
||||
|
||||
needs_padding_m = (np.mod(M, np.maximum(tile_m, 1)) != 0)
|
||||
needs_padding_n = (np.mod(N, np.maximum(tile_n, 1)) != 0)
|
||||
needs_padding_k = (np.mod(K, np.maximum(tile_k, 1)) != 0)
|
||||
|
||||
result[:, 50] = needs_padding_m.astype(float)
|
||||
result[:, 51] = needs_padding_n.astype(float)
|
||||
result[:, 52] = needs_padding_k.astype(float)
|
||||
|
||||
# Interaction features: kernel has padding when problem needs it
|
||||
result[:, 53] = (needs_padding_m & pad_m_bool).astype(float) # has_padding_when_needed_m
|
||||
result[:, 54] = (needs_padding_n & pad_n_bool).astype(float) # has_padding_when_needed_n
|
||||
result[:, 55] = (needs_padding_k & pad_k_bool).astype(float) # has_padding_when_needed_k
|
||||
|
||||
# Critical feature: missing required padding
|
||||
result[:, 56] = (needs_padding_m & ~pad_m_bool).astype(float) # missing_required_padding_m
|
||||
result[:, 57] = (needs_padding_n & ~pad_n_bool).astype(float) # missing_required_padding_n
|
||||
result[:, 58] = (needs_padding_k & ~pad_k_bool).astype(float) # missing_required_padding_k
|
||||
result[:, 59] = ((needs_padding_m & ~pad_m_bool) | (needs_padding_n & ~pad_n_bool) | (needs_padding_k & ~pad_k_bool)).astype(float) # missing_any_required_padding
|
||||
|
||||
# Hardware profile features
|
||||
hw = self._hw
|
||||
result[:, 60] = hw["num_cus"]
|
||||
result[:, 61] = hw["simds_per_cu"]
|
||||
result[:, 62] = hw["total_simds"]
|
||||
result[:, 63] = hw["shader_engines"]
|
||||
result[:, 64] = hw["max_clock_mhz"]
|
||||
result[:, 65] = hw["max_waves_per_cu"]
|
||||
result[:, 66] = hw["wavefront_size"]
|
||||
result[:, 67] = hw["lds_capacity"]
|
||||
result[:, 68] = hw["l1_cache_kb"]
|
||||
result[:, 69] = hw["l2_cache_kb"]
|
||||
result[:, 70] = hw["l3_cache_kb"]
|
||||
result[:, 71] = hw["num_xcd"]
|
||||
|
||||
return result
|
||||
|
||||
def get_parameter_space(self) -> dict[str, list]:
|
||||
return {
|
||||
"tile_m": [32, 64, 128, 192, 256],
|
||||
"tile_n": [32, 64, 128, 192, 256],
|
||||
"tile_k": [32, 64, 128, 256],
|
||||
"warp_m": [1, 2, 4],
|
||||
"warp_n": [1, 2, 4],
|
||||
"warp_k": [1],
|
||||
"warp_tile_m": [4, 16, 32, 64],
|
||||
"warp_tile_n": [4, 16, 32, 64],
|
||||
"warp_tile_k": [8, 16, 32, 64, 128],
|
||||
"pipeline": list(PIPELINE_MAP.keys()),
|
||||
"scheduler": list(SCHEDULER_MAP.keys()),
|
||||
"epilogue": list(EPILOGUE_MAP.keys()),
|
||||
"pad_m": [True, False],
|
||||
"pad_n": [True, False],
|
||||
"pad_k": [True, False],
|
||||
"persistent": [True, False],
|
||||
}
|
||||
|
||||
def get_constraints(self) -> list:
|
||||
lds_cap = self._hw["lds_capacity"]
|
||||
|
||||
def _lds_constraint(cfg):
|
||||
tm = cfg.get("tile_m", 128)
|
||||
tn = cfg.get("tile_n", 128)
|
||||
tk = cfg.get("tile_k", 64)
|
||||
bpe = 1.0 # fp8 default
|
||||
est = (tm * tk + tn * tk) * bpe
|
||||
cap = (
|
||||
32768 if str(cfg.get("pipeline", "")).startswith("compv4") else lds_cap
|
||||
)
|
||||
return est <= cap
|
||||
|
||||
def _warp_constraint(cfg):
|
||||
wm = cfg.get("warp_m", 2)
|
||||
wn = cfg.get("warp_n", 2)
|
||||
wk = cfg.get("warp_k", 1)
|
||||
return (wm * wn * wk) in [2, 4, 8]
|
||||
|
||||
return [_lds_constraint, _warp_constraint]
|
||||
553
dispatcher/heuristics/generate_benchmark_data.py
Normal file
553
dispatcher/heuristics/generate_benchmark_data.py
Normal file
@@ -0,0 +1,553 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
GEMM Universal Benchmark Data Generation Script
|
||||
|
||||
This script generates training data for ML-based kernel selection heuristics by:
|
||||
1. Reading kernel configurations from the tile engine
|
||||
2. Building benchmark executables (in parallel)
|
||||
3. Running benchmarks across multiple problem sizes
|
||||
4. Outputting performance data in JSON format
|
||||
|
||||
Usage:
|
||||
python generate_benchmark_data.py \
|
||||
--build_dir /tmp/build \
|
||||
--output_dir /tmp/benchmark_data \
|
||||
--dtype fp16 \
|
||||
--layout rcr \
|
||||
--num_build_jobs 4 \
|
||||
--num_benchmark_jobs 1
|
||||
|
||||
Requirements:
|
||||
- ROCm-capable GPU
|
||||
- CK tile engine built with CMake
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
import re
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelConfig:
|
||||
"""Represents a single kernel configuration."""
|
||||
|
||||
name: str
|
||||
dtype: str
|
||||
layout: str
|
||||
pipeline: str
|
||||
epilogue: str
|
||||
scheduler: str
|
||||
pad_m: bool
|
||||
pad_n: bool
|
||||
pad_k: bool
|
||||
persistent: bool
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
warp_m: int
|
||||
warp_n: int
|
||||
warp_k: int
|
||||
warp_tile_m: int
|
||||
warp_tile_n: int
|
||||
warp_tile_k: int
|
||||
|
||||
@classmethod
|
||||
def from_kernel_name(cls, name: str, dtype: str, layout: str) -> "KernelConfig":
|
||||
"""Parse kernel name to extract configuration."""
|
||||
# Format: gemm_universal_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler}_{padM}_{padN}_{padK}_{persistent}_{tile_config}
|
||||
# tile_config: {tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}
|
||||
|
||||
parts = name.split("_")
|
||||
prefix = f"gemm_universal_{dtype}_{layout}_"
|
||||
trait_and_tile = name[len(prefix) :]
|
||||
trait_parts = trait_and_tile.split("_")
|
||||
|
||||
pipeline = trait_parts[0]
|
||||
epilogue = trait_parts[1]
|
||||
scheduler = trait_parts[2]
|
||||
pad_m = trait_parts[3] == "True"
|
||||
pad_n = trait_parts[4] == "True"
|
||||
pad_k = trait_parts[5] == "True"
|
||||
persistent = trait_parts[6] == "True"
|
||||
|
||||
# Parse tile config
|
||||
tile_dims = trait_parts[7].split("x")
|
||||
warp_dims = trait_parts[8].split("x")
|
||||
warp_tile_dims = trait_parts[9].split("x")
|
||||
|
||||
return cls(
|
||||
name=name,
|
||||
dtype=dtype,
|
||||
layout=layout,
|
||||
pipeline=pipeline,
|
||||
epilogue=epilogue,
|
||||
scheduler=scheduler,
|
||||
pad_m=pad_m,
|
||||
pad_n=pad_n,
|
||||
pad_k=pad_k,
|
||||
persistent=persistent,
|
||||
tile_m=int(tile_dims[0]),
|
||||
tile_n=int(tile_dims[1]),
|
||||
tile_k=int(tile_dims[2]),
|
||||
warp_m=int(warp_dims[0]),
|
||||
warp_n=int(warp_dims[1]),
|
||||
warp_k=int(warp_dims[2]),
|
||||
warp_tile_m=int(warp_tile_dims[0]),
|
||||
warp_tile_n=int(warp_tile_dims[1]),
|
||||
warp_tile_k=int(warp_tile_dims[2]),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
"""Result of a single benchmark run."""
|
||||
|
||||
kernel_name: str
|
||||
m: int
|
||||
n: int
|
||||
k: int
|
||||
avg_time_ms: float
|
||||
tflops: float
|
||||
is_valid: bool
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProblemSize:
|
||||
"""GEMM problem dimensions."""
|
||||
|
||||
m: int
|
||||
n: int
|
||||
k: int
|
||||
|
||||
|
||||
def get_problem_sizes() -> List[ProblemSize]:
|
||||
"""
|
||||
Generate diverse problem sizes for benchmarking.
|
||||
|
||||
Includes:
|
||||
- Square matrices (powers of 2)
|
||||
- Rectangular matrices (common in ML)
|
||||
- LLM-specific sizes (attention, MLP)
|
||||
- Edge cases (small, very large)
|
||||
"""
|
||||
sizes = []
|
||||
|
||||
# Powers of 2 (square)
|
||||
for p in [6, 7, 8, 9, 10, 11, 12, 13]: # 64 to 8192
|
||||
dim = 2**p
|
||||
sizes.append(ProblemSize(dim, dim, dim))
|
||||
|
||||
# Common ML sizes (batch x hidden)
|
||||
ml_sizes = [
|
||||
(1, 4096, 4096), # Single token inference
|
||||
(8, 4096, 4096), # Small batch
|
||||
(32, 4096, 4096), # Medium batch
|
||||
(128, 4096, 4096), # Large batch
|
||||
(1, 4096, 11008), # LLaMA MLP up-projection
|
||||
(1, 11008, 4096), # LLaMA MLP down-projection
|
||||
(32, 4096, 11008),
|
||||
(32, 11008, 4096),
|
||||
(1, 8192, 8192), # Large model
|
||||
(32, 8192, 8192),
|
||||
(1, 8192, 28672), # LLaMA-70B MLP
|
||||
(32, 8192, 28672),
|
||||
]
|
||||
for m, n, k in ml_sizes:
|
||||
sizes.append(ProblemSize(m, n, k))
|
||||
|
||||
# Rectangular matrices
|
||||
rect_sizes = [
|
||||
(1024, 4096, 1024),
|
||||
(4096, 1024, 4096),
|
||||
(2048, 8192, 2048),
|
||||
(256, 256, 8192), # Tall K
|
||||
(8192, 8192, 256), # Short K
|
||||
]
|
||||
for m, n, k in rect_sizes:
|
||||
sizes.append(ProblemSize(m, n, k))
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
seen = set()
|
||||
unique_sizes = []
|
||||
for s in sizes:
|
||||
key = (s.m, s.n, s.k)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique_sizes.append(s)
|
||||
|
||||
return unique_sizes
|
||||
|
||||
|
||||
def load_kernel_list(build_dir: Path, dtype: str, layout: str) -> List[KernelConfig]:
|
||||
"""Load kernel configurations from the tile engine build."""
|
||||
kernel_list_path = (
|
||||
build_dir
|
||||
/ "tile_engine"
|
||||
/ "ops"
|
||||
/ "gemm"
|
||||
/ "gemm_universal"
|
||||
/ dtype
|
||||
/ layout
|
||||
/ "gemm_universal_kernel_list.txt"
|
||||
)
|
||||
|
||||
if not kernel_list_path.exists():
|
||||
raise FileNotFoundError(f"Kernel list not found: {kernel_list_path}")
|
||||
|
||||
kernels = []
|
||||
with open(kernel_list_path, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# Format: kernel_name|tile_config|trait_combo
|
||||
parts = line.split("|")
|
||||
kernel_name = parts[0]
|
||||
kernels.append(KernelConfig.from_kernel_name(kernel_name, dtype, layout))
|
||||
|
||||
return kernels
|
||||
|
||||
|
||||
def build_kernel(build_dir: Path, kernel: KernelConfig) -> Tuple[str, bool, str]:
|
||||
"""
|
||||
Build a single kernel benchmark executable.
|
||||
|
||||
Returns: (kernel_name, success, error_message)
|
||||
"""
|
||||
target_name = f"benchmark_{kernel.name}"
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["ninja", "-j1", target_name],
|
||||
cwd=build_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5 minute timeout
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
return (kernel.name, False, result.stderr[:500])
|
||||
|
||||
return (kernel.name, True, "")
|
||||
except subprocess.TimeoutExpired:
|
||||
return (kernel.name, False, "Build timeout")
|
||||
except Exception as e:
|
||||
return (kernel.name, False, str(e))
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
build_dir: Path,
|
||||
kernel: KernelConfig,
|
||||
problem: ProblemSize,
|
||||
warmup: int = 10,
|
||||
repeat: int = 50,
|
||||
) -> BenchmarkResult:
|
||||
"""
|
||||
Run benchmark for a single kernel and problem size.
|
||||
"""
|
||||
exe_path = build_dir / "bin" / f"benchmark_{kernel.name}"
|
||||
|
||||
if not exe_path.exists():
|
||||
return BenchmarkResult(
|
||||
kernel_name=kernel.name,
|
||||
m=problem.m,
|
||||
n=problem.n,
|
||||
k=problem.k,
|
||||
avg_time_ms=0,
|
||||
tflops=0,
|
||||
is_valid=False,
|
||||
error="Executable not found",
|
||||
)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
str(exe_path),
|
||||
f"-m={problem.m}",
|
||||
f"-n={problem.n}",
|
||||
f"-k={problem.k}",
|
||||
f"-warmup={warmup}",
|
||||
f"-repeat={repeat}",
|
||||
"-verify=0",
|
||||
"-json_output=true",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
# Try to parse error
|
||||
error = result.stderr[:200] if result.stderr else result.stdout[:200]
|
||||
return BenchmarkResult(
|
||||
kernel_name=kernel.name,
|
||||
m=problem.m,
|
||||
n=problem.n,
|
||||
k=problem.k,
|
||||
avg_time_ms=0,
|
||||
tflops=0,
|
||||
is_valid=False,
|
||||
error=error,
|
||||
)
|
||||
|
||||
# Parse JSON output
|
||||
output = result.stdout.strip()
|
||||
|
||||
# Try to find JSON in output
|
||||
json_match = re.search(r"\{.*\}", output, re.DOTALL)
|
||||
if json_match:
|
||||
data = json.loads(json_match.group())
|
||||
# Extract from nested perf_result object
|
||||
perf = data.get("perf_result", {})
|
||||
avg_time_ms = perf.get("latency(ms)", 0)
|
||||
tflops = perf.get("tflops(TFlops)", 0)
|
||||
|
||||
return BenchmarkResult(
|
||||
kernel_name=kernel.name,
|
||||
m=problem.m,
|
||||
n=problem.n,
|
||||
k=problem.k,
|
||||
avg_time_ms=avg_time_ms,
|
||||
tflops=tflops,
|
||||
is_valid=True,
|
||||
)
|
||||
else:
|
||||
# Parse from text output
|
||||
# Look for patterns like "avg_time: X ms" or "tflops: Y"
|
||||
avg_time = 0.0
|
||||
tflops = 0.0
|
||||
|
||||
time_match = re.search(
|
||||
r"(?:avg[_\s]?time|latency)[:\s]+(\d+\.?\d*)\s*(?:ms)?", output, re.I
|
||||
)
|
||||
if time_match:
|
||||
avg_time = float(time_match.group(1))
|
||||
|
||||
tflops_match = re.search(r"tflops[:\s]+(\d+\.?\d*)", output, re.I)
|
||||
if tflops_match:
|
||||
tflops = float(tflops_match.group(1))
|
||||
|
||||
# Calculate TFLOPs if not provided
|
||||
if tflops == 0 and avg_time > 0:
|
||||
flops = 2.0 * problem.m * problem.n * problem.k
|
||||
tflops = flops / (avg_time * 1e-3) / 1e12
|
||||
|
||||
return BenchmarkResult(
|
||||
kernel_name=kernel.name,
|
||||
m=problem.m,
|
||||
n=problem.n,
|
||||
k=problem.k,
|
||||
avg_time_ms=avg_time,
|
||||
tflops=tflops,
|
||||
is_valid=avg_time > 0,
|
||||
error=None if avg_time > 0 else "Could not parse output",
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return BenchmarkResult(
|
||||
kernel_name=kernel.name,
|
||||
m=problem.m,
|
||||
n=problem.n,
|
||||
k=problem.k,
|
||||
avg_time_ms=0,
|
||||
tflops=0,
|
||||
is_valid=False,
|
||||
error="Benchmark timeout",
|
||||
)
|
||||
except Exception as e:
|
||||
return BenchmarkResult(
|
||||
kernel_name=kernel.name,
|
||||
m=problem.m,
|
||||
n=problem.n,
|
||||
k=problem.k,
|
||||
avg_time_ms=0,
|
||||
tflops=0,
|
||||
is_valid=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate GEMM benchmark data for ML training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--build_dir", type=str, default="/tmp/build", help="CK build directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="/tmp/benchmark_data",
|
||||
help="Output directory for benchmark results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="fp16",
|
||||
choices=["fp16", "fp8", "bf16", "bf8"],
|
||||
help="Data type to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layout",
|
||||
type=str,
|
||||
default="rcr",
|
||||
choices=["rcr", "rrr", "crr", "ccr"],
|
||||
help="Matrix layout to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_build_jobs", type=int, default=4, help="Number of parallel build jobs"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_benchmark_jobs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of parallel benchmark jobs (use 1 for accurate timing)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_kernels",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of kernels to benchmark (for testing)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_build",
|
||||
action="store_true",
|
||||
help="Skip building and only run benchmarks",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup", type=int, default=10, help="Number of warmup iterations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat", type=int, default=50, help="Number of benchmark iterations"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
build_dir = Path(args.build_dir)
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load kernel configurations
|
||||
print(f"Loading kernel list for {args.dtype}/{args.layout}...")
|
||||
kernels = load_kernel_list(build_dir, args.dtype, args.layout)
|
||||
print(f"Found {len(kernels)} kernel configurations")
|
||||
|
||||
if args.max_kernels:
|
||||
kernels = kernels[: args.max_kernels]
|
||||
print(f"Limiting to {len(kernels)} kernels")
|
||||
|
||||
# Build kernels
|
||||
if not args.skip_build:
|
||||
print(
|
||||
f"\nBuilding {len(kernels)} kernels with {args.num_build_jobs} parallel jobs..."
|
||||
)
|
||||
build_results = {"success": 0, "failed": 0, "failed_kernels": []}
|
||||
|
||||
with ProcessPoolExecutor(max_workers=args.num_build_jobs) as executor:
|
||||
futures = {executor.submit(build_kernel, build_dir, k): k for k in kernels}
|
||||
|
||||
for i, future in enumerate(as_completed(futures)):
|
||||
kernel_name, success, error = future.result()
|
||||
if success:
|
||||
build_results["success"] += 1
|
||||
else:
|
||||
build_results["failed"] += 1
|
||||
build_results["failed_kernels"].append(
|
||||
{"name": kernel_name, "error": error}
|
||||
)
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
print(
|
||||
f" Built {i + 1}/{len(kernels)} ({build_results['success']} success, {build_results['failed']} failed)"
|
||||
)
|
||||
|
||||
print(
|
||||
f"\nBuild complete: {build_results['success']} success, {build_results['failed']} failed"
|
||||
)
|
||||
|
||||
# Save build results
|
||||
with open(output_dir / "build_results.json", "w") as f:
|
||||
json.dump(build_results, f, indent=2)
|
||||
|
||||
# Get problem sizes
|
||||
problem_sizes = get_problem_sizes()
|
||||
print(f"\nBenchmarking {len(problem_sizes)} problem sizes...")
|
||||
|
||||
# Run benchmarks
|
||||
all_results = []
|
||||
total_benchmarks = len(kernels) * len(problem_sizes)
|
||||
completed = 0
|
||||
|
||||
print(f"Total benchmarks to run: {total_benchmarks}")
|
||||
|
||||
for kernel in kernels:
|
||||
kernel_results = {
|
||||
"kernel_config": asdict(kernel),
|
||||
"benchmarks": [],
|
||||
}
|
||||
|
||||
for problem in problem_sizes:
|
||||
result = run_benchmark(
|
||||
build_dir,
|
||||
kernel,
|
||||
problem,
|
||||
warmup=args.warmup,
|
||||
repeat=args.repeat,
|
||||
)
|
||||
kernel_results["benchmarks"].append(asdict(result))
|
||||
completed += 1
|
||||
|
||||
if completed % 100 == 0:
|
||||
print(f" Progress: {completed}/{total_benchmarks} benchmarks complete")
|
||||
|
||||
all_results.append(kernel_results)
|
||||
|
||||
# Save intermediate results
|
||||
intermediate_file = (
|
||||
output_dir / f"benchmark_results_{args.dtype}_{args.layout}_partial.json"
|
||||
)
|
||||
with open(intermediate_file, "w") as f:
|
||||
json.dump(all_results, f, indent=2)
|
||||
|
||||
# Save final results
|
||||
final_file = output_dir / f"benchmark_results_{args.dtype}_{args.layout}.json"
|
||||
with open(final_file, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"metadata": {
|
||||
"dtype": args.dtype,
|
||||
"layout": args.layout,
|
||||
"num_kernels": len(kernels),
|
||||
"num_problems": len(problem_sizes),
|
||||
"warmup": args.warmup,
|
||||
"repeat": args.repeat,
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
},
|
||||
"problem_sizes": [asdict(p) for p in problem_sizes],
|
||||
"results": all_results,
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
print(f"\nResults saved to {final_file}")
|
||||
|
||||
# Print summary
|
||||
valid_count = sum(
|
||||
1 for kr in all_results for br in kr["benchmarks"] if br["is_valid"]
|
||||
)
|
||||
print(f"Valid benchmarks: {valid_count}/{total_benchmarks}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
166
dispatcher/heuristics/generate_edge_dims.py
Normal file
166
dispatcher/heuristics/generate_edge_dims.py
Normal file
@@ -0,0 +1,166 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Supplementary edge-case benchmark generator for N=1 and K=1 dimensions.
|
||||
|
||||
These shapes represent vector-matrix multiply (N=1), rank-1 updates (K=1),
|
||||
and other degenerate GEMM cases that stress tile efficiency and padding logic.
|
||||
"""
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def generate_edge_shapes():
|
||||
"""Generate shapes with N=1, K=1, and other single-dimension edge cases."""
|
||||
shapes = set()
|
||||
|
||||
# --- N=1: vector-matrix multiply / single output column ---
|
||||
for m in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]:
|
||||
for k in [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 7168, 8192]:
|
||||
shapes.add((m, 1, k))
|
||||
|
||||
# --- K=1: rank-1 update / outer product ---
|
||||
for m in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]:
|
||||
for n in [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 7168, 8192]:
|
||||
shapes.add((m, n, 1))
|
||||
|
||||
# --- M=1, N=1: dot product ---
|
||||
for k in [1, 16, 64, 256, 1024, 4096, 8192]:
|
||||
shapes.add((1, 1, k))
|
||||
|
||||
# --- M=1, K=1: scalar-vector ---
|
||||
for n in [1, 16, 64, 256, 1024, 4096, 8192]:
|
||||
shapes.add((1, n, 1))
|
||||
|
||||
# --- N=1, K=1: scalar-vector ---
|
||||
for m in [1, 16, 64, 256, 1024, 4096, 8192]:
|
||||
shapes.add((m, 1, 1))
|
||||
|
||||
# --- All ones: 1x1x1 ---
|
||||
shapes.add((1, 1, 1))
|
||||
|
||||
# --- Small N (2-16) ---
|
||||
for m in [64, 256, 1024, 4096]:
|
||||
for n in [2, 3, 4, 7, 8, 15, 16]:
|
||||
for k in [64, 256, 1024, 4096]:
|
||||
shapes.add((m, n, k))
|
||||
|
||||
# --- Small K (2-16) ---
|
||||
for m in [64, 256, 1024, 4096]:
|
||||
for n in [64, 256, 1024, 4096]:
|
||||
for k in [2, 3, 4, 7, 8, 15, 16]:
|
||||
shapes.add((m, n, k))
|
||||
|
||||
return sorted(shapes)
|
||||
|
||||
|
||||
def run_shapes(bin_dir, shapes, out_file, warmup=3, repeat=10):
|
||||
"""Run all kernels against shapes, writing streaming log."""
|
||||
executables = sorted(Path(bin_dir).glob("benchmark_gemm_universal_fp8_rcr_*"))
|
||||
if not executables:
|
||||
print(f"ERROR: No executables found in {bin_dir}", file=sys.stderr)
|
||||
return 0
|
||||
|
||||
total = 0
|
||||
for idx, (m, n, k) in enumerate(shapes):
|
||||
out_file.write("\n========================================\n")
|
||||
out_file.write(f"Shape {idx + 1}: M={m} N={n} K={k} dtype=fp8 layout=rcr\n")
|
||||
out_file.write("========================================\n")
|
||||
out_file.write(f"Found {len(executables)} kernels\n")
|
||||
out_file.flush()
|
||||
|
||||
for exe in executables:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
str(exe),
|
||||
f"-m={m}",
|
||||
f"-n={n}",
|
||||
f"-k={k}",
|
||||
f"-warmup={warmup}",
|
||||
f"-repeat={repeat}",
|
||||
"-verify=0",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
output = result.stdout
|
||||
json_start = output.find("{")
|
||||
json_end = output.rfind("}") + 1
|
||||
if json_start >= 0 and json_end > json_start:
|
||||
json_block = output[json_start:json_end]
|
||||
try:
|
||||
json.loads(json_block)
|
||||
out_file.write(json_block + "\n")
|
||||
total += 1
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except (subprocess.TimeoutExpired, Exception):
|
||||
pass
|
||||
|
||||
out_file.flush()
|
||||
print(
|
||||
f" Shape {idx + 1}/{len(shapes)}: M={m} N={n} K={k}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
return total
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bin_dir = "/workspace/ck_tile/bin"
|
||||
out_dir = Path("data/edge_dims")
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
shapes = generate_edge_shapes()
|
||||
print(f"Generated {len(shapes)} edge-case shapes", file=sys.stderr, flush=True)
|
||||
|
||||
n1_count = sum(1 for m, n, k in shapes if n == 1)
|
||||
k1_count = sum(1 for m, n, k in shapes if k == 1)
|
||||
both1 = sum(1 for m, n, k in shapes if n == 1 and k == 1)
|
||||
small_n = sum(1 for m, n, k in shapes if 2 <= n <= 16)
|
||||
small_k = sum(1 for m, n, k in shapes if 2 <= k <= 16)
|
||||
print(
|
||||
f" N=1: {n1_count}, K=1: {k1_count}, both=1: {both1}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
f" Small N(2-16): {small_n}, Small K(2-16): {small_k}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
batch_size = 25
|
||||
total = 0
|
||||
batch_idx = 0
|
||||
for i in range(0, len(shapes), batch_size):
|
||||
batch = shapes[i : i + batch_size]
|
||||
batch_idx += 1
|
||||
out_path = out_dir / f"edge_dims_batch_{batch_idx:03d}.log"
|
||||
print(
|
||||
f"\nBatch {batch_idx}: shapes {i + 1}-{i + len(batch)} -> {out_path}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
with open(out_path, "w") as f:
|
||||
f.write(f"CK Tile Edge Dims Benchmark Batch {batch_idx}\n")
|
||||
f.write("GPU ID: 0\nImplementation: gemm_universal\n\n")
|
||||
count = run_shapes(bin_dir, batch, f, warmup=3, repeat=10)
|
||||
total += count
|
||||
|
||||
print(f" Batch {batch_idx} done: {count} results", file=sys.stderr, flush=True)
|
||||
|
||||
print(
|
||||
f"\nTotal: {total} benchmarks across {len(shapes)} shapes",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
289
dispatcher/heuristics/generate_wide_coverage.py
Normal file
289
dispatcher/heuristics/generate_wide_coverage.py
Normal file
@@ -0,0 +1,289 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Wide-coverage benchmark data generator.
|
||||
|
||||
Generates benchmark results for hundreds of diverse GEMM shapes across all
|
||||
corner cases: skinny M, tall N, deep K, M=1, prime dimensions, power-of-2,
|
||||
LLM inference shapes, training shapes, and edge cases.
|
||||
|
||||
Runs all 4608 kernels in /workspace/ck_tile/bin/ against each shape and
|
||||
writes streaming log output parseable by data_pipeline.py.
|
||||
|
||||
Usage:
|
||||
python3 generate_wide_coverage.py --bin_dir /workspace/ck_tile/bin \
|
||||
--out_dir data/ --batch_size 20 --warmup 3 --repeat 10
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def generate_shape_list():
|
||||
"""Generate a comprehensive list of (M, N, K) shapes covering all corner cases.
|
||||
|
||||
Categories:
|
||||
1. M=1 (single token inference) -- the hardest case
|
||||
2. Tiny M (2-16) -- small batch inference
|
||||
3. Small M (32-128) -- medium batch
|
||||
4. Medium M (256-2048) -- large batch / training
|
||||
5. Large M (4096-20480) -- very large batch
|
||||
6. Square shapes (powers of 2)
|
||||
7. Skinny M, tall N (M << N)
|
||||
8. Tall M, skinny N (M >> N)
|
||||
9. Deep K (K >> M, N) -- compute-bound
|
||||
10. Shallow K (K << M, N) -- memory-bound
|
||||
11. Prime dimensions -- worst-case for tiling
|
||||
12. LLM-specific shapes (DeepSeek, LLaMA, etc.)
|
||||
13. Non-power-of-2 common sizes
|
||||
"""
|
||||
shapes = set()
|
||||
|
||||
# --- 1. M=1 (single token) across various N, K ---
|
||||
for n in [512, 1024, 1536, 2048, 3072, 4096, 4608, 7168, 8192, 11008, 14336, 28672]:
|
||||
for k in [256, 512, 1024, 1536, 2048, 2304, 4096, 7168, 8192]:
|
||||
shapes.add((1, n, k))
|
||||
|
||||
# --- 2. Tiny M (2-16) ---
|
||||
for m in [2, 4, 8, 16]:
|
||||
for n in [512, 1536, 4096, 7168]:
|
||||
for k in [256, 1024, 4096, 7168]:
|
||||
shapes.add((m, n, k))
|
||||
|
||||
# --- 3. Small M (32-128) ---
|
||||
for m in [32, 48, 64, 96, 128]:
|
||||
for n in [512, 1536, 4096, 7168, 8192]:
|
||||
for k in [256, 512, 2048, 4096, 7168]:
|
||||
shapes.add((m, n, k))
|
||||
|
||||
# --- 4. Medium M (256-2048) ---
|
||||
for m in [256, 384, 512, 768, 1024, 1536, 2048]:
|
||||
for n in [512, 1536, 4096, 7168]:
|
||||
for k in [256, 1024, 2048, 4096, 7168]:
|
||||
shapes.add((m, n, k))
|
||||
|
||||
# --- 5. Large M (4096-20480) ---
|
||||
for m in [4096, 8192, 12288, 16384, 20480]:
|
||||
for n in [512, 1536, 4096, 7168]:
|
||||
for k in [256, 1024, 2048, 7168]:
|
||||
shapes.add((m, n, k))
|
||||
|
||||
# --- 6. Square shapes (powers of 2) ---
|
||||
for p in range(5, 14): # 32 to 8192
|
||||
d = 2**p
|
||||
shapes.add((d, d, d))
|
||||
|
||||
# --- 7. Skinny M, tall N ---
|
||||
for m in [1, 4, 16, 64]:
|
||||
for n in [8192, 16384, 28672]:
|
||||
for k in [1024, 4096, 8192]:
|
||||
shapes.add((m, n, k))
|
||||
|
||||
# --- 8. Tall M, skinny N ---
|
||||
for m in [4096, 8192, 16384]:
|
||||
for n in [32, 64, 128, 256]:
|
||||
for k in [1024, 4096]:
|
||||
shapes.add((m, n, k))
|
||||
|
||||
# --- 9. Deep K (K >> M, N) ---
|
||||
for m in [16, 64, 256]:
|
||||
for n in [16, 64, 256]:
|
||||
for k in [4096, 8192, 16384, 32768]:
|
||||
shapes.add((m, n, k))
|
||||
|
||||
# --- 10. Shallow K (K << M, N) ---
|
||||
for m in [1024, 4096, 8192]:
|
||||
for n in [1024, 4096, 8192]:
|
||||
for k in [16, 32, 64, 128]:
|
||||
shapes.add((m, n, k))
|
||||
|
||||
# --- 11. Prime dimensions ---
|
||||
primes = [17, 31, 37, 127, 251, 509, 1021, 2039, 4093]
|
||||
for p in primes:
|
||||
shapes.add((p, p, p))
|
||||
for p in primes[:5]:
|
||||
shapes.add((p, 4096, 4096))
|
||||
shapes.add((4096, p, 4096))
|
||||
shapes.add((4096, 4096, p))
|
||||
|
||||
# --- 12. LLM-specific shapes ---
|
||||
llm_shapes = [
|
||||
# DeepSeek MoE
|
||||
(1, 1536, 7168),
|
||||
(1, 4608, 7168),
|
||||
(1, 7168, 2048),
|
||||
(1, 7168, 2304),
|
||||
(1, 7168, 256),
|
||||
(1, 576, 7168),
|
||||
(1, 512, 7168),
|
||||
(1, 3072, 1536),
|
||||
# LLaMA-7B
|
||||
(1, 4096, 4096),
|
||||
(32, 4096, 4096),
|
||||
(128, 4096, 4096),
|
||||
(1, 4096, 11008),
|
||||
(32, 4096, 11008),
|
||||
(1, 11008, 4096),
|
||||
(32, 11008, 4096),
|
||||
# LLaMA-70B
|
||||
(1, 8192, 8192),
|
||||
(32, 8192, 8192),
|
||||
(128, 8192, 8192),
|
||||
(1, 8192, 28672),
|
||||
(32, 8192, 28672),
|
||||
(1, 28672, 8192),
|
||||
# GPT-style attention
|
||||
(128, 128, 64),
|
||||
(128, 128, 128),
|
||||
(256, 256, 64),
|
||||
(512, 512, 64),
|
||||
(1024, 1024, 64),
|
||||
(2048, 2048, 64),
|
||||
]
|
||||
for s in llm_shapes:
|
||||
shapes.add(s)
|
||||
|
||||
# --- 13. Non-power-of-2 common sizes ---
|
||||
for m in [48, 96, 192, 384, 576, 768, 1152, 1536, 2304, 3072, 4608, 6144]:
|
||||
shapes.add((m, m, m))
|
||||
shapes.add((m, 4096, 4096))
|
||||
|
||||
sorted_shapes = sorted(shapes)
|
||||
return sorted_shapes
|
||||
|
||||
|
||||
def run_shape_batch(bin_dir, shapes, out_file, warmup=3, repeat=10):
|
||||
"""Run all kernels against a batch of shapes, writing streaming log output."""
|
||||
executables = sorted(Path(bin_dir).glob("benchmark_gemm_universal_fp8_rcr_*"))
|
||||
if not executables:
|
||||
print(f"ERROR: No executables found in {bin_dir}", file=sys.stderr)
|
||||
return 0
|
||||
|
||||
total_benchmarks = 0
|
||||
|
||||
for shape_idx, (m, n, k) in enumerate(shapes):
|
||||
out_file.write("\n========================================\n")
|
||||
out_file.write(
|
||||
f"Shape {shape_idx + 1}: M={m} N={n} K={k} dtype=fp8 layout=rcr\n"
|
||||
)
|
||||
out_file.write("========================================\n")
|
||||
out_file.write(f"Found {len(executables)} kernels\n")
|
||||
out_file.flush()
|
||||
|
||||
for exe in executables:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
str(exe),
|
||||
f"-m={m}",
|
||||
f"-n={n}",
|
||||
f"-k={k}",
|
||||
f"-warmup={warmup}",
|
||||
f"-repeat={repeat}",
|
||||
"-verify=0",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
output = result.stdout
|
||||
# Extract JSON block from output
|
||||
json_start = output.find("{")
|
||||
json_end = output.rfind("}") + 1
|
||||
if json_start >= 0 and json_end > json_start:
|
||||
json_block = output[json_start:json_end]
|
||||
try:
|
||||
json.loads(json_block)
|
||||
out_file.write(json_block + "\n")
|
||||
total_benchmarks += 1
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except (subprocess.TimeoutExpired, Exception):
|
||||
pass
|
||||
|
||||
out_file.flush()
|
||||
elapsed_kernels = len(executables)
|
||||
print(
|
||||
f" Shape {shape_idx + 1}/{len(shapes)}: M={m} N={n} K={k} "
|
||||
f"({elapsed_kernels} kernels)",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
return total_benchmarks
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate wide-coverage benchmark data"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bin_dir",
|
||||
default="/workspace/ck_tile/bin",
|
||||
help="Directory with benchmark executables",
|
||||
)
|
||||
parser.add_argument("--out_dir", default="data", help="Output directory")
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=25, help="Shapes per output file"
|
||||
)
|
||||
parser.add_argument("--warmup", type=int, default=3)
|
||||
parser.add_argument("--repeat", type=int, default=10)
|
||||
parser.add_argument(
|
||||
"--max_shapes", type=int, default=None, help="Limit total shapes (for testing)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
shapes = generate_shape_list()
|
||||
if args.max_shapes:
|
||||
shapes = shapes[: args.max_shapes]
|
||||
|
||||
print(f"Generated {len(shapes)} unique shapes", file=sys.stderr, flush=True)
|
||||
print(f"Bin dir: {args.bin_dir}", file=sys.stderr, flush=True)
|
||||
print(f"Output dir: {args.out_dir}", file=sys.stderr, flush=True)
|
||||
print(f"Batch size: {args.batch_size}", file=sys.stderr, flush=True)
|
||||
|
||||
total = 0
|
||||
batch_idx = 0
|
||||
for i in range(0, len(shapes), args.batch_size):
|
||||
batch = shapes[i : i + args.batch_size]
|
||||
batch_idx += 1
|
||||
out_path = out_dir / f"wide_coverage_batch_{batch_idx:03d}.log"
|
||||
|
||||
print(
|
||||
f"\nBatch {batch_idx}: shapes {i + 1}-{i + len(batch)} -> {out_path}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
with open(out_path, "w") as f:
|
||||
f.write(f"CK Tile Wide Coverage Benchmark Batch {batch_idx}\n")
|
||||
f.write("GPU ID: 0\n")
|
||||
f.write("Implementation: gemm_universal\n\n")
|
||||
count = run_shape_batch(
|
||||
args.bin_dir, batch, f, warmup=args.warmup, repeat=args.repeat
|
||||
)
|
||||
total += count
|
||||
|
||||
print(
|
||||
f" Batch {batch_idx} complete: {count} benchmarks",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
print(
|
||||
f"\nTotal: {total} benchmarks across {len(shapes)} shapes",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
867
dispatcher/heuristics/ml_heuristic_sweep.py
Normal file
867
dispatcher/heuristics/ml_heuristic_sweep.py
Normal file
@@ -0,0 +1,867 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
ML Heuristic Sweep: Comprehensive GEMM Performance Evaluation
|
||||
|
||||
Sweeps across diverse problem shapes with ML-based kernel selection to measure
|
||||
TFLOPS performance. Supports multiple dtypes (fp16, bf16, fp8) and validates
|
||||
ML model predictions by executing kernels on GPU.
|
||||
|
||||
Shape Constraints (fp16/bf16 on gfx950):
|
||||
- M >= 1 (any M is valid)
|
||||
- N % 8 == 0 AND N >= 64
|
||||
- K % 2 == 0 AND K >= 32
|
||||
|
||||
Usage:
|
||||
python ml_heuristic_sweep.py --dtype fp16 --num_shapes 256
|
||||
python ml_heuristic_sweep.py --dtypes fp16 bf16 --output sweep_results.csv
|
||||
python ml_heuristic_sweep.py --dtype fp16 --dry_run # Prediction only, no GPU execution
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import time
|
||||
import csv
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
# Add parent directories to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "python"))
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
)
|
||||
|
||||
try:
|
||||
from predict import Predictor
|
||||
# from feature_engine import GemmUniversalFeatureEngine
|
||||
|
||||
HAS_ML = True
|
||||
except ImportError:
|
||||
HAS_ML = False
|
||||
print("WARNING: ML heuristic modules not available. Will use first-fit selection.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelSpec:
|
||||
"""Kernel specification for ML heuristic"""
|
||||
|
||||
name: str
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
pipeline: str = "compv3"
|
||||
scheduler: str = "intrawave"
|
||||
wave_m: int = 2
|
||||
wave_n: int = 2
|
||||
wave_k: int = 1
|
||||
warp_m: int = 32
|
||||
warp_n: int = 32
|
||||
warp_k: int = 16
|
||||
|
||||
|
||||
# Comprehensive kernel pool covering diverse tile sizes and configurations
|
||||
KERNEL_POOL = [
|
||||
# Small tiles (64x64)
|
||||
KernelSpec(
|
||||
"s_64x64_k32_v3", 64, 64, 32, "compv3", "intrawave", 2, 2, 1, 16, 16, 16
|
||||
),
|
||||
KernelSpec(
|
||||
"s_64x64_k64_v3", 64, 64, 64, "compv3", "intrawave", 2, 2, 1, 16, 16, 16
|
||||
),
|
||||
KernelSpec(
|
||||
"s_64x64_k128_v3", 64, 64, 128, "compv3", "intrawave", 2, 2, 1, 16, 16, 16
|
||||
),
|
||||
KernelSpec(
|
||||
"s_64x64_k64_v4", 64, 64, 64, "compv4", "intrawave", 2, 2, 1, 16, 16, 16
|
||||
),
|
||||
KernelSpec("s_64x64_k64_mem", 64, 64, 64, "mem", "intrawave", 2, 2, 1, 16, 16, 16),
|
||||
KernelSpec(
|
||||
"s_64x64_k128_mem", 64, 64, 128, "mem", "intrawave", 2, 2, 1, 16, 16, 16
|
||||
),
|
||||
# Medium tiles (128x128)
|
||||
KernelSpec("m_128x128_k32_v3", 128, 128, 32, "compv3", "intrawave"),
|
||||
KernelSpec("m_128x128_k64_v3", 128, 128, 64, "compv3", "intrawave"),
|
||||
KernelSpec("m_128x128_k128_v3", 128, 128, 128, "compv3", "intrawave"),
|
||||
KernelSpec("m_128x128_k64_v4", 128, 128, 64, "compv4", "intrawave"),
|
||||
KernelSpec("m_128x128_k128_v4", 128, 128, 128, "compv4", "intrawave"),
|
||||
KernelSpec("m_128x128_k64_mem", 128, 128, 64, "mem", "intrawave"),
|
||||
KernelSpec("m_128x128_k128_mem", 128, 128, 128, "mem", "intrawave"),
|
||||
# Rectangular medium (M != N)
|
||||
KernelSpec(
|
||||
"r_64x128_k32_v3", 64, 128, 32, "compv3", "intrawave", 2, 2, 1, 16, 32, 16
|
||||
),
|
||||
KernelSpec(
|
||||
"r_128x64_k32_v3", 128, 64, 32, "compv3", "intrawave", 2, 2, 1, 32, 16, 16
|
||||
),
|
||||
KernelSpec(
|
||||
"r_64x128_k64_v3", 64, 128, 64, "compv3", "intrawave", 2, 2, 1, 16, 32, 16
|
||||
),
|
||||
KernelSpec(
|
||||
"r_128x64_k64_v3", 128, 64, 64, "compv3", "intrawave", 2, 2, 1, 32, 16, 16
|
||||
),
|
||||
KernelSpec(
|
||||
"r_64x256_k32_v3", 64, 256, 32, "compv3", "intrawave", 2, 2, 1, 16, 32, 16
|
||||
),
|
||||
KernelSpec(
|
||||
"r_256x64_k32_v3", 256, 64, 32, "compv3", "intrawave", 2, 2, 1, 32, 16, 16
|
||||
),
|
||||
# Large tiles (256x256)
|
||||
KernelSpec("l_256x128_k32_v3", 256, 128, 32, "compv3", "intrawave"),
|
||||
KernelSpec("l_128x256_k32_v3", 128, 256, 32, "compv3", "intrawave"),
|
||||
KernelSpec("l_256x256_k32_v3", 256, 256, 32, "compv3", "intrawave"),
|
||||
KernelSpec("l_256x256_k64_v3", 256, 256, 64, "compv3", "intrawave"),
|
||||
KernelSpec("l_256x256_k64_v4", 256, 256, 64, "compv4", "intrawave"),
|
||||
# Interwave variants
|
||||
KernelSpec("m_128x128_k64_iw_v3", 128, 128, 64, "compv3", "interwave"),
|
||||
KernelSpec("m_128x128_k128_iw_v3", 128, 128, 128, "compv3", "interwave"),
|
||||
KernelSpec("l_256x256_k32_iw_v3", 256, 256, 32, "compv3", "interwave"),
|
||||
]
|
||||
|
||||
|
||||
def generate_problem_shapes(num_shapes: int = 1024) -> List[Tuple[int, int, int]]:
|
||||
"""
|
||||
Generate diverse problem shapes with hardware constraints:
|
||||
- M >= 1 (any M is valid, including tiny M for inference)
|
||||
- N % 8 == 0 AND N >= 64 (hardware alignment requirement)
|
||||
- K % 2 == 0 AND K >= 32 (fp16 requirement)
|
||||
|
||||
Covers:
|
||||
- Powers of 2 (square and rectangular)
|
||||
- ML workloads (LLM attention, MLP, batch inference)
|
||||
- Non-power-of-2 dimensions (aligned to constraints)
|
||||
- Edge cases (tiny M, very large matrices, extreme aspect ratios)
|
||||
"""
|
||||
shapes = []
|
||||
|
||||
# 1. Powers of 2 - Square (64 to 8192) with K variations
|
||||
for p in range(6, 14): # 2^6=64 to 2^13=8192
|
||||
dim = 2**p
|
||||
shapes.append((dim, dim, dim))
|
||||
if dim >= 128:
|
||||
# K variations (must be even and >= 32)
|
||||
shapes.append((dim, dim, dim // 2))
|
||||
shapes.append((dim, dim, dim * 2))
|
||||
shapes.append((dim, dim, max(32, dim // 4)))
|
||||
|
||||
# 2. Small batch inference (1-256 batch, common hidden dims)
|
||||
# N must be multiple of 8 and >= 64
|
||||
hidden_dims = [768, 1024, 2048, 3072, 4096, 5120, 8192, 11008, 12288, 16384]
|
||||
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256]
|
||||
|
||||
for hidden in hidden_dims:
|
||||
for batch in batch_sizes[:8]:
|
||||
shapes.append((batch, hidden, hidden))
|
||||
if hidden >= 4096:
|
||||
# LLM MLP projections (ensure K is even)
|
||||
k_mlp = hidden * 3 // 4
|
||||
if k_mlp % 2 == 1:
|
||||
k_mlp += 1 # Make even
|
||||
if k_mlp >= 32:
|
||||
shapes.append((batch, hidden, k_mlp))
|
||||
shapes.append((batch, k_mlp, hidden))
|
||||
|
||||
# 3. Attention patterns (seq_len x head_dim)
|
||||
# seq_len can be any value >= 1, total_dim must be multiple of 8
|
||||
seq_lens = [128, 256, 512, 1024, 2048, 4096, 8192]
|
||||
head_dims = [64, 80, 96, 128, 256]
|
||||
num_heads = [8, 12, 16, 32, 40, 64]
|
||||
|
||||
for seq in seq_lens:
|
||||
for head_dim in head_dims:
|
||||
for nh in num_heads[:4]:
|
||||
total_dim = nh * head_dim
|
||||
# total_dim should be multiple of 8 (naturally satisfied for most cases)
|
||||
if total_dim % 8 == 0 and total_dim >= 64:
|
||||
# head_dim must be even for K
|
||||
if head_dim % 2 == 0 and head_dim >= 32:
|
||||
shapes.append((seq, total_dim, head_dim))
|
||||
shapes.append((seq, head_dim, total_dim))
|
||||
|
||||
# 4. Rectangular matrices (extreme aspect ratios)
|
||||
# All dims must satisfy constraints
|
||||
dims_m = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||
dims_n = [64, 128, 256, 512, 1024, 2048, 4096, 8192] # N >= 64, N % 8 == 0
|
||||
dims_k = [
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
4096,
|
||||
8192,
|
||||
16384,
|
||||
] # K >= 32, K % 2 == 0
|
||||
|
||||
# Sample to avoid explosion
|
||||
for i, m in enumerate(dims_m):
|
||||
for j, n in enumerate(dims_n):
|
||||
for _l, k in enumerate(dims_k):
|
||||
if (i + j + _l) % 3 == 0: # Stratified sampling
|
||||
shapes.append((m, n, k))
|
||||
|
||||
# 5. Non-power-of-2 dimensions (aligned to constraints)
|
||||
# N values: multiples of 8, >= 64
|
||||
non_pow2_n = [
|
||||
72,
|
||||
80,
|
||||
88,
|
||||
96,
|
||||
104,
|
||||
112,
|
||||
120,
|
||||
136,
|
||||
144,
|
||||
152,
|
||||
160,
|
||||
176,
|
||||
184,
|
||||
192,
|
||||
200,
|
||||
224,
|
||||
240,
|
||||
272,
|
||||
288,
|
||||
304,
|
||||
320,
|
||||
336,
|
||||
352,
|
||||
368,
|
||||
384,
|
||||
400,
|
||||
416,
|
||||
448,
|
||||
480,
|
||||
544,
|
||||
576,
|
||||
640,
|
||||
672,
|
||||
704,
|
||||
736,
|
||||
768,
|
||||
800,
|
||||
832,
|
||||
896,
|
||||
960,
|
||||
1088,
|
||||
1152,
|
||||
1216,
|
||||
1280,
|
||||
1344,
|
||||
1408,
|
||||
1472,
|
||||
1536,
|
||||
1600,
|
||||
1664,
|
||||
1728,
|
||||
1792,
|
||||
1856,
|
||||
1920,
|
||||
2176,
|
||||
2304,
|
||||
2432,
|
||||
2560,
|
||||
2688,
|
||||
2816,
|
||||
2944,
|
||||
3072,
|
||||
3200,
|
||||
3328,
|
||||
3456,
|
||||
3584,
|
||||
3712,
|
||||
3840,
|
||||
3968,
|
||||
4224,
|
||||
4352,
|
||||
4480,
|
||||
4608,
|
||||
4736,
|
||||
4864,
|
||||
4992,
|
||||
]
|
||||
|
||||
# K values: even numbers >= 32
|
||||
non_pow2_k = [
|
||||
34,
|
||||
36,
|
||||
38,
|
||||
40,
|
||||
42,
|
||||
44,
|
||||
48,
|
||||
50,
|
||||
52,
|
||||
56,
|
||||
60,
|
||||
66,
|
||||
68,
|
||||
72,
|
||||
76,
|
||||
80,
|
||||
88,
|
||||
96,
|
||||
100,
|
||||
112,
|
||||
120,
|
||||
136,
|
||||
144,
|
||||
160,
|
||||
176,
|
||||
192,
|
||||
224,
|
||||
240,
|
||||
272,
|
||||
288,
|
||||
320,
|
||||
352,
|
||||
384,
|
||||
416,
|
||||
448,
|
||||
480,
|
||||
544,
|
||||
576,
|
||||
640,
|
||||
672,
|
||||
704,
|
||||
768,
|
||||
800,
|
||||
832,
|
||||
896,
|
||||
960,
|
||||
1088,
|
||||
1152,
|
||||
1280,
|
||||
1344,
|
||||
1408,
|
||||
1536,
|
||||
1600,
|
||||
1664,
|
||||
1792,
|
||||
1920,
|
||||
]
|
||||
|
||||
# M values: any value >= 1
|
||||
non_pow2_m = [
|
||||
1,
|
||||
3,
|
||||
5,
|
||||
7,
|
||||
9,
|
||||
11,
|
||||
13,
|
||||
15,
|
||||
17,
|
||||
19,
|
||||
23,
|
||||
27,
|
||||
31,
|
||||
33,
|
||||
37,
|
||||
41,
|
||||
47,
|
||||
51,
|
||||
57,
|
||||
63,
|
||||
65,
|
||||
71,
|
||||
79,
|
||||
87,
|
||||
95,
|
||||
97,
|
||||
111,
|
||||
119,
|
||||
127,
|
||||
129,
|
||||
143,
|
||||
159,
|
||||
175,
|
||||
191,
|
||||
193,
|
||||
223,
|
||||
239,
|
||||
255,
|
||||
257,
|
||||
287,
|
||||
319,
|
||||
351,
|
||||
383,
|
||||
385,
|
||||
447,
|
||||
479,
|
||||
511,
|
||||
513,
|
||||
575,
|
||||
639,
|
||||
703,
|
||||
767,
|
||||
769,
|
||||
895,
|
||||
959,
|
||||
1023,
|
||||
1025,
|
||||
]
|
||||
|
||||
# Sample non-power-of-2 shapes
|
||||
for i, m in enumerate(non_pow2_m[:30]):
|
||||
for j, n in enumerate(non_pow2_n[:20]):
|
||||
for _l, k in enumerate(non_pow2_k[:15]):
|
||||
if (i + j + _l) % 4 == 0: # Stratified sampling
|
||||
shapes.append((m, n, k))
|
||||
|
||||
# 6. Very tall K (memory-bound) - ensure N % 8 == 0, K % 2 == 0
|
||||
for mn in [64, 128, 256, 512, 1024]:
|
||||
for k in [4096, 8192, 16384]:
|
||||
shapes.append((mn, mn, k))
|
||||
|
||||
# 7. Very short K (compute-bound) - ensure K >= 32, K % 2 == 0
|
||||
for mn in [512, 1024, 2048, 4096]:
|
||||
for k in [32, 64, 128]:
|
||||
shapes.append((mn, mn, k))
|
||||
|
||||
# 8. Tiny M (edge cases for batch-1 inference)
|
||||
for m in [1, 2, 4, 8, 16, 32]:
|
||||
for n in [64, 128, 256, 512, 1024, 2048]: # N >= 64, N % 8 == 0
|
||||
for k in [32, 64, 128, 256, 512]: # K >= 32, K % 2 == 0
|
||||
shapes.append((m, n, k))
|
||||
|
||||
# 9. Stress test sizes (aligned to constraints)
|
||||
stress_sizes = [
|
||||
(10000, 10000, 10000),
|
||||
(1000, 10000, 1000),
|
||||
(1000, 1000, 10000),
|
||||
(5000, 5000, 5000),
|
||||
(7168, 7168, 7168), # Common LLM hidden dim
|
||||
(8192, 11008, 8192), # LLaMA MLP dimensions
|
||||
]
|
||||
shapes.extend(stress_sizes)
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
seen = set()
|
||||
unique_shapes = []
|
||||
for s in shapes:
|
||||
if s not in seen:
|
||||
seen.add(s)
|
||||
unique_shapes.append(s)
|
||||
|
||||
# Filter to ensure all shapes meet constraints
|
||||
valid_shapes = []
|
||||
for m, n, k in unique_shapes:
|
||||
if m >= 1 and n >= 64 and n % 8 == 0 and k >= 32 and k % 2 == 0:
|
||||
valid_shapes.append((m, n, k))
|
||||
|
||||
# Sample down to target number if we have too many
|
||||
if len(valid_shapes) > num_shapes:
|
||||
# Stratified sampling to preserve diversity
|
||||
step = len(valid_shapes) / num_shapes
|
||||
valid_shapes = [valid_shapes[int(i * step)] for i in range(num_shapes)]
|
||||
|
||||
return valid_shapes
|
||||
|
||||
|
||||
def spec_to_feature_dict(spec: KernelSpec, dtype: str, layout: str) -> dict:
|
||||
"""Convert KernelSpec to feature dict for ML predictor"""
|
||||
return {
|
||||
"kernel_name": spec.name,
|
||||
"tile_m": spec.tile_m,
|
||||
"tile_n": spec.tile_n,
|
||||
"tile_k": spec.tile_k,
|
||||
"warp_m": spec.wave_m,
|
||||
"warp_n": spec.wave_n,
|
||||
"warp_k": spec.wave_k,
|
||||
"warp_tile_m": spec.warp_m,
|
||||
"warp_tile_n": spec.warp_n,
|
||||
"warp_tile_k": spec.warp_k,
|
||||
"pipeline": spec.pipeline,
|
||||
"scheduler": spec.scheduler,
|
||||
"epilogue": "cshuffle",
|
||||
"pad_m": True, # Enable padding to support arbitrary M dimensions
|
||||
"pad_n": True, # Enable padding to support arbitrary N dimensions
|
||||
"pad_k": True, # Enable padding to support arbitrary K dimensions
|
||||
"persistent": False,
|
||||
"dtype": dtype,
|
||||
"layout": layout,
|
||||
}
|
||||
|
||||
|
||||
def spec_to_kernel_config(
|
||||
spec: KernelSpec, dtype: str, arch: str, dtype_acc: str = "fp32"
|
||||
) -> KernelConfig:
|
||||
"""Convert KernelSpec to KernelConfig for dispatcher"""
|
||||
return KernelConfig(
|
||||
dtype_a=dtype,
|
||||
dtype_b=dtype,
|
||||
dtype_c=dtype,
|
||||
dtype_acc=dtype_acc,
|
||||
layout_a="row",
|
||||
layout_b="col",
|
||||
layout_c="row",
|
||||
tile_m=spec.tile_m,
|
||||
tile_n=spec.tile_n,
|
||||
tile_k=spec.tile_k,
|
||||
wave_m=spec.wave_m,
|
||||
wave_n=spec.wave_n,
|
||||
wave_k=spec.wave_k,
|
||||
warp_m=spec.warp_m,
|
||||
warp_n=spec.warp_n,
|
||||
warp_k=spec.warp_k,
|
||||
pipeline=spec.pipeline,
|
||||
scheduler=spec.scheduler,
|
||||
epilogue="cshuffle",
|
||||
gfx_arch=arch,
|
||||
)
|
||||
|
||||
|
||||
def ml_select_kernel(
|
||||
predictor, pool: List[KernelSpec], M: int, N: int, K: int, dtype: str, layout: str
|
||||
) -> Tuple[KernelSpec, float]:
|
||||
"""Use ML model to select best kernel"""
|
||||
if not HAS_ML or predictor is None:
|
||||
# Fallback: select first kernel
|
||||
return pool[0], 0.0
|
||||
|
||||
problem = {"m": M, "n": N, "k": K, "dtype": dtype, "layout": layout, "split_k": 1}
|
||||
kernel_dicts = [spec_to_feature_dict(s, dtype, layout) for s in pool]
|
||||
|
||||
ranked = predictor.rank_kernels(problem, kernel_dicts)
|
||||
if not ranked:
|
||||
return pool[0], 0.0
|
||||
|
||||
best_name, best_tflops = ranked[0]
|
||||
best_spec = next((s for s in pool if s.name == best_name), pool[0])
|
||||
return best_spec, best_tflops
|
||||
|
||||
|
||||
def run_single_gemm(
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
dtype: str,
|
||||
arch: str,
|
||||
predictor,
|
||||
dry_run: bool = False,
|
||||
dtype_acc: str = "fp32",
|
||||
) -> dict:
|
||||
"""Run a single GEMM with ML heuristic selection"""
|
||||
|
||||
# Select kernel via ML heuristic
|
||||
t0 = time.time()
|
||||
best_spec, pred_tflops = ml_select_kernel(
|
||||
predictor, KERNEL_POOL, M, N, K, dtype, "rcr"
|
||||
)
|
||||
select_time_ms = (time.time() - t0) * 1000
|
||||
|
||||
result = {
|
||||
"M": M,
|
||||
"N": N,
|
||||
"K": K,
|
||||
"dtype": dtype,
|
||||
"selected_kernel": best_spec.name,
|
||||
"predicted_tflops": pred_tflops,
|
||||
"selection_time_ms": select_time_ms,
|
||||
"actual_time_ms": 0,
|
||||
"actual_tflops": 0,
|
||||
"status": "SKIP" if dry_run else "PENDING",
|
||||
"error": None,
|
||||
}
|
||||
|
||||
if dry_run:
|
||||
return result
|
||||
|
||||
# Build and run kernel
|
||||
config = spec_to_kernel_config(best_spec, dtype, arch, dtype_acc)
|
||||
|
||||
try:
|
||||
setup = setup_gemm_dispatcher(
|
||||
config=config,
|
||||
registry_name=f"sweep_{dtype}_{best_spec.name}",
|
||||
verbose=False,
|
||||
auto_rebuild=True,
|
||||
)
|
||||
|
||||
if not setup.success:
|
||||
result["status"] = "BUILD_FAIL"
|
||||
result["error"] = "Failed to build kernel"
|
||||
cleanup_gemm()
|
||||
return result
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
result["status"] = "UNSUPPORTED"
|
||||
result["error"] = "Problem size not supported by kernel"
|
||||
cleanup_gemm()
|
||||
return result
|
||||
|
||||
# Create input data
|
||||
np_dtype = {"fp16": np.float16, "bf16": np.float16, "fp8": np.float16}[dtype]
|
||||
np.random.seed(42)
|
||||
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
|
||||
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
|
||||
|
||||
# Run GEMM
|
||||
exec_result = dispatcher.run(A, B, M, N, K)
|
||||
|
||||
if exec_result.success:
|
||||
result["actual_time_ms"] = exec_result.time_ms
|
||||
result["actual_tflops"] = exec_result.tflops
|
||||
result["status"] = "SUCCESS"
|
||||
else:
|
||||
# Decode status code for better error message
|
||||
status_messages = {
|
||||
0: "Success",
|
||||
-1: "GPU/HIP error (check permissions, memory, or kernel validity)",
|
||||
-2: "No suitable kernel found for this problem size",
|
||||
}
|
||||
error_msg = status_messages.get(exec_result.status, f"Unknown error (status={exec_result.status})")
|
||||
result["status"] = "RUN_FAIL"
|
||||
result["error"] = f"{error_msg} (status_code={exec_result.status})"
|
||||
|
||||
# Print detailed error for debugging
|
||||
print(f" ERROR: {error_msg}")
|
||||
print(f" Status code: {exec_result.status}")
|
||||
print(f" Time returned: {exec_result.time_ms}")
|
||||
print(f" Kernel: {exec_result.kernel_name}")
|
||||
|
||||
cleanup_gemm()
|
||||
|
||||
except Exception as e:
|
||||
result["status"] = "ERROR"
|
||||
result["error"] = str(e)[:200]
|
||||
cleanup_gemm()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="ML Heuristic Sweep: Test GEMM across many shapes and dtypes"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtypes",
|
||||
nargs="+",
|
||||
default=["fp16"],
|
||||
choices=["fp16", "bf16", "fp8"],
|
||||
help="Data types to test (default: fp16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch", default="gfx950", help="GPU architecture (default: gfx950)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype_acc",
|
||||
default="fp32",
|
||||
choices=["fp16", "fp32"],
|
||||
help="Accumulator data type (default: fp32)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_dir",
|
||||
default=None,
|
||||
help="Path to model directory (auto-detect if not specified)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_shapes",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Number of problem shapes to test (default: 256)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default="ml_heuristic_sweep_results.csv",
|
||||
help="Output CSV file path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry_run",
|
||||
action="store_true",
|
||||
help="Only predict, do not run kernels (fast validation)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup ML predictor
|
||||
predictor = None
|
||||
if HAS_ML:
|
||||
if args.model_dir is None:
|
||||
# Auto-detect model directory based on first dtype
|
||||
first_dtype = args.dtypes[0]
|
||||
heuristics_dir = Path(__file__).parent
|
||||
model_candidates = [
|
||||
heuristics_dir / "models" / f"gemm_universal_{first_dtype}_{args.arch}",
|
||||
]
|
||||
for model_dir in model_candidates:
|
||||
if model_dir.exists():
|
||||
args.model_dir = str(model_dir)
|
||||
break
|
||||
|
||||
if args.model_dir and Path(args.model_dir).exists():
|
||||
try:
|
||||
predictor = Predictor(args.model_dir)
|
||||
print(f"✓ Loaded ML model from: {args.model_dir}")
|
||||
except Exception as e:
|
||||
print(f"⚠ Failed to load ML model: {e}")
|
||||
print(" Will use first-fit selection instead")
|
||||
else:
|
||||
print(f"⚠ Model directory not found: {args.model_dir}")
|
||||
print(" Will use first-fit selection instead")
|
||||
|
||||
# Generate problem shapes
|
||||
print(f"\nGenerating {args.num_shapes} problem shapes...")
|
||||
shapes = generate_problem_shapes(args.num_shapes)
|
||||
print(
|
||||
f"✓ Generated {len(shapes)} valid shapes (M>=1, N%8==0, N>=64, K%2==0, K>=32)"
|
||||
)
|
||||
|
||||
# Validate all shapes meet constraints
|
||||
invalid = [
|
||||
(m, n, k)
|
||||
for m, n, k in shapes
|
||||
if not (m >= 1 and n >= 64 and n % 8 == 0 and k >= 32 and k % 2 == 0)
|
||||
]
|
||||
if invalid:
|
||||
print(f"⚠ WARNING: {len(invalid)} shapes violate constraints!")
|
||||
print(f" First few: {invalid[:5]}")
|
||||
|
||||
# Print configuration
|
||||
print("\n" + "=" * 80)
|
||||
print(" ML Heuristic Sweep Configuration")
|
||||
print("=" * 80)
|
||||
print(
|
||||
f" Model: {args.model_dir if args.model_dir else 'first-fit (no ML)'}"
|
||||
)
|
||||
print(f" Data types: {', '.join(args.dtypes)}")
|
||||
print(f" Accumulator: {args.dtype_acc}")
|
||||
print(f" Architecture: {args.arch}")
|
||||
print(f" Kernel pool: {len(KERNEL_POOL)} kernels")
|
||||
print(f" Problem shapes: {len(shapes)}")
|
||||
print(f" Total tests: {len(shapes) * len(args.dtypes)}")
|
||||
print(
|
||||
f" Mode: {'DRY RUN (prediction only)' if args.dry_run else 'FULL RUN (execute kernels)'}"
|
||||
)
|
||||
print(f" Output: {args.output}")
|
||||
print("=" * 80)
|
||||
|
||||
# Open output CSV
|
||||
csv_file = open(args.output, "w", newline="")
|
||||
csv_writer = csv.DictWriter(
|
||||
csv_file,
|
||||
fieldnames=[
|
||||
"dtype",
|
||||
"M",
|
||||
"N",
|
||||
"K",
|
||||
"selected_kernel",
|
||||
"predicted_tflops",
|
||||
"selection_time_ms",
|
||||
"actual_time_ms",
|
||||
"actual_tflops",
|
||||
"status",
|
||||
"error",
|
||||
],
|
||||
)
|
||||
csv_writer.writeheader()
|
||||
|
||||
# Run sweep
|
||||
total_tests = len(shapes) * len(args.dtypes)
|
||||
completed = 0
|
||||
start_time = time.time()
|
||||
|
||||
print("\nStarting sweep... (Ctrl+C to stop and save partial results)\n")
|
||||
|
||||
try:
|
||||
for dtype in args.dtypes:
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f" Testing dtype: {dtype.upper()}")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
for i, (M, N, K) in enumerate(shapes):
|
||||
result = run_single_gemm(
|
||||
M, N, K, dtype, args.arch, predictor, args.dry_run, args.dtype_acc
|
||||
)
|
||||
|
||||
# Write to CSV
|
||||
csv_writer.writerow(result)
|
||||
csv_file.flush()
|
||||
|
||||
completed += 1
|
||||
|
||||
# Progress update
|
||||
if completed % 10 == 0 or result["status"] != "SUCCESS":
|
||||
elapsed = time.time() - start_time
|
||||
rate = completed / elapsed if elapsed > 0 else 0
|
||||
eta = (total_tests - completed) / rate if rate > 0 else 0
|
||||
|
||||
status_emoji = {
|
||||
"SUCCESS": "✓",
|
||||
"SKIP": "→",
|
||||
"BUILD_FAIL": "✗",
|
||||
"UNSUPPORTED": "○",
|
||||
"RUN_FAIL": "✗",
|
||||
"ERROR": "✗",
|
||||
}.get(result["status"], "?")
|
||||
|
||||
print(
|
||||
f" [{completed:4d}/{total_tests}] {status_emoji} "
|
||||
f"{dtype:4s} {M:5d}x{N:5d}x{K:5d} → "
|
||||
f"{result['selected_kernel']:20s} "
|
||||
f"pred={result['predicted_tflops']:6.1f} "
|
||||
f"actual={result['actual_tflops']:6.1f} TFLOPS "
|
||||
f"[{rate:.1f} tests/s, ETA {eta / 60:.1f}m]"
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n\n⚠ Interrupted! Saving partial results to {args.output}...")
|
||||
|
||||
finally:
|
||||
csv_file.close()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 80)
|
||||
print(" SWEEP COMPLETE")
|
||||
print("=" * 80)
|
||||
|
||||
# Read back results and compute statistics
|
||||
results = []
|
||||
with open(args.output, "r") as f:
|
||||
reader = csv.DictReader(f)
|
||||
results = list(reader)
|
||||
|
||||
print(f"\n Total tests: {len(results)}")
|
||||
print(f" Output file: {args.output}")
|
||||
|
||||
if not args.dry_run:
|
||||
success = [r for r in results if r["status"] == "SUCCESS"]
|
||||
print(
|
||||
f" Successful: {len(success)} ({100 * len(success) / len(results):.1f}%)"
|
||||
)
|
||||
|
||||
if success:
|
||||
avg_tflops = np.mean([float(r["actual_tflops"]) for r in success])
|
||||
max_tflops = max([float(r["actual_tflops"]) for r in success])
|
||||
print(f" Avg TFLOPS: {avg_tflops:.2f}")
|
||||
print(f" Max TFLOPS: {max_tflops:.2f}")
|
||||
|
||||
# Per-dtype breakdown
|
||||
for dtype in args.dtypes:
|
||||
dtype_results = [r for r in success if r["dtype"] == dtype]
|
||||
if dtype_results:
|
||||
avg = np.mean([float(r["actual_tflops"]) for r in dtype_results])
|
||||
print(
|
||||
f" {dtype:4s}: {avg:.2f} TFLOPS (n={len(dtype_results)})"
|
||||
)
|
||||
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -0,0 +1,113 @@
|
||||
{
|
||||
"op_type": "gemm_universal",
|
||||
"dtype": "fp16",
|
||||
"arch": "gfx950",
|
||||
"feature_names": [
|
||||
"M",
|
||||
"N",
|
||||
"K",
|
||||
"split_k",
|
||||
"log2_M",
|
||||
"log2_N",
|
||||
"log2_K",
|
||||
"log2_MNK",
|
||||
"arithmetic_intensity",
|
||||
"aspect_ratio_mn",
|
||||
"aspect_ratio_mk",
|
||||
"aspect_ratio_nk",
|
||||
"layout",
|
||||
"tile_m",
|
||||
"tile_n",
|
||||
"tile_k",
|
||||
"warp_m",
|
||||
"warp_n",
|
||||
"warp_k",
|
||||
"warp_tile_m",
|
||||
"warp_tile_n",
|
||||
"warp_tile_k",
|
||||
"pipeline",
|
||||
"scheduler",
|
||||
"epilogue",
|
||||
"pad_m",
|
||||
"pad_n",
|
||||
"pad_k",
|
||||
"persistent",
|
||||
"num_warps",
|
||||
"tile_volume",
|
||||
"tile_mn",
|
||||
"lds_usage_estimate",
|
||||
"lds_usage_ratio",
|
||||
"num_tiles_m",
|
||||
"num_tiles_n",
|
||||
"num_tiles_k",
|
||||
"total_output_tiles",
|
||||
"tile_eff_m",
|
||||
"tile_eff_n",
|
||||
"tile_eff_k",
|
||||
"overall_tile_efficiency",
|
||||
"cu_utilization",
|
||||
"ratio_M_to_tile_m",
|
||||
"ratio_N_to_tile_n",
|
||||
"ratio_K_to_tile_k",
|
||||
"problem_smaller_than_tile_m",
|
||||
"problem_smaller_than_tile_n",
|
||||
"problem_smaller_than_tile_k",
|
||||
"any_dim_too_small",
|
||||
"needs_padding_m",
|
||||
"needs_padding_n",
|
||||
"needs_padding_k",
|
||||
"has_padding_when_needed_m",
|
||||
"has_padding_when_needed_n",
|
||||
"has_padding_when_needed_k",
|
||||
"missing_required_padding_m",
|
||||
"missing_required_padding_n",
|
||||
"missing_required_padding_k",
|
||||
"missing_any_required_padding",
|
||||
"hw_num_cus",
|
||||
"hw_simds_per_cu",
|
||||
"hw_total_simds",
|
||||
"hw_shader_engines",
|
||||
"hw_max_clock_mhz",
|
||||
"hw_max_waves_per_cu",
|
||||
"hw_wavefront_size",
|
||||
"hw_lds_capacity",
|
||||
"hw_l1_cache_kb",
|
||||
"hw_l2_cache_kb",
|
||||
"hw_l3_cache_kb",
|
||||
"hw_num_xcd"
|
||||
],
|
||||
"categorical_features": [
|
||||
"layout",
|
||||
"pipeline",
|
||||
"scheduler",
|
||||
"epilogue"
|
||||
],
|
||||
"targets": [
|
||||
"tflops",
|
||||
"latency",
|
||||
"bandwidth"
|
||||
],
|
||||
"log_targets": [
|
||||
"bandwidth",
|
||||
"tflops"
|
||||
],
|
||||
"params": {
|
||||
"objective": "regression",
|
||||
"metric": [
|
||||
"rmse",
|
||||
"mae"
|
||||
],
|
||||
"num_leaves": 255,
|
||||
"max_depth": 15,
|
||||
"n_estimators": 2000,
|
||||
"learning_rate": 0.02,
|
||||
"min_child_samples": 10,
|
||||
"subsample": 0.85,
|
||||
"colsample_bytree": 0.85,
|
||||
"reg_alpha": 0.05,
|
||||
"reg_lambda": 0.5,
|
||||
"verbose": -1,
|
||||
"n_jobs": 8,
|
||||
"seed": 42
|
||||
}
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"warm_start_from": null,
|
||||
"prev_n_estimators": 0,
|
||||
"new_n_estimators": 2000,
|
||||
"total_n_estimators": 2000,
|
||||
"data_rows": 25600,
|
||||
"valid_rows": 21920,
|
||||
"unique_shapes": 25,
|
||||
"timestamp": "2026-03-20T05:00:55"
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
{
|
||||
"op_type": "gemm_universal",
|
||||
"dtype": "fp8",
|
||||
"arch": "gfx950",
|
||||
"feature_names": [
|
||||
"M",
|
||||
"N",
|
||||
"K",
|
||||
"split_k",
|
||||
"log2_M",
|
||||
"log2_N",
|
||||
"log2_K",
|
||||
"log2_MNK",
|
||||
"arithmetic_intensity",
|
||||
"aspect_ratio_mn",
|
||||
"aspect_ratio_mk",
|
||||
"aspect_ratio_nk",
|
||||
"layout",
|
||||
"tile_m",
|
||||
"tile_n",
|
||||
"tile_k",
|
||||
"warp_m",
|
||||
"warp_n",
|
||||
"warp_k",
|
||||
"warp_tile_m",
|
||||
"warp_tile_n",
|
||||
"warp_tile_k",
|
||||
"pipeline",
|
||||
"scheduler",
|
||||
"epilogue",
|
||||
"pad_m",
|
||||
"pad_n",
|
||||
"pad_k",
|
||||
"persistent",
|
||||
"num_warps",
|
||||
"tile_volume",
|
||||
"tile_mn",
|
||||
"lds_usage_estimate",
|
||||
"lds_usage_ratio",
|
||||
"num_tiles_m",
|
||||
"num_tiles_n",
|
||||
"num_tiles_k",
|
||||
"total_output_tiles",
|
||||
"tile_eff_m",
|
||||
"tile_eff_n",
|
||||
"tile_eff_k",
|
||||
"overall_tile_efficiency",
|
||||
"cu_utilization",
|
||||
"ratio_M_to_tile_m",
|
||||
"ratio_N_to_tile_n",
|
||||
"ratio_K_to_tile_k",
|
||||
"problem_smaller_than_tile_m",
|
||||
"problem_smaller_than_tile_n",
|
||||
"problem_smaller_than_tile_k",
|
||||
"any_dim_too_small",
|
||||
"needs_padding_m",
|
||||
"needs_padding_n",
|
||||
"needs_padding_k",
|
||||
"has_padding_when_needed_m",
|
||||
"has_padding_when_needed_n",
|
||||
"has_padding_when_needed_k",
|
||||
"missing_required_padding_m",
|
||||
"missing_required_padding_n",
|
||||
"missing_required_padding_k",
|
||||
"missing_any_required_padding",
|
||||
"hw_num_cus",
|
||||
"hw_simds_per_cu",
|
||||
"hw_total_simds",
|
||||
"hw_shader_engines",
|
||||
"hw_max_clock_mhz",
|
||||
"hw_max_waves_per_cu",
|
||||
"hw_wavefront_size",
|
||||
"hw_lds_capacity",
|
||||
"hw_l1_cache_kb",
|
||||
"hw_l2_cache_kb",
|
||||
"hw_l3_cache_kb",
|
||||
"hw_num_xcd"
|
||||
],
|
||||
"categorical_features": [
|
||||
"layout",
|
||||
"pipeline",
|
||||
"scheduler",
|
||||
"epilogue"
|
||||
],
|
||||
"targets": [
|
||||
"tflops",
|
||||
"latency",
|
||||
"bandwidth"
|
||||
],
|
||||
"log_targets": [
|
||||
"bandwidth",
|
||||
"tflops"
|
||||
],
|
||||
"params": {
|
||||
"objective": "regression",
|
||||
"metric": [
|
||||
"rmse",
|
||||
"mae"
|
||||
],
|
||||
"num_leaves": 255,
|
||||
"max_depth": 15,
|
||||
"n_estimators": 2000,
|
||||
"learning_rate": 0.02,
|
||||
"min_child_samples": 10,
|
||||
"subsample": 0.85,
|
||||
"colsample_bytree": 0.85,
|
||||
"reg_alpha": 0.05,
|
||||
"reg_lambda": 0.5,
|
||||
"verbose": -1,
|
||||
"n_jobs": 8,
|
||||
"seed": 42
|
||||
}
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"warm_start_from": null,
|
||||
"prev_n_estimators": 0,
|
||||
"new_n_estimators": 2000,
|
||||
"total_n_estimators": 2000,
|
||||
"data_rows": 1296528,
|
||||
"valid_rows": 1253076,
|
||||
"unique_shapes": 168,
|
||||
"timestamp": "2026-03-19T06:10:29"
|
||||
}
|
||||
243
dispatcher/heuristics/predict.py
Normal file
243
dispatcher/heuristics/predict.py
Normal file
@@ -0,0 +1,243 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Predictor for CK Tile kernel performance.
|
||||
|
||||
Loads trained LightGBM models and provides:
|
||||
- predict_tflops(): predicted TFLOPS for a single (problem, kernel) pair
|
||||
- predict_latency(): predicted latency in ms
|
||||
- predict_bandwidth(): predicted bandwidth in GB/s
|
||||
- predict_all(): all three predictions at once
|
||||
- rank_kernels(): rank all candidate kernels by predicted TFLOPS
|
||||
- select_best(): return the best kernel ID
|
||||
|
||||
Usage:
|
||||
predictor = Predictor("models/gemm_universal_fp8_gfx950")
|
||||
best_kernel = predictor.select_best(
|
||||
problem={"m": 128, "n": 1536, "k": 7168, "dtype": "fp8", "layout": "rcr"},
|
||||
kernel_configs=[...],
|
||||
)
|
||||
"""
|
||||
|
||||
import gzip
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import lightgbm as lgb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from feature_engine import GemmUniversalFeatureEngine
|
||||
|
||||
|
||||
class Predictor:
|
||||
"""Loads trained models and feature spec for kernel performance prediction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_dir : str or Path
|
||||
Directory containing model artifacts:
|
||||
- model_tflops.lgbm (required)
|
||||
- model_latency.lgbm (optional)
|
||||
- model_bandwidth.lgbm (optional)
|
||||
- feature_spec.json (required)
|
||||
|
||||
feature_engine : FeatureEngine, optional
|
||||
Override the feature engine. If None, constructs one from feature_spec.json.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir: str | Path, feature_engine=None):
|
||||
self._model_dir = Path(model_dir)
|
||||
self._models: dict[str, lgb.Booster] = {}
|
||||
|
||||
spec_path = self._model_dir / "feature_spec.json"
|
||||
if spec_path.exists():
|
||||
with open(spec_path) as f:
|
||||
self._spec = json.load(f)
|
||||
else:
|
||||
self._spec = {}
|
||||
|
||||
self._log_targets = set(self._spec.get("log_targets", []))
|
||||
|
||||
if feature_engine is not None:
|
||||
self._feature_engine = feature_engine
|
||||
else:
|
||||
self._feature_engine = GemmUniversalFeatureEngine()
|
||||
|
||||
def _load_model(self, target: str) -> Optional[lgb.Booster]:
|
||||
"""Lazy-load a model for the given target.
|
||||
|
||||
Automatically decompresses .lgbm.gz files if the .lgbm file doesn't exist.
|
||||
The decompressed file is cached to disk for subsequent loads.
|
||||
"""
|
||||
if target in self._models:
|
||||
return self._models[target]
|
||||
|
||||
path = self._model_dir / f"model_{target}.lgbm"
|
||||
gz_path = self._model_dir / f"model_{target}.lgbm.gz"
|
||||
|
||||
# Auto-decompress if needed
|
||||
if not path.exists() and gz_path.exists():
|
||||
with gzip.open(gz_path, 'rb') as f_in:
|
||||
with open(path, 'wb') as f_out:
|
||||
f_out.write(f_in.read())
|
||||
|
||||
if not path.exists():
|
||||
return None
|
||||
|
||||
model = lgb.Booster(model_file=str(path))
|
||||
self._models[target] = model
|
||||
return model
|
||||
|
||||
def _predict_single(self, target: str, problem: dict, kernel_config: dict) -> float:
|
||||
"""Predict a single target value, applying inverse log transform if needed."""
|
||||
model = self._load_model(target)
|
||||
if model is None:
|
||||
raise FileNotFoundError(f"No model_{target}.lgbm in {self._model_dir}")
|
||||
features = self._feature_engine.extract(problem, kernel_config)
|
||||
raw = float(model.predict(features.reshape(1, -1))[0])
|
||||
if target in self._log_targets:
|
||||
return float(np.expm1(raw))
|
||||
# Clamp to non-negative even for non-log models
|
||||
return float(max(0.0, raw))
|
||||
|
||||
def predict_tflops(self, problem: dict, kernel_config: dict) -> float:
|
||||
"""Predict TFLOPS for a single (problem, kernel) pair.
|
||||
|
||||
Returns a real TFLOPS estimate (interpretable, usable as DE surrogate).
|
||||
If the model was trained in log-space, the inverse transform is applied
|
||||
automatically.
|
||||
"""
|
||||
return self._predict_single("tflops", problem, kernel_config)
|
||||
|
||||
def predict_latency(self, problem: dict, kernel_config: dict) -> float:
|
||||
"""Predict latency in milliseconds for a single (problem, kernel) pair."""
|
||||
return self._predict_single("latency", problem, kernel_config)
|
||||
|
||||
def predict_bandwidth(self, problem: dict, kernel_config: dict) -> float:
|
||||
"""Predict bandwidth in GB/s for a single (problem, kernel) pair."""
|
||||
return self._predict_single("bandwidth", problem, kernel_config)
|
||||
|
||||
def predict_all(self, problem: dict, kernel_config: dict) -> dict[str, float]:
|
||||
"""Predict all available targets for a single (problem, kernel) pair.
|
||||
|
||||
Returns dict with keys 'tflops', 'latency_ms', 'bandwidth_gb_s' (if models exist).
|
||||
|
||||
Note: Applies inverse log transform for targets in log_targets and clamps
|
||||
negatives to 0.0, consistent with _predict_single().
|
||||
"""
|
||||
features = self._feature_engine.extract(problem, kernel_config).reshape(1, -1)
|
||||
result = {}
|
||||
for target, key in [
|
||||
("tflops", "tflops"),
|
||||
("latency", "latency_ms"),
|
||||
("bandwidth", "bandwidth_gb_s"),
|
||||
]:
|
||||
model = self._load_model(target)
|
||||
if model is not None:
|
||||
raw = float(model.predict(features)[0])
|
||||
# Apply inverse log transform if model was trained in log-space
|
||||
if target in self._log_targets:
|
||||
result[key] = float(np.expm1(raw))
|
||||
else:
|
||||
# Clamp to non-negative even for non-log models
|
||||
result[key] = float(max(0.0, raw))
|
||||
return result
|
||||
|
||||
def rank_kernels(
|
||||
self, problem: dict, kernel_configs: list[dict]
|
||||
) -> list[tuple[str, float]]:
|
||||
"""Rank candidate kernels by predicted TFLOPS (descending).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
problem : dict
|
||||
Problem specification with keys: m, n, k, dtype, layout, split_k.
|
||||
kernel_configs : list of dict
|
||||
Each dict must have a 'kernel_name' key plus kernel parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of (kernel_name, predicted_tflops) tuples, sorted descending.
|
||||
"""
|
||||
if not kernel_configs:
|
||||
return []
|
||||
|
||||
model = self._load_model("tflops")
|
||||
if model is None:
|
||||
raise FileNotFoundError(f"No model_tflops.lgbm in {self._model_dir}")
|
||||
|
||||
rows = []
|
||||
for kc in kernel_configs:
|
||||
merged = {**problem, **kc}
|
||||
rows.append(merged)
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
X = self._feature_engine.extract_batch(df)
|
||||
preds = model.predict(X)
|
||||
if "tflops" in self._log_targets:
|
||||
preds = np.expm1(preds)
|
||||
|
||||
results = []
|
||||
for i, kc in enumerate(kernel_configs):
|
||||
name = kc.get("kernel_name", f"kernel_{i}")
|
||||
results.append((name, float(preds[i])))
|
||||
|
||||
results.sort(key=lambda x: -x[1])
|
||||
return results
|
||||
|
||||
def select_best(self, problem: dict, kernel_configs: list[dict]) -> str:
|
||||
"""Return the kernel_name of the best predicted kernel."""
|
||||
ranked = self.rank_kernels(problem, kernel_configs)
|
||||
if not ranked:
|
||||
raise ValueError("No kernel configs provided")
|
||||
return ranked[0][0]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Predict kernel performance")
|
||||
parser.add_argument(
|
||||
"--model_dir", required=True, help="Directory with trained models"
|
||||
)
|
||||
parser.add_argument("--m", type=int, required=True)
|
||||
parser.add_argument("--n", type=int, required=True)
|
||||
parser.add_argument("--k", type=int, required=True)
|
||||
parser.add_argument("--layout", default="rcr")
|
||||
parser.add_argument("--dtype", default="fp8")
|
||||
args = parser.parse_args()
|
||||
|
||||
predictor = Predictor(args.model_dir)
|
||||
problem = {
|
||||
"m": args.m,
|
||||
"n": args.n,
|
||||
"k": args.k,
|
||||
"dtype": args.dtype,
|
||||
"layout": args.layout,
|
||||
"split_k": 1,
|
||||
}
|
||||
|
||||
print(f"Loading models from {args.model_dir}...")
|
||||
print(
|
||||
f"Problem: M={args.m} N={args.n} K={args.k} dtype={args.dtype} layout={args.layout}"
|
||||
)
|
||||
|
||||
data_dir = Path(args.model_dir).parent.parent / "data"
|
||||
if data_dir.exists():
|
||||
for pq in data_dir.glob("*.parquet"):
|
||||
df = pd.read_parquet(pq)
|
||||
kernel_names = df["kernel_name"].unique()
|
||||
configs = []
|
||||
for kn in kernel_names[:10]:
|
||||
row = df[df["kernel_name"] == kn].iloc[0]
|
||||
configs.append(row.to_dict())
|
||||
if configs:
|
||||
ranked = predictor.rank_kernels(problem, configs)
|
||||
print(f"\nTop 5 kernels (from {len(configs)} candidates):")
|
||||
for name, tflops in ranked[:5]:
|
||||
print(f" {tflops:8.2f} TFLOPS {name}")
|
||||
break
|
||||
272
dispatcher/heuristics/search.py
Normal file
272
dispatcher/heuristics/search.py
Normal file
@@ -0,0 +1,272 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Surrogate search for CK Tile kernel configuration optimization.
|
||||
|
||||
Uses a trained LGBMRegressor as a cheap surrogate function to search the
|
||||
discrete kernel parameter space (tile sizes, warp config, pipeline, etc.)
|
||||
without running actual GPU benchmarks.
|
||||
|
||||
Strategies:
|
||||
- 'random': Sample N random valid configs, score all, return top-K.
|
||||
- 'de': Discrete Differential Evolution with mutation over valid parameter choices.
|
||||
|
||||
Usage:
|
||||
from search import SurrogateSearch
|
||||
from predict import Predictor
|
||||
|
||||
predictor = Predictor("models/gemm_universal_fp8_gfx950")
|
||||
searcher = SurrogateSearch(predictor, strategy='random')
|
||||
results = searcher.search(
|
||||
problem={"m": 128, "n": 1536, "k": 7168, "dtype": "fp8", "layout": "rcr"},
|
||||
budget=500,
|
||||
)
|
||||
# results: [(config_dict, predicted_tflops), ...] sorted descending
|
||||
"""
|
||||
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from feature_engine import GemmUniversalFeatureEngine
|
||||
from predict import Predictor
|
||||
|
||||
|
||||
class SurrogateSearch:
|
||||
"""Search kernel parameter space using ML regressor as surrogate objective.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
predictor : Predictor
|
||||
Trained predictor with a TFLOPS model.
|
||||
feature_engine : GemmUniversalFeatureEngine, optional
|
||||
Feature engine for parameter space and validation. If None, uses default.
|
||||
strategy : str
|
||||
Search strategy: 'random' or 'de' (Discrete Differential Evolution).
|
||||
seed : int
|
||||
Random seed for reproducibility.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
predictor: Predictor,
|
||||
feature_engine: Optional[GemmUniversalFeatureEngine] = None,
|
||||
strategy: str = "random",
|
||||
seed: int = 42,
|
||||
):
|
||||
self._predictor = predictor
|
||||
self._fe = feature_engine or GemmUniversalFeatureEngine()
|
||||
self._strategy = strategy
|
||||
self._rng = random.Random(seed)
|
||||
self._np_rng = np.random.RandomState(seed)
|
||||
self._param_space = self._fe.get_parameter_space()
|
||||
|
||||
def _sample_random_config(self) -> dict:
|
||||
"""Sample a single random config from the parameter space."""
|
||||
config = {}
|
||||
for param, values in self._param_space.items():
|
||||
config[param] = self._rng.choice(values)
|
||||
return config
|
||||
|
||||
def _sample_valid_config(self, max_attempts: int = 50) -> Optional[dict]:
|
||||
"""Sample a random config that passes all validation constraints."""
|
||||
for _ in range(max_attempts):
|
||||
config = self._sample_random_config()
|
||||
if self._fe.validate_config(config):
|
||||
return config
|
||||
return None
|
||||
|
||||
def _score_config(self, problem: dict, config: dict) -> float:
|
||||
"""Score a config using the predictor."""
|
||||
return self._predictor.predict_tflops(problem, config)
|
||||
|
||||
def _search_random(
|
||||
self, problem: dict, budget: int, top_k: int
|
||||
) -> list[tuple[dict, float]]:
|
||||
"""Random search: sample valid configs, score all, return top-K."""
|
||||
configs = []
|
||||
for _ in range(budget):
|
||||
cfg = self._sample_valid_config()
|
||||
if cfg is not None:
|
||||
configs.append(cfg)
|
||||
|
||||
if not configs:
|
||||
return []
|
||||
|
||||
scored = []
|
||||
for cfg in configs:
|
||||
try:
|
||||
score = self._score_config(problem, cfg)
|
||||
scored.append((cfg, score))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
scored.sort(key=lambda x: -x[1])
|
||||
return scored[:top_k]
|
||||
|
||||
def _search_de(
|
||||
self,
|
||||
problem: dict,
|
||||
budget: int,
|
||||
top_k: int,
|
||||
pop_size: int = 20,
|
||||
mutation_rate: float = 0.3,
|
||||
crossover_rate: float = 0.7,
|
||||
) -> list[tuple[dict, float]]:
|
||||
"""Discrete Differential Evolution.
|
||||
|
||||
Uses discrete mutation: randomly swap parameters to other valid values
|
||||
from the parameter space (no continuous relaxation + snap).
|
||||
|
||||
Each generation:
|
||||
1. For each member of the population, create a trial vector by:
|
||||
- Selecting 3 random other members (a, b, c)
|
||||
- For each parameter, with probability mutation_rate, take the value
|
||||
from a, b, or c (uniform choice among the three donors)
|
||||
- With probability crossover_rate, take the trial value; otherwise keep original
|
||||
2. Validate the trial; if invalid, resample that parameter from the space
|
||||
3. Score the trial; if better than parent, replace
|
||||
"""
|
||||
param_names = list(self._param_space.keys())
|
||||
|
||||
population = []
|
||||
for _ in range(pop_size):
|
||||
cfg = self._sample_valid_config()
|
||||
if cfg is not None:
|
||||
score = self._score_config(problem, cfg)
|
||||
population.append((cfg, score))
|
||||
|
||||
if len(population) < 4:
|
||||
return self._search_random(problem, budget, top_k)
|
||||
|
||||
evals_used = len(population)
|
||||
max_gens = (budget - evals_used) // pop_size
|
||||
|
||||
for gen in range(max_gens):
|
||||
new_pop = []
|
||||
for i, (parent, parent_score) in enumerate(population):
|
||||
candidates = [j for j in range(len(population)) if j != i]
|
||||
if len(candidates) < 3:
|
||||
new_pop.append((parent, parent_score))
|
||||
continue
|
||||
|
||||
a_idx, b_idx, c_idx = self._rng.sample(candidates, 3)
|
||||
a, b, c = (
|
||||
population[a_idx][0],
|
||||
population[b_idx][0],
|
||||
population[c_idx][0],
|
||||
)
|
||||
|
||||
trial = dict(parent)
|
||||
for param in param_names:
|
||||
if self._rng.random() < mutation_rate:
|
||||
donor = self._rng.choice([a, b, c])
|
||||
trial[param] = donor.get(param, parent.get(param))
|
||||
|
||||
if self._rng.random() > crossover_rate:
|
||||
trial[param] = parent.get(param)
|
||||
|
||||
if not self._fe.validate_config(trial):
|
||||
for param in param_names:
|
||||
if param in trial and trial[param] not in self._param_space.get(
|
||||
param, [trial[param]]
|
||||
):
|
||||
trial[param] = self._rng.choice(self._param_space[param])
|
||||
if not self._fe.validate_config(trial):
|
||||
new_pop.append((parent, parent_score))
|
||||
continue
|
||||
|
||||
try:
|
||||
trial_score = self._score_config(problem, trial)
|
||||
evals_used += 1
|
||||
except Exception:
|
||||
new_pop.append((parent, parent_score))
|
||||
continue
|
||||
|
||||
if trial_score > parent_score:
|
||||
new_pop.append((trial, trial_score))
|
||||
else:
|
||||
new_pop.append((parent, parent_score))
|
||||
|
||||
population = new_pop
|
||||
|
||||
population.sort(key=lambda x: -x[1])
|
||||
return population[:top_k]
|
||||
|
||||
def search(
|
||||
self,
|
||||
problem: dict,
|
||||
budget: int = 500,
|
||||
top_k: int = 10,
|
||||
**kwargs,
|
||||
) -> list[tuple[dict, float]]:
|
||||
"""Search the kernel parameter space for the best configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
problem : dict
|
||||
Problem specification: m, n, k, dtype, layout, split_k.
|
||||
budget : int
|
||||
Maximum number of surrogate evaluations.
|
||||
top_k : int
|
||||
Number of top configurations to return.
|
||||
**kwargs
|
||||
Strategy-specific parameters (pop_size, mutation_rate, etc.).
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of (config_dict, predicted_tflops), sorted descending by TFLOPS.
|
||||
"""
|
||||
if self._strategy == "random":
|
||||
return self._search_random(problem, budget, top_k)
|
||||
elif self._strategy == "de":
|
||||
return self._search_de(problem, budget, top_k, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown strategy: {self._strategy}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import time
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Surrogate search for optimal kernel config"
|
||||
)
|
||||
parser.add_argument("--model_dir", required=True)
|
||||
parser.add_argument("--m", type=int, required=True)
|
||||
parser.add_argument("--n", type=int, required=True)
|
||||
parser.add_argument("--k", type=int, required=True)
|
||||
parser.add_argument("--dtype", default="fp8")
|
||||
parser.add_argument("--layout", default="rcr")
|
||||
parser.add_argument("--strategy", default="random", choices=["random", "de"])
|
||||
parser.add_argument("--budget", type=int, default=500)
|
||||
parser.add_argument("--top_k", type=int, default=10)
|
||||
args = parser.parse_args()
|
||||
|
||||
predictor = Predictor(args.model_dir)
|
||||
searcher = SurrogateSearch(predictor, strategy=args.strategy)
|
||||
problem = {
|
||||
"m": args.m,
|
||||
"n": args.n,
|
||||
"k": args.k,
|
||||
"dtype": args.dtype,
|
||||
"layout": args.layout,
|
||||
"split_k": 1,
|
||||
}
|
||||
|
||||
print(f"Searching with strategy={args.strategy}, budget={args.budget}...")
|
||||
t0 = time.time()
|
||||
results = searcher.search(problem, budget=args.budget, top_k=args.top_k)
|
||||
elapsed = time.time() - t0
|
||||
|
||||
print(f"\nTop {len(results)} configs found in {elapsed * 1000:.1f}ms:")
|
||||
for i, (cfg, tflops) in enumerate(results):
|
||||
tile_str = f"{cfg.get('tile_m', '?')}x{cfg.get('tile_n', '?')}x{cfg.get('tile_k', '?')}"
|
||||
warp_str = f"{cfg.get('warp_m', '?')}x{cfg.get('warp_n', '?')}x{cfg.get('warp_k', '?')}"
|
||||
print(
|
||||
f" #{i + 1}: {tflops:8.2f} TFLOPS tile={tile_str} warp={warp_str} "
|
||||
f"pipeline={cfg.get('pipeline', '?')} scheduler={cfg.get('scheduler', '?')}"
|
||||
)
|
||||
2
dispatcher/heuristics/tests/__init__.py
Normal file
2
dispatcher/heuristics/tests/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
368
dispatcher/heuristics/tests/test_data_pipeline.py
Normal file
368
dispatcher/heuristics/tests/test_data_pipeline.py
Normal file
@@ -0,0 +1,368 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Tests for data_pipeline.py.
|
||||
|
||||
Covers: kernel name parsing, layout derivation, streaming log parsing,
|
||||
schema validation, and corner cases (empty logs, malformed JSON, single-shape).
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from data_pipeline import (
|
||||
parse_kernel_name,
|
||||
_layout_from_problem,
|
||||
parse_streaming_log,
|
||||
save_parquet,
|
||||
load_parquet,
|
||||
CANONICAL_COLUMNS,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_kernel_name
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseKernelName:
|
||||
def test_standard_name(self):
|
||||
name = "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128"
|
||||
result = parse_kernel_name(name)
|
||||
assert result["dtype"] == "fp8"
|
||||
assert result["layout"] == "rcr"
|
||||
assert result["pipeline"] == "compv3"
|
||||
assert result["epilogue"] == "cshuffle"
|
||||
assert result["scheduler"] == "intrawave"
|
||||
assert result["pad_m"] is False
|
||||
assert result["pad_n"] is False
|
||||
assert result["pad_k"] is False
|
||||
assert result["persistent"] is False
|
||||
assert result["tile_m"] == 128
|
||||
assert result["tile_n"] == 128
|
||||
assert result["tile_k"] == 128
|
||||
assert result["warp_m"] == 1
|
||||
assert result["warp_n"] == 4
|
||||
assert result["warp_k"] == 1
|
||||
assert result["warp_tile_m"] == 16
|
||||
assert result["warp_tile_n"] == 16
|
||||
assert result["warp_tile_k"] == 128
|
||||
|
||||
def test_with_padding_and_persistent(self):
|
||||
name = "gemm_universal_fp16_rrr_compv4_default_interwave_True_True_True_True_256x256x64_2x2x1_32x32x16"
|
||||
result = parse_kernel_name(name)
|
||||
assert result["dtype"] == "fp16"
|
||||
assert result["layout"] == "rrr"
|
||||
assert result["pad_m"] is True
|
||||
assert result["pad_n"] is True
|
||||
assert result["pad_k"] is True
|
||||
assert result["persistent"] is True
|
||||
assert result["tile_m"] == 256
|
||||
|
||||
def test_empty_name(self):
|
||||
assert parse_kernel_name("") == {}
|
||||
|
||||
def test_malformed_name(self):
|
||||
assert parse_kernel_name("not_a_kernel_name") == {}
|
||||
|
||||
def test_partial_name(self):
|
||||
result = parse_kernel_name("gemm_universal_fp8_rcr_compv3")
|
||||
assert result.get("dtype") == "fp8"
|
||||
assert result.get("layout") == "rcr"
|
||||
assert "tile_m" not in result # not enough parts
|
||||
|
||||
def test_all_layouts(self):
|
||||
for layout in ["rcr", "rrr", "crr", "ccr"]:
|
||||
name = f"gemm_universal_fp8_{layout}_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128"
|
||||
result = parse_kernel_name(name)
|
||||
assert result["layout"] == layout
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _layout_from_problem
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLayoutFromProblem:
|
||||
def test_rcr(self):
|
||||
assert (
|
||||
_layout_from_problem(
|
||||
{
|
||||
"layout_a": "RowMajor",
|
||||
"layout_b": "ColumnMajor",
|
||||
"layout_c": "RowMajor",
|
||||
}
|
||||
)
|
||||
== "rcr"
|
||||
)
|
||||
|
||||
def test_rrr(self):
|
||||
assert (
|
||||
_layout_from_problem(
|
||||
{"layout_a": "RowMajor", "layout_b": "RowMajor", "layout_c": "RowMajor"}
|
||||
)
|
||||
== "rrr"
|
||||
)
|
||||
|
||||
def test_empty(self):
|
||||
assert _layout_from_problem({}) == "???"
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert (
|
||||
_layout_from_problem(
|
||||
{
|
||||
"layout_a": "rowmajor",
|
||||
"layout_b": "COLUMNMAJOR",
|
||||
"layout_c": "RowMajor",
|
||||
}
|
||||
)
|
||||
== "rcr"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_streaming_log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SAMPLE_LOG = """\
|
||||
================================================================================
|
||||
LOG FILE: test.log
|
||||
================================================================================
|
||||
CK Tile Profiling Run
|
||||
GPU ID: 0
|
||||
|
||||
--- Running CK Tile benchmarks on GPU 0 ---
|
||||
|
||||
========================================
|
||||
Shape 1: M=16 N=1536 K=7168 dtype=fp8 layout=rcr
|
||||
========================================
|
||||
Found 2 kernels
|
||||
{
|
||||
"name": "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128",
|
||||
"problem": {
|
||||
"split_k":1,
|
||||
"m":16,
|
||||
"n":1536,
|
||||
"k":7168,
|
||||
"stride_a":7168,
|
||||
"stride_b":7168,
|
||||
"stride_c":1536,
|
||||
"dtype_a":"fp8",
|
||||
"dtype_b":"fp8",
|
||||
"dtype_acc":"fp32",
|
||||
"dtype_c":"fp16",
|
||||
"layout_a":"RowMajor",
|
||||
"layout_b":"ColumnMajor",
|
||||
"layout_c":"RowMajor",
|
||||
"structured_sparsity":false
|
||||
},
|
||||
"perf_result": {
|
||||
"latency(ms)": 0.04,
|
||||
"tflops(TFlops)": 8.81,
|
||||
"bandwidth(GB/s)": 279.51
|
||||
}
|
||||
}
|
||||
{
|
||||
"name": "gemm_universal_fp8_rcr_compv4_default_intrawave_False_False_False_False_128x128x64_2x2x1_32x32x16",
|
||||
"problem": {
|
||||
"split_k":1,
|
||||
"m":16,
|
||||
"n":1536,
|
||||
"k":7168,
|
||||
"stride_a":7168,
|
||||
"stride_b":7168,
|
||||
"stride_c":1536,
|
||||
"dtype_a":"fp8",
|
||||
"dtype_b":"fp8",
|
||||
"dtype_acc":"fp32",
|
||||
"dtype_c":"fp16",
|
||||
"layout_a":"RowMajor",
|
||||
"layout_b":"ColumnMajor",
|
||||
"layout_c":"RowMajor",
|
||||
"structured_sparsity":false
|
||||
},
|
||||
"perf_result": {
|
||||
"latency(ms)": 0.05,
|
||||
"tflops(TFlops)": 7.22,
|
||||
"bandwidth(GB/s)": 228.85
|
||||
}
|
||||
}
|
||||
|
||||
========================================
|
||||
Shape 2: M=20480 N=7168 K=256 dtype=fp8 layout=rcr
|
||||
========================================
|
||||
Found 1 kernels
|
||||
{
|
||||
"name": "gemm_universal_fp8_rcr_mem_cshuffle_intrawave_False_False_False_True_64x64x128_1x4x1_16x16x32",
|
||||
"problem": {
|
||||
"split_k":1,
|
||||
"m":20480,
|
||||
"n":7168,
|
||||
"k":256,
|
||||
"stride_a":256,
|
||||
"stride_b":256,
|
||||
"stride_c":7168,
|
||||
"dtype_a":"fp8",
|
||||
"dtype_b":"fp8",
|
||||
"dtype_acc":"fp32",
|
||||
"dtype_c":"fp16",
|
||||
"layout_a":"RowMajor",
|
||||
"layout_b":"ColumnMajor",
|
||||
"layout_c":"RowMajor",
|
||||
"structured_sparsity":false
|
||||
},
|
||||
"perf_result": {
|
||||
"latency(ms)": 0.15,
|
||||
"tflops(TFlops)": 505.00,
|
||||
"bandwidth(GB/s)": 1200.50
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class TestParseStreamingLog:
|
||||
def _write_log(self, content: str) -> Path:
|
||||
f = tempfile.NamedTemporaryFile(mode="w", suffix=".log", delete=False)
|
||||
f.write(content)
|
||||
f.close()
|
||||
return Path(f.name)
|
||||
|
||||
def test_basic_parse(self):
|
||||
path = self._write_log(SAMPLE_LOG)
|
||||
df = parse_streaming_log(path, arch="gfx950")
|
||||
assert len(df) == 3
|
||||
assert df["arch"].iloc[0] == "gfx950"
|
||||
assert df["m"].tolist() == [16, 16, 20480]
|
||||
assert df["n"].tolist() == [1536, 1536, 7168]
|
||||
assert df["k"].tolist() == [7168, 7168, 256]
|
||||
|
||||
def test_tflops_values(self):
|
||||
path = self._write_log(SAMPLE_LOG)
|
||||
df = parse_streaming_log(path)
|
||||
assert df["measured_tflops"].tolist() == pytest.approx([8.81, 7.22, 505.0])
|
||||
|
||||
def test_kernel_config_parsed(self):
|
||||
path = self._write_log(SAMPLE_LOG)
|
||||
df = parse_streaming_log(path)
|
||||
assert df["tile_m"].iloc[0] == 128
|
||||
assert df["pipeline"].iloc[0] == "compv3"
|
||||
assert df["pipeline"].iloc[1] == "compv4"
|
||||
|
||||
def test_layout_derived_from_json(self):
|
||||
path = self._write_log(SAMPLE_LOG)
|
||||
df = parse_streaming_log(path)
|
||||
assert all(df["layout"] == "rcr")
|
||||
|
||||
def test_empty_log(self):
|
||||
path = self._write_log("No shapes here\nJust noise\n")
|
||||
df = parse_streaming_log(path)
|
||||
assert len(df) == 0
|
||||
for col in CANONICAL_COLUMNS:
|
||||
assert col in df.columns
|
||||
|
||||
def test_single_kernel(self):
|
||||
log = """\
|
||||
Shape 1: M=1 N=1 K=1 dtype=fp8 layout=rcr
|
||||
{
|
||||
"name": "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128",
|
||||
"problem": {"split_k":1, "m":1, "n":1, "k":1, "dtype_a":"fp8", "dtype_b":"fp8", "layout_a":"RowMajor", "layout_b":"ColumnMajor", "layout_c":"RowMajor"},
|
||||
"perf_result": {"latency(ms)": 0.001, "tflops(TFlops)": 0.002, "bandwidth(GB/s)": 0.01}
|
||||
}
|
||||
"""
|
||||
path = self._write_log(log)
|
||||
df = parse_streaming_log(path)
|
||||
assert len(df) == 1
|
||||
assert df["m"].iloc[0] == 1
|
||||
assert bool(df["is_valid"].iloc[0]) is True
|
||||
|
||||
def test_zero_tflops_marked_invalid(self):
|
||||
log = """\
|
||||
Shape 1: M=16 N=16 K=16 dtype=fp8 layout=rcr
|
||||
{
|
||||
"name": "test_kernel",
|
||||
"problem": {"split_k":1, "m":16, "n":16, "k":16, "dtype_a":"fp8"},
|
||||
"perf_result": {"latency(ms)": 0.0, "tflops(TFlops)": 0.0, "bandwidth(GB/s)": 0.0}
|
||||
}
|
||||
"""
|
||||
path = self._write_log(log)
|
||||
df = parse_streaming_log(path)
|
||||
assert len(df) == 1
|
||||
assert bool(df["is_valid"].iloc[0]) is False
|
||||
|
||||
def test_malformed_json_skipped(self):
|
||||
log = """\
|
||||
Shape 1: M=16 N=16 K=16 dtype=fp8 layout=rcr
|
||||
{
|
||||
"name": "good_kernel",
|
||||
"problem": {"split_k":1, "m":16, "n":16, "k":16, "dtype_a":"fp8"},
|
||||
"perf_result": {"latency(ms)": 0.01, "tflops(TFlops)": 1.0, "bandwidth(GB/s)": 10.0}
|
||||
}
|
||||
{ this is not valid json }
|
||||
{
|
||||
"name": "another_good",
|
||||
"problem": {"split_k":1, "m":16, "n":16, "k":16, "dtype_a":"fp8"},
|
||||
"perf_result": {"latency(ms)": 0.02, "tflops(TFlops)": 2.0, "bandwidth(GB/s)": 20.0}
|
||||
}
|
||||
"""
|
||||
path = self._write_log(log)
|
||||
df = parse_streaming_log(path)
|
||||
assert len(df) == 2
|
||||
|
||||
def test_extreme_shapes(self):
|
||||
"""Tiny M=1 (single token) and very large M=20480."""
|
||||
path = self._write_log(SAMPLE_LOG)
|
||||
df = parse_streaming_log(path)
|
||||
assert 1 not in df["m"].values # sample has M=16, M=20480
|
||||
assert 16 in df["m"].values
|
||||
assert 20480 in df["m"].values
|
||||
|
||||
def test_run_id_assigned(self):
|
||||
path = self._write_log(SAMPLE_LOG)
|
||||
df = parse_streaming_log(path, run_id="test_run_123")
|
||||
assert all(df["run_id"] == "test_run_123")
|
||||
|
||||
def test_op_type_assigned(self):
|
||||
path = self._write_log(SAMPLE_LOG)
|
||||
df = parse_streaming_log(path, op_type="gemm_streamk")
|
||||
assert all(df["op_type"] == "gemm_streamk")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parquet round-trip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParquetIO:
|
||||
def test_round_trip(self, tmp_path):
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"m": [16, 32],
|
||||
"n": [1536, 1536],
|
||||
"k": [7168, 7168],
|
||||
"measured_tflops": [8.81, 15.0],
|
||||
}
|
||||
)
|
||||
path = tmp_path / "test.parquet"
|
||||
save_parquet(df, path)
|
||||
loaded = load_parquet(path)
|
||||
assert len(loaded) == 2
|
||||
assert loaded["m"].tolist() == [16, 32]
|
||||
|
||||
def test_creates_parent_dirs(self, tmp_path):
|
||||
path = tmp_path / "sub" / "dir" / "test.parquet"
|
||||
df = pd.DataFrame({"x": [1]})
|
||||
save_parquet(df, path)
|
||||
assert path.exists()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
264
dispatcher/heuristics/tests/test_dispatcher_integration.py
Normal file
264
dispatcher/heuristics/tests/test_dispatcher_integration.py
Normal file
@@ -0,0 +1,264 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Tests for dispatcher_integration.py.
|
||||
|
||||
Covers: kernel name parsing to feature dict, feature dict to dispatcher config
|
||||
(name mapping inversion), MLKernelSpec creation, binary pool loading, and
|
||||
the ML heuristic function.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import lightgbm as lgb
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from dispatcher_integration import (
|
||||
kernel_config_to_feature_dict,
|
||||
feature_dict_to_dispatcher_config,
|
||||
feature_dict_to_ml_spec,
|
||||
ml_spec_to_dispatcher_config,
|
||||
create_ml_heuristic,
|
||||
load_kernel_pool_from_binaries,
|
||||
MLKernelSpec,
|
||||
LAYOUT_TO_DISPATCHER,
|
||||
)
|
||||
from feature_engine import GemmUniversalFeatureEngine
|
||||
|
||||
|
||||
SAMPLE_KERNEL_NAME = (
|
||||
"gemm_universal_fp8_rcr_compv3_cshuffle_intrawave"
|
||||
"_False_False_False_False_128x128x128_1x4x1_16x16x128"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# kernel_config_to_feature_dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestKernelConfigToFeatureDict:
|
||||
def test_parses_standard_name(self):
|
||||
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
|
||||
assert feat["tile_m"] == 128
|
||||
assert feat["tile_n"] == 128
|
||||
assert feat["tile_k"] == 128
|
||||
assert feat["warp_m"] == 1 # warps per block
|
||||
assert feat["warp_n"] == 4
|
||||
assert feat["warp_k"] == 1
|
||||
assert feat["warp_tile_m"] == 16
|
||||
assert feat["warp_tile_n"] == 16
|
||||
assert feat["warp_tile_k"] == 128
|
||||
assert feat["pipeline"] == "compv3"
|
||||
assert feat["scheduler"] == "intrawave"
|
||||
assert feat["epilogue"] == "cshuffle"
|
||||
assert feat["kernel_name"] == SAMPLE_KERNEL_NAME
|
||||
|
||||
def test_empty_name_returns_empty(self):
|
||||
assert kernel_config_to_feature_dict("") == {}
|
||||
|
||||
def test_invalid_name_returns_empty(self):
|
||||
assert kernel_config_to_feature_dict("not_a_kernel") == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Name mapping: feature dict <-> dispatcher config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNameMapping:
|
||||
"""The critical inversion: feature engine warp_m/n/k (warps per block)
|
||||
maps to dispatcher wave_m/n/k, and feature engine warp_tile_m/n/k
|
||||
maps to dispatcher warp_m/n/k."""
|
||||
|
||||
def test_warp_to_wave_mapping(self):
|
||||
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
|
||||
disp = feature_dict_to_dispatcher_config(feat)
|
||||
assert disp["wave_m"] == feat["warp_m"] # 1
|
||||
assert disp["wave_n"] == feat["warp_n"] # 4
|
||||
assert disp["wave_k"] == feat["warp_k"] # 1
|
||||
|
||||
def test_warp_tile_to_warp_mapping(self):
|
||||
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
|
||||
disp = feature_dict_to_dispatcher_config(feat)
|
||||
assert disp["warp_m"] == feat["warp_tile_m"] # 16
|
||||
assert disp["warp_n"] == feat["warp_tile_n"] # 16
|
||||
assert disp["warp_k"] == feat["warp_tile_k"] # 128
|
||||
|
||||
def test_tile_dims_pass_through(self):
|
||||
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
|
||||
disp = feature_dict_to_dispatcher_config(feat)
|
||||
assert disp["tile_m"] == 128
|
||||
assert disp["tile_n"] == 128
|
||||
assert disp["tile_k"] == 128
|
||||
|
||||
def test_pipeline_passes_through(self):
|
||||
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
|
||||
disp = feature_dict_to_dispatcher_config(feat)
|
||||
assert disp["pipeline"] == "compv3"
|
||||
assert disp["scheduler"] == "intrawave"
|
||||
assert disp["epilogue"] == "cshuffle"
|
||||
|
||||
def test_rcr_layout_mapping(self):
|
||||
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
|
||||
disp = feature_dict_to_dispatcher_config(feat, dtype="fp8")
|
||||
assert disp["layout_a"] == "row"
|
||||
assert disp["layout_b"] == "col"
|
||||
assert disp["layout_c"] == "row"
|
||||
|
||||
def test_all_layouts(self):
|
||||
for layout, (la, lb, lc) in LAYOUT_TO_DISPATCHER.items():
|
||||
feat = {"layout": layout, "tile_m": 128}
|
||||
disp = feature_dict_to_dispatcher_config(feat)
|
||||
assert disp["layout_a"] == la
|
||||
assert disp["layout_b"] == lb
|
||||
assert disp["layout_c"] == lc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MLKernelSpec
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMLKernelSpec:
|
||||
def test_from_feature_dict(self):
|
||||
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
|
||||
spec = feature_dict_to_ml_spec(feat, predicted_tflops=123.4)
|
||||
assert spec.kernel_name == SAMPLE_KERNEL_NAME
|
||||
assert spec.predicted_tflops == 123.4
|
||||
assert spec.tile_m == 128
|
||||
assert spec.wave_m == 1 # was warp_m in feature space
|
||||
assert spec.warp_m == 16 # was warp_tile_m in feature space
|
||||
|
||||
def test_spec_to_dispatcher_config(self):
|
||||
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
|
||||
spec = feature_dict_to_ml_spec(feat, 100.0)
|
||||
disp = ml_spec_to_dispatcher_config(spec, dtype="fp8", arch="gfx950")
|
||||
assert disp["tile_m"] == 128
|
||||
assert disp["wave_m"] == 1
|
||||
assert disp["warp_m"] == 16
|
||||
assert disp["gfx_arch"] == "gfx950"
|
||||
assert disp["dtype_a"] == "fp8"
|
||||
|
||||
def test_roundtrip_preserves_values(self):
|
||||
"""feature_dict -> MLKernelSpec -> dispatcher_config should be consistent."""
|
||||
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
|
||||
spec = feature_dict_to_ml_spec(feat, 0.0)
|
||||
disp_from_spec = ml_spec_to_dispatcher_config(spec)
|
||||
disp_from_feat = feature_dict_to_dispatcher_config(feat)
|
||||
for key in [
|
||||
"tile_m",
|
||||
"tile_n",
|
||||
"tile_k",
|
||||
"wave_m",
|
||||
"wave_n",
|
||||
"wave_k",
|
||||
"warp_m",
|
||||
"warp_n",
|
||||
"warp_k",
|
||||
"pipeline",
|
||||
"scheduler",
|
||||
"epilogue",
|
||||
]:
|
||||
assert disp_from_spec[key] == disp_from_feat[key], f"Mismatch on {key}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Binary pool loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLoadKernelPool:
|
||||
def test_loads_from_real_bin_dir(self):
|
||||
bin_dir = Path("/workspace/ck_tile/bin")
|
||||
if not bin_dir.exists():
|
||||
pytest.skip("No /workspace/ck_tile/bin")
|
||||
pool = load_kernel_pool_from_binaries(bin_dir)
|
||||
assert len(pool) > 0
|
||||
assert "tile_m" in pool[0]
|
||||
assert "kernel_name" in pool[0]
|
||||
|
||||
def test_empty_dir_returns_empty(self, tmp_path):
|
||||
pool = load_kernel_pool_from_binaries(tmp_path)
|
||||
assert pool == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ML heuristic function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCreateMLHeuristic:
|
||||
@pytest.fixture
|
||||
def mock_model_dir(self, tmp_path):
|
||||
"""Create a minimal model for testing the heuristic flow."""
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
n_features = len(fe.get_feature_names())
|
||||
np.random.seed(42)
|
||||
X = np.random.rand(100, n_features)
|
||||
y = np.random.rand(100) * 500
|
||||
model = lgb.LGBMRegressor(n_estimators=5, verbose=-1)
|
||||
model.fit(X, y)
|
||||
model.booster_.save_model(str(tmp_path / "model_tflops.lgbm"))
|
||||
spec = {
|
||||
"feature_names": fe.get_feature_names(),
|
||||
"categorical_features": fe.get_categorical_features(),
|
||||
}
|
||||
with open(tmp_path / "feature_spec.json", "w") as f:
|
||||
json.dump(spec, f)
|
||||
return tmp_path
|
||||
|
||||
def _make_pool(self):
|
||||
"""Create a small synthetic kernel pool."""
|
||||
names = [
|
||||
"gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128",
|
||||
"gemm_universal_fp8_rcr_compv4_default_intrawave_False_False_False_False_128x128x64_2x2x1_32x32x16",
|
||||
"gemm_universal_fp8_rcr_mem_cshuffle_interwave_False_False_False_False_64x64x128_1x4x1_16x16x32",
|
||||
]
|
||||
return [kernel_config_to_feature_dict(n) for n in names]
|
||||
|
||||
def test_returns_ml_kernel_spec(self, mock_model_dir):
|
||||
pool = self._make_pool()
|
||||
heuristic = create_ml_heuristic(mock_model_dir, kernel_pool=pool)
|
||||
result = heuristic(1024, 1024, 1024)
|
||||
assert isinstance(result, MLKernelSpec)
|
||||
assert result.tile_m > 0
|
||||
assert isinstance(result.predicted_tflops, float)
|
||||
|
||||
def test_returns_valid_kernel_from_pool(self, mock_model_dir):
|
||||
pool = self._make_pool()
|
||||
pool_names = {p["kernel_name"] for p in pool}
|
||||
heuristic = create_ml_heuristic(mock_model_dir, kernel_pool=pool)
|
||||
result = heuristic(1024, 1024, 1024)
|
||||
assert result.kernel_name in pool_names
|
||||
|
||||
def test_different_shapes_may_select_different_kernels(self, mock_model_dir):
|
||||
pool = self._make_pool()
|
||||
heuristic = create_ml_heuristic(mock_model_dir, kernel_pool=pool)
|
||||
r1 = heuristic(16, 1536, 7168)
|
||||
r2 = heuristic(8192, 8192, 256)
|
||||
# At minimum both should return valid specs
|
||||
assert r1.tile_m > 0
|
||||
assert r2.tile_m > 0
|
||||
|
||||
def test_m1_corner_case(self, mock_model_dir):
|
||||
pool = self._make_pool()
|
||||
heuristic = create_ml_heuristic(mock_model_dir, kernel_pool=pool)
|
||||
result = heuristic(1, 4096, 4096)
|
||||
assert isinstance(result, MLKernelSpec)
|
||||
assert np.isfinite(result.predicted_tflops)
|
||||
|
||||
def test_empty_pool_raises(self, mock_model_dir):
|
||||
with pytest.raises(ValueError, match="No kernel configs"):
|
||||
create_ml_heuristic(mock_model_dir, kernel_pool=[])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
55
dispatcher/heuristics/tests/test_evaluate.py
Normal file
55
dispatcher/heuristics/tests/test_evaluate.py
Normal file
@@ -0,0 +1,55 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Tests for evaluate.py.
|
||||
|
||||
Covers: shape family classification, K-depth regime classification,
|
||||
and basic evaluation metric checks.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from evaluate import classify_shape_family, classify_k_regime
|
||||
|
||||
|
||||
class TestClassifyShapeFamily:
|
||||
def test_tiny_m(self):
|
||||
assert classify_shape_family(1, 4096, 4096) == "tiny_m"
|
||||
assert classify_shape_family(16, 1536, 7168) == "tiny_m"
|
||||
|
||||
def test_small_m(self):
|
||||
assert classify_shape_family(32, 1536, 7168) == "small_m"
|
||||
assert classify_shape_family(128, 4096, 4096) == "small_m"
|
||||
|
||||
def test_medium_m(self):
|
||||
assert classify_shape_family(256, 1024, 1024) == "medium_m"
|
||||
assert classify_shape_family(2048, 2048, 2048) == "medium_m"
|
||||
|
||||
def test_large_m(self):
|
||||
assert classify_shape_family(4096, 4096, 4096) == "large_m"
|
||||
assert classify_shape_family(20480, 7168, 256) == "large_m"
|
||||
|
||||
|
||||
class TestClassifyKRegime:
|
||||
def test_shallow(self):
|
||||
assert classify_k_regime(256) == "shallow_k"
|
||||
assert classify_k_regime(32) == "shallow_k"
|
||||
|
||||
def test_medium(self):
|
||||
assert classify_k_regime(1024) == "medium_k"
|
||||
assert classify_k_regime(2048) == "medium_k"
|
||||
|
||||
def test_deep(self):
|
||||
assert classify_k_regime(4096) == "deep_k"
|
||||
assert classify_k_regime(7168) == "deep_k"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
409
dispatcher/heuristics/tests/test_feature_engine.py
Normal file
409
dispatcher/heuristics/tests/test_feature_engine.py
Normal file
@@ -0,0 +1,409 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Tests for feature_engine.py.
|
||||
|
||||
Covers: feature count consistency, formula correctness (tile efficiency, LDS,
|
||||
arithmetic intensity), corner-case shapes (M=1, huge M, square, skinny-K),
|
||||
parameter space validity, config validation, and batch vs single extraction parity.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from feature_engine import (
|
||||
GemmUniversalFeatureEngine,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fe():
|
||||
"""Default feature engine with MI355X-like hardware."""
|
||||
return GemmUniversalFeatureEngine(
|
||||
num_cus=256,
|
||||
lds_capacity=65536,
|
||||
max_clock_mhz=2400,
|
||||
simds_per_cu=4,
|
||||
shader_engines=32,
|
||||
max_waves_per_cu=32,
|
||||
wavefront_size=64,
|
||||
l1_cache_kb=32,
|
||||
l2_cache_kb=4096,
|
||||
l3_cache_kb=262144,
|
||||
num_xcd=8,
|
||||
)
|
||||
|
||||
|
||||
def _make_problem(m=1024, n=1024, k=1024, dtype="fp8", layout="rcr", split_k=1):
|
||||
return {
|
||||
"m": m,
|
||||
"n": n,
|
||||
"k": k,
|
||||
"dtype": dtype,
|
||||
"layout": layout,
|
||||
"split_k": split_k,
|
||||
}
|
||||
|
||||
|
||||
def _make_kernel(
|
||||
tile_m=128,
|
||||
tile_n=128,
|
||||
tile_k=64,
|
||||
warp_m=2,
|
||||
warp_n=2,
|
||||
warp_k=1,
|
||||
warp_tile_m=32,
|
||||
warp_tile_n=32,
|
||||
warp_tile_k=16,
|
||||
pipeline="compv3",
|
||||
scheduler="intrawave",
|
||||
epilogue="cshuffle",
|
||||
pad_m=False,
|
||||
pad_n=False,
|
||||
pad_k=False,
|
||||
persistent=False,
|
||||
):
|
||||
return {
|
||||
"tile_m": tile_m,
|
||||
"tile_n": tile_n,
|
||||
"tile_k": tile_k,
|
||||
"warp_m": warp_m,
|
||||
"warp_n": warp_n,
|
||||
"warp_k": warp_k,
|
||||
"warp_tile_m": warp_tile_m,
|
||||
"warp_tile_n": warp_tile_n,
|
||||
"warp_tile_k": warp_tile_k,
|
||||
"pipeline": pipeline,
|
||||
"scheduler": scheduler,
|
||||
"epilogue": epilogue,
|
||||
"pad_m": pad_m,
|
||||
"pad_n": pad_n,
|
||||
"pad_k": pad_k,
|
||||
"persistent": persistent,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Basic consistency
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFeatureConsistency:
|
||||
def test_feature_count_matches_names(self, fe):
|
||||
prob = _make_problem()
|
||||
kern = _make_kernel()
|
||||
vec = fe.extract(prob, kern)
|
||||
assert len(vec) == len(fe.get_feature_names())
|
||||
|
||||
def test_feature_count_is_72(self, fe):
|
||||
assert len(fe.get_feature_names()) == 72
|
||||
|
||||
def test_no_nan_in_output(self, fe):
|
||||
prob = _make_problem()
|
||||
kern = _make_kernel()
|
||||
vec = fe.extract(prob, kern)
|
||||
assert not np.any(np.isnan(vec))
|
||||
|
||||
def test_no_inf_in_output(self, fe):
|
||||
prob = _make_problem()
|
||||
kern = _make_kernel()
|
||||
vec = fe.extract(prob, kern)
|
||||
assert not np.any(np.isinf(vec))
|
||||
|
||||
def test_categorical_features_in_names(self, fe):
|
||||
names = fe.get_feature_names()
|
||||
for cat in fe.get_categorical_features():
|
||||
assert cat in names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Formula correctness
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTileEfficiency:
|
||||
"""Tile efficiency: fraction of the last tile that is useful work."""
|
||||
|
||||
def test_perfectly_divisible(self, fe):
|
||||
prob = _make_problem(m=256, n=256, k=128)
|
||||
kern = _make_kernel(tile_m=128, tile_n=128, tile_k=64)
|
||||
vec = fe.extract(prob, kern)
|
||||
names = fe.get_feature_names()
|
||||
assert vec[names.index("tile_eff_m")] == 1.0
|
||||
assert vec[names.index("tile_eff_n")] == 1.0
|
||||
assert vec[names.index("tile_eff_k")] == 1.0
|
||||
assert vec[names.index("overall_tile_efficiency")] == 1.0
|
||||
|
||||
def test_not_divisible(self, fe):
|
||||
prob = _make_problem(m=100, n=100, k=100)
|
||||
kern = _make_kernel(tile_m=128, tile_n=128, tile_k=64)
|
||||
vec = fe.extract(prob, kern)
|
||||
names = fe.get_feature_names()
|
||||
assert vec[names.index("tile_eff_m")] == pytest.approx(100 / 128)
|
||||
assert vec[names.index("tile_eff_n")] == pytest.approx(100 / 128)
|
||||
assert vec[names.index("tile_eff_k")] == pytest.approx(36 / 64)
|
||||
|
||||
def test_m_equals_1(self, fe):
|
||||
"""Single-token inference: M=1, tile_m=128 => eff = 1/128."""
|
||||
prob = _make_problem(m=1)
|
||||
kern = _make_kernel(tile_m=128)
|
||||
vec = fe.extract(prob, kern)
|
||||
names = fe.get_feature_names()
|
||||
assert vec[names.index("tile_eff_m")] == pytest.approx(1.0 / 128.0)
|
||||
|
||||
|
||||
class TestLDSUsage:
|
||||
def test_lds_formula(self, fe):
|
||||
prob = _make_problem(dtype="fp8")
|
||||
kern = _make_kernel(tile_m=128, tile_n=128, tile_k=64)
|
||||
vec = fe.extract(prob, kern)
|
||||
names = fe.get_feature_names()
|
||||
expected = (128 * 64 + 128 * 64) * 1.0 # fp8 = 1 byte
|
||||
assert vec[names.index("lds_usage_estimate")] == expected
|
||||
|
||||
def test_lds_ratio_compv4(self, fe):
|
||||
"""compv4 has 32KB LDS limit, not 64KB."""
|
||||
prob = _make_problem(dtype="fp8")
|
||||
kern = _make_kernel(tile_m=128, tile_n=128, tile_k=64, pipeline="compv4")
|
||||
vec = fe.extract(prob, kern)
|
||||
names = fe.get_feature_names()
|
||||
lds_est = (128 * 64 + 128 * 64) * 1.0
|
||||
assert vec[names.index("lds_usage_ratio")] == pytest.approx(lds_est / 32768)
|
||||
|
||||
def test_lds_fp16_doubles(self, fe):
|
||||
prob = _make_problem(dtype="fp16")
|
||||
kern = _make_kernel(tile_m=128, tile_n=128, tile_k=64)
|
||||
vec = fe.extract(prob, kern)
|
||||
names = fe.get_feature_names()
|
||||
expected = (128 * 64 + 128 * 64) * 2.0 # fp16 = 2 bytes
|
||||
assert vec[names.index("lds_usage_estimate")] == expected
|
||||
|
||||
|
||||
class TestArithmeticIntensity:
|
||||
def test_square_shape(self, fe):
|
||||
M, N, K = 1024, 1024, 1024
|
||||
prob = _make_problem(m=M, n=N, k=K, dtype="fp8")
|
||||
kern = _make_kernel()
|
||||
vec = fe.extract(prob, kern)
|
||||
names = fe.get_feature_names()
|
||||
mem = (M * K + K * N + M * N) * 1.0
|
||||
expected = (2.0 * M * N * K) / mem
|
||||
assert vec[names.index("arithmetic_intensity")] == pytest.approx(expected)
|
||||
|
||||
def test_skinny_k(self, fe):
|
||||
"""Small K => low arithmetic intensity (memory-bound)."""
|
||||
prob = _make_problem(m=8192, n=8192, k=32, dtype="fp8")
|
||||
kern = _make_kernel()
|
||||
vec = fe.extract(prob, kern)
|
||||
names = fe.get_feature_names()
|
||||
assert vec[names.index("arithmetic_intensity")] < 100
|
||||
|
||||
def test_deep_k(self, fe):
|
||||
"""Large K => high arithmetic intensity (compute-bound)."""
|
||||
prob = _make_problem(m=256, n=256, k=8192, dtype="fp8")
|
||||
kern = _make_kernel()
|
||||
vec = fe.extract(prob, kern)
|
||||
names = fe.get_feature_names()
|
||||
assert vec[names.index("arithmetic_intensity")] > 100
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Corner-case shapes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCornerCaseShapes:
|
||||
def test_m1_single_token(self, fe):
|
||||
vec = fe.extract(_make_problem(m=1, n=4096, k=4096), _make_kernel())
|
||||
assert not np.any(np.isnan(vec))
|
||||
|
||||
def test_m1_n1_k1_minimum(self, fe):
|
||||
vec = fe.extract(_make_problem(m=1, n=1, k=1), _make_kernel())
|
||||
assert not np.any(np.isnan(vec))
|
||||
assert not np.any(np.isinf(vec))
|
||||
|
||||
def test_very_large_m(self, fe):
|
||||
vec = fe.extract(_make_problem(m=20480, n=7168, k=7168), _make_kernel())
|
||||
assert not np.any(np.isnan(vec))
|
||||
|
||||
def test_non_power_of_2(self, fe):
|
||||
vec = fe.extract(_make_problem(m=1536, n=7168, k=2304), _make_kernel())
|
||||
assert not np.any(np.isnan(vec))
|
||||
|
||||
def test_prime_dimensions(self, fe):
|
||||
vec = fe.extract(_make_problem(m=17, n=31, k=127), _make_kernel())
|
||||
assert not np.any(np.isnan(vec))
|
||||
|
||||
def test_tall_matrix(self, fe):
|
||||
"""M >> N (tall matrix)."""
|
||||
prob = _make_problem(m=16384, n=64, k=1024)
|
||||
vec = fe.extract(prob, _make_kernel())
|
||||
names = fe.get_feature_names()
|
||||
assert vec[names.index("aspect_ratio_mn")] > 100
|
||||
|
||||
def test_wide_matrix(self, fe):
|
||||
"""N >> M (wide matrix)."""
|
||||
prob = _make_problem(m=64, n=16384, k=1024)
|
||||
vec = fe.extract(prob, _make_kernel())
|
||||
names = fe.get_feature_names()
|
||||
assert vec[names.index("aspect_ratio_mn")] < 0.01
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Batch vs single extraction parity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBatchParity:
|
||||
def test_batch_matches_single(self, fe):
|
||||
"""Vectorized batch should produce identical results to row-by-row."""
|
||||
rows = [
|
||||
{
|
||||
"m": 16,
|
||||
"n": 1536,
|
||||
"k": 7168,
|
||||
"split_k": 1,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"tile_m": 128,
|
||||
"tile_n": 128,
|
||||
"tile_k": 128,
|
||||
"warp_m": 1,
|
||||
"warp_n": 4,
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 16,
|
||||
"warp_tile_n": 16,
|
||||
"warp_tile_k": 128,
|
||||
"pipeline": "compv3",
|
||||
"scheduler": "intrawave",
|
||||
"epilogue": "cshuffle",
|
||||
"pad_m": False,
|
||||
"pad_n": False,
|
||||
"pad_k": False,
|
||||
"persistent": False,
|
||||
},
|
||||
{
|
||||
"m": 20480,
|
||||
"n": 7168,
|
||||
"k": 256,
|
||||
"split_k": 1,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"tile_m": 64,
|
||||
"tile_n": 64,
|
||||
"tile_k": 128,
|
||||
"warp_m": 2,
|
||||
"warp_n": 2,
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 32,
|
||||
"warp_tile_n": 32,
|
||||
"warp_tile_k": 16,
|
||||
"pipeline": "mem",
|
||||
"scheduler": "interwave",
|
||||
"epilogue": "default",
|
||||
"pad_m": True,
|
||||
"pad_n": True,
|
||||
"pad_k": True,
|
||||
"persistent": True,
|
||||
},
|
||||
]
|
||||
df = pd.DataFrame(rows)
|
||||
batch_result = fe.extract_batch(df)
|
||||
|
||||
for i, row_dict in enumerate(rows):
|
||||
single_result = fe.extract(row_dict, row_dict)
|
||||
np.testing.assert_allclose(
|
||||
batch_result[i],
|
||||
single_result,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
err_msg=f"Mismatch at row {i}",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parameter space and validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParameterSpace:
|
||||
def test_parameter_space_non_empty(self, fe):
|
||||
ps = fe.get_parameter_space()
|
||||
assert len(ps) > 0
|
||||
assert "tile_m" in ps
|
||||
assert "pipeline" in ps
|
||||
|
||||
def test_valid_config_passes(self, fe):
|
||||
config = {
|
||||
"tile_m": 128,
|
||||
"tile_n": 128,
|
||||
"tile_k": 64,
|
||||
"warp_m": 2,
|
||||
"warp_n": 2,
|
||||
"warp_k": 1,
|
||||
"pipeline": "compv3",
|
||||
"scheduler": "intrawave",
|
||||
"epilogue": "cshuffle",
|
||||
"pad_m": False,
|
||||
"pad_n": False,
|
||||
"pad_k": False,
|
||||
"persistent": False,
|
||||
}
|
||||
assert fe.validate_config(config) is True
|
||||
|
||||
def test_invalid_tile_rejected(self, fe):
|
||||
config = {"tile_m": 999}
|
||||
assert fe.validate_config(config) is False
|
||||
|
||||
def test_lds_constraint_rejects_huge_tile(self, fe):
|
||||
config = {
|
||||
"tile_m": 256,
|
||||
"tile_n": 256,
|
||||
"tile_k": 256,
|
||||
"warp_m": 2,
|
||||
"warp_n": 2,
|
||||
"warp_k": 1,
|
||||
"pipeline": "compv4",
|
||||
}
|
||||
assert fe.validate_config(config) is False
|
||||
|
||||
def test_project_to_valid_snaps(self, fe):
|
||||
config = {"tile_m": 100, "tile_n": 200, "pipeline": "compv3"}
|
||||
projected = fe.project_to_valid(config)
|
||||
assert projected["tile_m"] == 128
|
||||
assert projected["tile_n"] == 192
|
||||
assert projected["pipeline"] == "compv3"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hardware features
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHardwareFeatures:
|
||||
def test_hardware_values_propagated(self, fe):
|
||||
vec = fe.extract(_make_problem(), _make_kernel())
|
||||
names = fe.get_feature_names()
|
||||
assert vec[names.index("hw_num_cus")] == 256
|
||||
assert vec[names.index("hw_max_clock_mhz")] == 2400
|
||||
assert vec[names.index("hw_total_simds")] == 256 * 4
|
||||
assert vec[names.index("hw_num_xcd")] == 8
|
||||
|
||||
def test_different_hardware(self):
|
||||
fe_small = GemmUniversalFeatureEngine(num_cus=120, max_clock_mhz=1800)
|
||||
vec = fe_small.extract(_make_problem(), _make_kernel())
|
||||
names = fe_small.get_feature_names()
|
||||
assert vec[names.index("hw_num_cus")] == 120
|
||||
assert vec[names.index("hw_max_clock_mhz")] == 1800
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
357
dispatcher/heuristics/tests/test_feature_parity.py
Normal file
357
dispatcher/heuristics/tests/test_feature_parity.py
Normal file
@@ -0,0 +1,357 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Test that the C++ extract_features() in ml_heuristic.hpp produces identical
|
||||
values to the Python GemmUniversalFeatureEngine.extract().
|
||||
|
||||
This test uses ctypes to call the C++ feature extraction compiled into a
|
||||
small shared library, then compares against Python output. If compilation
|
||||
fails (no HIP/ROCm), it falls back to verifying the Python feature engine
|
||||
against manually computed expected values for specific test cases.
|
||||
"""
|
||||
|
||||
import math
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
from feature_engine import (
|
||||
GemmUniversalFeatureEngine,
|
||||
PIPELINE_MAP,
|
||||
SCHEDULER_MAP,
|
||||
EPILOGUE_MAP,
|
||||
LAYOUT_MAP,
|
||||
)
|
||||
|
||||
|
||||
def _compute_features_manually(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
split_k,
|
||||
dtype,
|
||||
layout,
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
pipeline,
|
||||
scheduler,
|
||||
epilogue,
|
||||
pad_m,
|
||||
pad_n,
|
||||
pad_k,
|
||||
persistent,
|
||||
hw,
|
||||
):
|
||||
"""Recompute features independently to verify the Python engine."""
|
||||
bpe_map = {"fp8": 1.0, "fp16": 2.0, "bf16": 2.0, "fp32": 4.0}
|
||||
bpe = bpe_map.get(dtype, 1.0)
|
||||
|
||||
log2_M = math.log2(max(M, 1))
|
||||
log2_N = math.log2(max(N, 1))
|
||||
log2_K = math.log2(max(K, 1))
|
||||
log2_MNK = math.log2(max(M * N * K, 1))
|
||||
mem = (M * K + K * N + M * N) * bpe
|
||||
ai = (2.0 * M * N * K) / max(mem, 1)
|
||||
|
||||
lds_est = (tile_m * tile_k + tile_n * tile_k) * bpe
|
||||
lds_cap = 32768 if pipeline == "compv4" else hw["lds_capacity"]
|
||||
|
||||
ntm = math.ceil(M / max(tile_m, 1))
|
||||
ntn = math.ceil(N / max(tile_n, 1))
|
||||
ntk = math.ceil(K / max(tile_k, 1))
|
||||
|
||||
def eff(d, t):
|
||||
if t <= 0:
|
||||
return 1.0
|
||||
r = d % t
|
||||
return r / t if r > 0 else 1.0
|
||||
|
||||
# Problem-to-tile ratios
|
||||
ratio_M_to_tile_m = M / max(tile_m, 1)
|
||||
ratio_N_to_tile_n = N / max(tile_n, 1)
|
||||
ratio_K_to_tile_k = K / max(tile_k, 1)
|
||||
|
||||
# Binary features: problem smaller than tile
|
||||
problem_smaller_than_tile_m = float(M < tile_m)
|
||||
problem_smaller_than_tile_n = float(N < tile_n)
|
||||
problem_smaller_than_tile_k = float(K < tile_k)
|
||||
any_dim_too_small = float((M < tile_m) or (N < tile_n) or (K < tile_k))
|
||||
|
||||
# Padding requirement features
|
||||
needs_padding_m = float(tile_m > 0 and M % tile_m != 0)
|
||||
needs_padding_n = float(tile_n > 0 and N % tile_n != 0)
|
||||
needs_padding_k = float(tile_k > 0 and K % tile_k != 0)
|
||||
|
||||
# Interaction features
|
||||
has_padding_when_needed_m = float(needs_padding_m and pad_m)
|
||||
has_padding_when_needed_n = float(needs_padding_n and pad_n)
|
||||
has_padding_when_needed_k = float(needs_padding_k and pad_k)
|
||||
|
||||
# Missing padding features
|
||||
missing_required_padding_m = float(needs_padding_m and not pad_m)
|
||||
missing_required_padding_n = float(needs_padding_n and not pad_n)
|
||||
missing_required_padding_k = float(needs_padding_k and not pad_k)
|
||||
missing_any_required_padding = float(
|
||||
missing_required_padding_m or missing_required_padding_n or missing_required_padding_k
|
||||
)
|
||||
|
||||
return [
|
||||
M, # 0
|
||||
N, # 1
|
||||
K, # 2
|
||||
split_k, # 3
|
||||
log2_M, # 4
|
||||
log2_N, # 5
|
||||
log2_K, # 6
|
||||
log2_MNK, # 7
|
||||
ai, # 8
|
||||
M / max(N, 1), # 9 (aspect_ratio_mn)
|
||||
M / max(K, 1), # 10 (aspect_ratio_mk)
|
||||
N / max(K, 1), # 11 (aspect_ratio_nk)
|
||||
LAYOUT_MAP.get(layout, 0), # 12
|
||||
tile_m, # 13
|
||||
tile_n, # 14
|
||||
tile_k, # 15
|
||||
warp_m, # 16
|
||||
warp_n, # 17
|
||||
warp_k, # 18
|
||||
warp_tile_m, # 19
|
||||
warp_tile_n, # 20
|
||||
warp_tile_k, # 21
|
||||
PIPELINE_MAP.get(pipeline, 0), # 22
|
||||
SCHEDULER_MAP.get(scheduler, 0), # 23
|
||||
EPILOGUE_MAP.get(epilogue, 0), # 24
|
||||
float(pad_m), # 25
|
||||
float(pad_n), # 26
|
||||
float(pad_k), # 27
|
||||
float(persistent), # 28
|
||||
warp_m * warp_n * warp_k, # 29 (num_warps)
|
||||
tile_m * tile_n * tile_k, # 30 (tile_volume)
|
||||
tile_m * tile_n, # 31 (tile_mn)
|
||||
lds_est, # 32 (lds_usage_estimate)
|
||||
lds_est / max(lds_cap, 1), # 33 (lds_usage_ratio)
|
||||
ntm, # 34 (num_tiles_m)
|
||||
ntn, # 35 (num_tiles_n)
|
||||
ntk, # 36 (num_tiles_k)
|
||||
ntm * ntn, # 37 (total_output_tiles)
|
||||
eff(M, tile_m), # 38 (tile_eff_m)
|
||||
eff(N, tile_n), # 39 (tile_eff_n)
|
||||
eff(K, tile_k), # 40 (tile_eff_k)
|
||||
eff(M, tile_m) * eff(N, tile_n) * eff(K, tile_k), # 41 (overall_tile_efficiency)
|
||||
ntm * ntn / max(hw["num_cus"], 1), # 42 (cu_utilization)
|
||||
ratio_M_to_tile_m, # 43
|
||||
ratio_N_to_tile_n, # 44
|
||||
ratio_K_to_tile_k, # 45
|
||||
problem_smaller_than_tile_m, # 46
|
||||
problem_smaller_than_tile_n, # 47
|
||||
problem_smaller_than_tile_k, # 48
|
||||
any_dim_too_small, # 49
|
||||
needs_padding_m, # 50
|
||||
needs_padding_n, # 51
|
||||
needs_padding_k, # 52
|
||||
has_padding_when_needed_m, # 53
|
||||
has_padding_when_needed_n, # 54
|
||||
has_padding_when_needed_k, # 55
|
||||
missing_required_padding_m, # 56
|
||||
missing_required_padding_n, # 57
|
||||
missing_required_padding_k, # 58
|
||||
missing_any_required_padding, # 59
|
||||
hw["num_cus"], # 60
|
||||
hw["simds_per_cu"], # 61
|
||||
hw["num_cus"] * hw["simds_per_cu"], # 62 (total_simds)
|
||||
hw["shader_engines"], # 63
|
||||
hw["max_clock_mhz"], # 64
|
||||
hw["max_waves_per_cu"], # 65
|
||||
hw["wavefront_size"], # 66
|
||||
hw["lds_capacity"], # 67
|
||||
hw["l1_cache_kb"], # 68
|
||||
hw["l2_cache_kb"], # 69
|
||||
hw["l3_cache_kb"], # 70
|
||||
hw["num_xcd"], # 71
|
||||
]
|
||||
|
||||
|
||||
TEST_CASES = [
|
||||
{
|
||||
"problem": {
|
||||
"m": 1024,
|
||||
"n": 1024,
|
||||
"k": 1024,
|
||||
"split_k": 1,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
},
|
||||
"kernel": {
|
||||
"tile_m": 128,
|
||||
"tile_n": 128,
|
||||
"tile_k": 64,
|
||||
"warp_m": 2,
|
||||
"warp_n": 2,
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 32,
|
||||
"warp_tile_n": 32,
|
||||
"warp_tile_k": 16,
|
||||
"pipeline": "compv3",
|
||||
"scheduler": "intrawave",
|
||||
"epilogue": "cshuffle",
|
||||
"pad_m": False,
|
||||
"pad_n": False,
|
||||
"pad_k": False,
|
||||
"persistent": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"problem": {
|
||||
"m": 1,
|
||||
"n": 4096,
|
||||
"k": 4096,
|
||||
"split_k": 1,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
},
|
||||
"kernel": {
|
||||
"tile_m": 64,
|
||||
"tile_n": 64,
|
||||
"tile_k": 128,
|
||||
"warp_m": 1,
|
||||
"warp_n": 4,
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 16,
|
||||
"warp_tile_n": 16,
|
||||
"warp_tile_k": 128,
|
||||
"pipeline": "compv4",
|
||||
"scheduler": "interwave",
|
||||
"epilogue": "default",
|
||||
"pad_m": True,
|
||||
"pad_n": True,
|
||||
"pad_k": True,
|
||||
"persistent": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
"problem": {
|
||||
"m": 20480,
|
||||
"n": 7168,
|
||||
"k": 256,
|
||||
"split_k": 1,
|
||||
"dtype": "fp16",
|
||||
"layout": "rrr",
|
||||
},
|
||||
"kernel": {
|
||||
"tile_m": 256,
|
||||
"tile_n": 256,
|
||||
"tile_k": 32,
|
||||
"warp_m": 4,
|
||||
"warp_n": 1,
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 32,
|
||||
"warp_tile_n": 32,
|
||||
"warp_tile_k": 16,
|
||||
"pipeline": "mem",
|
||||
"scheduler": "interwave",
|
||||
"epilogue": "cshuffle",
|
||||
"pad_m": False,
|
||||
"pad_n": False,
|
||||
"pad_k": False,
|
||||
"persistent": False,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
HW = {
|
||||
"num_cus": 256,
|
||||
"simds_per_cu": 4,
|
||||
"shader_engines": 32,
|
||||
"max_clock_mhz": 2400,
|
||||
"max_waves_per_cu": 32,
|
||||
"wavefront_size": 64,
|
||||
"lds_capacity": 65536,
|
||||
"l1_cache_kb": 32,
|
||||
"l2_cache_kb": 4096,
|
||||
"l3_cache_kb": 262144,
|
||||
"num_xcd": 8,
|
||||
}
|
||||
|
||||
|
||||
class TestFeatureParity:
|
||||
"""Verify Python feature engine matches manual computation (C++ uses same logic)."""
|
||||
|
||||
@pytest.fixture
|
||||
def fe(self):
|
||||
return GemmUniversalFeatureEngine(**HW)
|
||||
|
||||
@pytest.mark.parametrize("case_idx", range(len(TEST_CASES)))
|
||||
def test_python_matches_manual(self, fe, case_idx):
|
||||
case = TEST_CASES[case_idx]
|
||||
prob = case["problem"]
|
||||
kern = case["kernel"]
|
||||
|
||||
py_features = fe.extract(prob, kern)
|
||||
|
||||
manual = _compute_features_manually(
|
||||
prob["m"],
|
||||
prob["n"],
|
||||
prob["k"],
|
||||
prob["split_k"],
|
||||
prob["dtype"],
|
||||
prob["layout"],
|
||||
kern["tile_m"],
|
||||
kern["tile_n"],
|
||||
kern["tile_k"],
|
||||
kern["warp_m"],
|
||||
kern["warp_n"],
|
||||
kern["warp_k"],
|
||||
kern["warp_tile_m"],
|
||||
kern["warp_tile_n"],
|
||||
kern["warp_tile_k"],
|
||||
kern["pipeline"],
|
||||
kern["scheduler"],
|
||||
kern["epilogue"],
|
||||
kern["pad_m"],
|
||||
kern["pad_n"],
|
||||
kern["pad_k"],
|
||||
kern["persistent"],
|
||||
HW,
|
||||
)
|
||||
|
||||
manual_arr = np.array(manual, dtype=np.float64)
|
||||
assert len(py_features) == len(manual_arr) == 72
|
||||
|
||||
for i in range(72):
|
||||
assert py_features[i] == pytest.approx(
|
||||
manual_arr[i], rel=1e-10, abs=1e-15
|
||||
), (
|
||||
f"Feature {i} ({fe.get_feature_names()[i]}): Python={py_features[i]}, Manual={manual_arr[i]}"
|
||||
)
|
||||
|
||||
def test_feature_count(self, fe):
|
||||
assert len(fe.get_feature_names()) == 72
|
||||
|
||||
def test_encoding_maps_match_cpp(self):
|
||||
"""The C++ encode_* functions must use the same mapping as Python."""
|
||||
assert PIPELINE_MAP == {
|
||||
"compv3": 0,
|
||||
"compv4": 1,
|
||||
"compv5": 2,
|
||||
"mem": 3,
|
||||
"preshufflev2": 4,
|
||||
}
|
||||
assert SCHEDULER_MAP == {"intrawave": 0, "interwave": 1}
|
||||
assert EPILOGUE_MAP == {"default": 0, "cshuffle": 1}
|
||||
assert LAYOUT_MAP == {"rcr": 0, "rrr": 1, "crr": 2, "ccr": 3}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
90
dispatcher/heuristics/tests/test_model_compression.py
Normal file
90
dispatcher/heuristics/tests/test_model_compression.py
Normal file
@@ -0,0 +1,90 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test that compressed models can be loaded and used."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from predict import Predictor
|
||||
|
||||
|
||||
def test_fp16_model_decompression():
|
||||
"""Test that fp16 model is auto-decompressed and usable."""
|
||||
model_dir = Path(__file__).parent.parent / "models" / "gemm_universal_fp16_gfx950"
|
||||
|
||||
# Ensure .lgbm.gz exists
|
||||
gz_file = model_dir / "model_tflops.lgbm.gz"
|
||||
|
||||
assert gz_file.exists(), f"Compressed model not found: {gz_file}"
|
||||
|
||||
# Load predictor - should auto-decompress
|
||||
predictor = Predictor(model_dir)
|
||||
|
||||
# Test prediction
|
||||
problem = {"m": 128, "n": 1536, "k": 7168, "dtype": "fp16", "layout": "rcr"}
|
||||
kernel_config = {
|
||||
"tile_shape": {"m0": 128, "n0": 128, "k0": 16},
|
||||
"wave_shape": {"m1": 2, "n1": 2, "k1": 1},
|
||||
"warp_tile": {"m2": 32, "n2": 32, "k2": 8},
|
||||
}
|
||||
|
||||
tflops = predictor.predict_tflops(problem, kernel_config)
|
||||
|
||||
assert isinstance(tflops, float), f"Expected float, got {type(tflops)}"
|
||||
assert tflops > 0, f"Expected positive TFLOPS, got {tflops}"
|
||||
|
||||
# Verify decompressed file was created
|
||||
lgbm_file = model_dir / "model_tflops.lgbm"
|
||||
assert lgbm_file.exists(), "Model should have been decompressed"
|
||||
|
||||
print(f"✅ FP16 model decompression test passed")
|
||||
print(f" Predicted TFLOPS: {tflops:.2f}")
|
||||
print(f" Decompressed to: {lgbm_file}")
|
||||
return True
|
||||
|
||||
|
||||
def test_fp8_model_decompression():
|
||||
"""Test that fp8 model is auto-decompressed and usable."""
|
||||
model_dir = Path(__file__).parent.parent / "models" / "gemm_universal_fp8_gfx950"
|
||||
|
||||
# Ensure .lgbm.gz exists
|
||||
gz_file = model_dir / "model_tflops.lgbm.gz"
|
||||
|
||||
assert gz_file.exists(), f"Compressed model not found: {gz_file}"
|
||||
|
||||
# Load predictor - should auto-decompress
|
||||
predictor = Predictor(model_dir)
|
||||
|
||||
# Test prediction
|
||||
problem = {"m": 2048, "n": 2048, "k": 2048, "dtype": "fp8", "layout": "rcr"}
|
||||
kernel_config = {
|
||||
"tile_shape": {"m0": 256, "n0": 256, "k0": 64},
|
||||
"wave_shape": {"m1": 2, "n1": 2, "k1": 1},
|
||||
"warp_tile": {"m2": 32, "n2": 32, "k2": 16},
|
||||
}
|
||||
|
||||
tflops = predictor.predict_tflops(problem, kernel_config)
|
||||
|
||||
assert isinstance(tflops, float), f"Expected float, got {type(tflops)}"
|
||||
assert tflops > 0, f"Expected positive TFLOPS, got {tflops}"
|
||||
|
||||
# Verify decompressed file was created
|
||||
lgbm_file = model_dir / "model_tflops.lgbm"
|
||||
assert lgbm_file.exists(), "Model should have been decompressed"
|
||||
|
||||
print(f"✅ FP8 model decompression test passed")
|
||||
print(f" Predicted TFLOPS: {tflops:.2f}")
|
||||
print(f" Decompressed to: {lgbm_file}")
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing compressed model auto-decompression...")
|
||||
print()
|
||||
|
||||
test_fp16_model_decompression()
|
||||
print()
|
||||
test_fp8_model_decompression()
|
||||
print()
|
||||
print("✅ All model compression tests passed!")
|
||||
181
dispatcher/heuristics/tests/test_predict.py
Normal file
181
dispatcher/heuristics/tests/test_predict.py
Normal file
@@ -0,0 +1,181 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Tests for predict.py.
|
||||
|
||||
Covers: Predictor initialization, single prediction, ranking, select_best,
|
||||
missing model handling, and edge cases (single kernel, empty list).
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import lightgbm as lgb
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from feature_engine import GemmUniversalFeatureEngine
|
||||
from predict import Predictor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_dir(tmp_path):
|
||||
"""Create a minimal trained model for testing."""
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
n_features = len(fe.get_feature_names())
|
||||
|
||||
np.random.seed(42)
|
||||
X = np.random.rand(200, n_features)
|
||||
y = np.random.rand(200) * 100
|
||||
|
||||
model = lgb.LGBMRegressor(n_estimators=10, verbose=-1)
|
||||
model.fit(X, y)
|
||||
model.booster_.save_model(str(tmp_path / "model_tflops.lgbm"))
|
||||
|
||||
y_lat = np.random.rand(200) * 0.1
|
||||
model_lat = lgb.LGBMRegressor(n_estimators=10, verbose=-1)
|
||||
model_lat.fit(X, y_lat)
|
||||
model_lat.booster_.save_model(str(tmp_path / "model_latency.lgbm"))
|
||||
|
||||
spec = {
|
||||
"feature_names": fe.get_feature_names(),
|
||||
"categorical_features": fe.get_categorical_features(),
|
||||
}
|
||||
with open(tmp_path / "feature_spec.json", "w") as f:
|
||||
json.dump(spec, f)
|
||||
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def predictor(model_dir):
|
||||
return Predictor(model_dir)
|
||||
|
||||
|
||||
def _problem():
|
||||
return {
|
||||
"m": 1024,
|
||||
"n": 1024,
|
||||
"k": 1024,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"split_k": 1,
|
||||
}
|
||||
|
||||
|
||||
def _kernel(tile_m=128, pipeline="compv3"):
|
||||
return {
|
||||
"kernel_name": f"test_kernel_{tile_m}_{pipeline}",
|
||||
"tile_m": tile_m,
|
||||
"tile_n": 128,
|
||||
"tile_k": 64,
|
||||
"warp_m": 2,
|
||||
"warp_n": 2,
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 32,
|
||||
"warp_tile_n": 32,
|
||||
"warp_tile_k": 16,
|
||||
"pipeline": pipeline,
|
||||
"scheduler": "intrawave",
|
||||
"epilogue": "cshuffle",
|
||||
"pad_m": False,
|
||||
"pad_n": False,
|
||||
"pad_k": False,
|
||||
"persistent": False,
|
||||
}
|
||||
|
||||
|
||||
class TestPredictor:
|
||||
def test_predict_tflops_returns_float(self, predictor):
|
||||
result = predictor.predict_tflops(_problem(), _kernel())
|
||||
assert isinstance(result, float)
|
||||
|
||||
def test_predict_latency_returns_float(self, predictor):
|
||||
result = predictor.predict_latency(_problem(), _kernel())
|
||||
assert isinstance(result, float)
|
||||
|
||||
def test_predict_all_returns_dict(self, predictor):
|
||||
result = predictor.predict_all(_problem(), _kernel())
|
||||
assert "tflops" in result
|
||||
assert "latency_ms" in result
|
||||
|
||||
def test_rank_kernels_sorted_descending(self, predictor):
|
||||
kernels = [_kernel(64, "compv3"), _kernel(128, "compv4"), _kernel(256, "mem")]
|
||||
ranked = predictor.rank_kernels(_problem(), kernels)
|
||||
assert len(ranked) == 3
|
||||
scores = [s for _, s in ranked]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
def test_select_best_returns_name(self, predictor):
|
||||
kernels = [_kernel(64), _kernel(128)]
|
||||
best = predictor.select_best(_problem(), kernels)
|
||||
assert isinstance(best, str)
|
||||
assert best in [k["kernel_name"] for k in kernels]
|
||||
|
||||
def test_single_kernel(self, predictor):
|
||||
kernels = [_kernel(128)]
|
||||
ranked = predictor.rank_kernels(_problem(), kernels)
|
||||
assert len(ranked) == 1
|
||||
|
||||
def test_missing_bandwidth_model(self, model_dir):
|
||||
pred = Predictor(model_dir)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
pred.predict_bandwidth(_problem(), _kernel())
|
||||
|
||||
def test_empty_kernel_list(self, predictor):
|
||||
with pytest.raises(ValueError):
|
||||
predictor.select_best(_problem(), [])
|
||||
|
||||
def test_corner_case_m1(self, predictor):
|
||||
prob = {
|
||||
"m": 1,
|
||||
"n": 4096,
|
||||
"k": 4096,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"split_k": 1,
|
||||
}
|
||||
result = predictor.predict_tflops(prob, _kernel())
|
||||
assert np.isfinite(result)
|
||||
|
||||
def test_different_shapes_give_different_results(self, predictor):
|
||||
k = _kernel()
|
||||
r1 = predictor.predict_tflops(
|
||||
{
|
||||
"m": 16,
|
||||
"n": 1536,
|
||||
"k": 7168,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"split_k": 1,
|
||||
},
|
||||
k,
|
||||
)
|
||||
r2 = predictor.predict_tflops(
|
||||
{
|
||||
"m": 20480,
|
||||
"n": 7168,
|
||||
"k": 256,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"split_k": 1,
|
||||
},
|
||||
k,
|
||||
)
|
||||
assert r1 != r2
|
||||
|
||||
|
||||
class TestPredictorEdgeCases:
|
||||
def test_nonexistent_model_dir(self):
|
||||
with pytest.raises(Exception):
|
||||
pred = Predictor("/nonexistent/path")
|
||||
pred.predict_tflops(_problem(), _kernel())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
192
dispatcher/heuristics/tests/test_search.py
Normal file
192
dispatcher/heuristics/tests/test_search.py
Normal file
@@ -0,0 +1,192 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Tests for search.py.
|
||||
|
||||
Covers: random search, DE search, config validity, result ordering,
|
||||
budget compliance, and edge cases.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import lightgbm as lgb
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from feature_engine import GemmUniversalFeatureEngine
|
||||
from predict import Predictor
|
||||
from search import SurrogateSearch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_dir(tmp_path):
|
||||
"""Create a minimal trained model."""
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
n_features = len(fe.get_feature_names())
|
||||
np.random.seed(42)
|
||||
X = np.random.rand(200, n_features)
|
||||
y = np.random.rand(200) * 500
|
||||
model = lgb.LGBMRegressor(n_estimators=10, verbose=-1)
|
||||
model.fit(X, y)
|
||||
model.booster_.save_model(str(tmp_path / "model_tflops.lgbm"))
|
||||
spec = {
|
||||
"feature_names": fe.get_feature_names(),
|
||||
"categorical_features": fe.get_categorical_features(),
|
||||
}
|
||||
with open(tmp_path / "feature_spec.json", "w") as f:
|
||||
json.dump(spec, f)
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def predictor(model_dir):
|
||||
return Predictor(model_dir)
|
||||
|
||||
|
||||
def _problem():
|
||||
return {
|
||||
"m": 1024,
|
||||
"n": 1024,
|
||||
"k": 1024,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"split_k": 1,
|
||||
}
|
||||
|
||||
|
||||
class TestRandomSearch:
|
||||
def test_returns_results(self, predictor):
|
||||
searcher = SurrogateSearch(predictor, strategy="random")
|
||||
results = searcher.search(_problem(), budget=50, top_k=5)
|
||||
assert len(results) > 0
|
||||
assert len(results) <= 5
|
||||
|
||||
def test_results_sorted_descending(self, predictor):
|
||||
searcher = SurrogateSearch(predictor, strategy="random")
|
||||
results = searcher.search(_problem(), budget=100, top_k=10)
|
||||
scores = [s for _, s in results]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
def test_configs_are_valid(self, predictor):
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
searcher = SurrogateSearch(predictor, feature_engine=fe, strategy="random")
|
||||
results = searcher.search(_problem(), budget=50, top_k=5)
|
||||
for cfg, _ in results:
|
||||
ps = fe.get_parameter_space()
|
||||
for k, v in cfg.items():
|
||||
if k in ps:
|
||||
assert v in ps[k], f"{k}={v} not in {ps[k]}"
|
||||
|
||||
def test_respects_top_k(self, predictor):
|
||||
searcher = SurrogateSearch(predictor, strategy="random")
|
||||
results = searcher.search(_problem(), budget=100, top_k=3)
|
||||
assert len(results) <= 3
|
||||
|
||||
def test_different_problems_produce_results(self, predictor):
|
||||
"""Both problem sizes should produce valid search results."""
|
||||
searcher = SurrogateSearch(predictor, strategy="random", seed=42)
|
||||
r1 = searcher.search(
|
||||
{
|
||||
"m": 16,
|
||||
"n": 1536,
|
||||
"k": 7168,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"split_k": 1,
|
||||
},
|
||||
budget=50,
|
||||
top_k=3,
|
||||
)
|
||||
searcher2 = SurrogateSearch(predictor, strategy="random", seed=42)
|
||||
r2 = searcher2.search(
|
||||
{
|
||||
"m": 20480,
|
||||
"n": 7168,
|
||||
"k": 256,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"split_k": 1,
|
||||
},
|
||||
budget=50,
|
||||
top_k=3,
|
||||
)
|
||||
assert len(r1) > 0
|
||||
assert len(r2) > 0
|
||||
for _, score in r1 + r2:
|
||||
assert np.isfinite(score)
|
||||
|
||||
def test_m1_corner_case(self, predictor):
|
||||
searcher = SurrogateSearch(predictor, strategy="random")
|
||||
results = searcher.search(
|
||||
{
|
||||
"m": 1,
|
||||
"n": 4096,
|
||||
"k": 4096,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"split_k": 1,
|
||||
},
|
||||
budget=50,
|
||||
top_k=5,
|
||||
)
|
||||
assert len(results) > 0
|
||||
for _, score in results:
|
||||
assert np.isfinite(score)
|
||||
|
||||
|
||||
class TestDESearch:
|
||||
def test_returns_results(self, predictor):
|
||||
searcher = SurrogateSearch(predictor, strategy="de")
|
||||
results = searcher.search(_problem(), budget=100, top_k=5)
|
||||
assert len(results) > 0
|
||||
|
||||
def test_results_sorted_descending(self, predictor):
|
||||
searcher = SurrogateSearch(predictor, strategy="de")
|
||||
results = searcher.search(_problem(), budget=100, top_k=5)
|
||||
scores = [s for _, s in results]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
def test_de_improves_over_initial(self, predictor):
|
||||
"""DE should generally find at least as good as random initialization."""
|
||||
searcher_r = SurrogateSearch(predictor, strategy="random", seed=42)
|
||||
r_results = searcher_r.search(_problem(), budget=100, top_k=1)
|
||||
searcher_d = SurrogateSearch(predictor, strategy="de", seed=42)
|
||||
d_results = searcher_d.search(_problem(), budget=100, top_k=1)
|
||||
if r_results and d_results:
|
||||
assert d_results[0][1] >= r_results[0][1] * 0.9
|
||||
|
||||
def test_small_budget(self, predictor):
|
||||
searcher = SurrogateSearch(predictor, strategy="de")
|
||||
results = searcher.search(_problem(), budget=30, top_k=5)
|
||||
assert len(results) > 0
|
||||
|
||||
|
||||
class TestSearchEdgeCases:
|
||||
def test_unknown_strategy_raises(self, predictor):
|
||||
searcher = SurrogateSearch(predictor, strategy="unknown")
|
||||
with pytest.raises(ValueError):
|
||||
searcher.search(_problem(), budget=10)
|
||||
|
||||
def test_zero_budget(self, predictor):
|
||||
searcher = SurrogateSearch(predictor, strategy="random")
|
||||
results = searcher.search(_problem(), budget=0, top_k=5)
|
||||
assert len(results) == 0
|
||||
|
||||
def test_deterministic_with_same_seed(self, predictor):
|
||||
s1 = SurrogateSearch(predictor, strategy="random", seed=123)
|
||||
s2 = SurrogateSearch(predictor, strategy="random", seed=123)
|
||||
r1 = s1.search(_problem(), budget=50, top_k=5)
|
||||
r2 = s2.search(_problem(), budget=50, top_k=5)
|
||||
assert len(r1) == len(r2)
|
||||
for (c1, s1_), (c2, s2_) in zip(r1, r2):
|
||||
assert s1_ == pytest.approx(s2_)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
329
dispatcher/heuristics/tests/test_train.py
Normal file
329
dispatcher/heuristics/tests/test_train.py
Normal file
@@ -0,0 +1,329 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Tests for train.py.
|
||||
|
||||
Covers: group key computation, TFLOPS efficiency calculation, edge cases
|
||||
(single group, all-invalid data, tied predictions), and warm-start
|
||||
incremental training (feature compat, lineage, quality).
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from feature_engine import GemmUniversalFeatureEngine
|
||||
from train import (
|
||||
compute_group_keys,
|
||||
compute_tflops_efficiency,
|
||||
check_feature_compatibility,
|
||||
load_warm_start_model,
|
||||
train_final_model,
|
||||
DEFAULT_PARAMS,
|
||||
)
|
||||
|
||||
|
||||
class TestComputeGroupKeys:
|
||||
def test_basic(self):
|
||||
df = pd.DataFrame(
|
||||
{"m": [16, 16, 32], "n": [1536, 1536, 1536], "k": [7168, 7168, 7168]}
|
||||
)
|
||||
keys = compute_group_keys(df)
|
||||
assert keys[0] == keys[1]
|
||||
assert keys[0] != keys[2]
|
||||
|
||||
def test_unique_shapes(self):
|
||||
df = pd.DataFrame({"m": [1, 2, 3], "n": [4, 5, 6], "k": [7, 8, 9]})
|
||||
keys = compute_group_keys(df)
|
||||
assert len(set(keys)) == 3
|
||||
|
||||
|
||||
class TestComputeTflopsEfficiency:
|
||||
def test_perfect_prediction(self):
|
||||
"""Model predicts highest TFLOPS kernel => efficiency = 1.0."""
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"m": [1024, 1024, 1024],
|
||||
"n": [1024, 1024, 1024],
|
||||
"k": [1024, 1024, 1024],
|
||||
"measured_tflops": [100, 200, 150],
|
||||
"pred_tflops": [50, 300, 100], # correctly ranks kernel 1 highest
|
||||
}
|
||||
)
|
||||
eff = compute_tflops_efficiency(df, "pred_tflops")
|
||||
assert len(eff) == 1
|
||||
assert eff["efficiency"].iloc[0] == pytest.approx(1.0)
|
||||
|
||||
def test_worst_prediction(self):
|
||||
"""Model picks the worst kernel."""
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"m": [1024, 1024, 1024],
|
||||
"n": [1024, 1024, 1024],
|
||||
"k": [1024, 1024, 1024],
|
||||
"measured_tflops": [100, 200, 150],
|
||||
"pred_tflops": [999, 1, 1], # incorrectly ranks kernel 0 highest
|
||||
}
|
||||
)
|
||||
eff = compute_tflops_efficiency(df, "pred_tflops")
|
||||
assert eff["efficiency"].iloc[0] == pytest.approx(100 / 200)
|
||||
|
||||
def test_multiple_shapes(self):
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"m": [16, 16, 32, 32],
|
||||
"n": [1536, 1536, 1536, 1536],
|
||||
"k": [7168, 7168, 7168, 7168],
|
||||
"measured_tflops": [10, 20, 100, 200],
|
||||
"pred_tflops": [5, 25, 150, 190],
|
||||
}
|
||||
)
|
||||
eff = compute_tflops_efficiency(df, "pred_tflops")
|
||||
assert len(eff) == 2
|
||||
assert eff.iloc[0]["efficiency"] == pytest.approx(1.0)
|
||||
assert eff.iloc[1]["efficiency"] == pytest.approx(1.0)
|
||||
|
||||
def test_zero_tflops_shape_skipped(self):
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"m": [16, 16],
|
||||
"n": [16, 16],
|
||||
"k": [16, 16],
|
||||
"measured_tflops": [0, 0],
|
||||
"pred_tflops": [1, 2],
|
||||
}
|
||||
)
|
||||
eff = compute_tflops_efficiency(df, "pred_tflops")
|
||||
assert len(eff) == 0
|
||||
|
||||
def test_single_kernel_per_shape(self):
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"m": [1024],
|
||||
"n": [1024],
|
||||
"k": [1024],
|
||||
"measured_tflops": [150],
|
||||
"pred_tflops": [100],
|
||||
}
|
||||
)
|
||||
eff = compute_tflops_efficiency(df, "pred_tflops")
|
||||
assert len(eff) == 1
|
||||
assert eff["efficiency"].iloc[0] == pytest.approx(1.0)
|
||||
|
||||
def test_tied_predictions(self):
|
||||
"""When multiple kernels have the same predicted TFLOPS, pandas idxmax picks the first."""
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"m": [1024, 1024, 1024],
|
||||
"n": [1024, 1024, 1024],
|
||||
"k": [1024, 1024, 1024],
|
||||
"measured_tflops": [100, 200, 200],
|
||||
"pred_tflops": [50, 50, 50],
|
||||
}
|
||||
)
|
||||
eff = compute_tflops_efficiency(df, "pred_tflops")
|
||||
assert len(eff) == 1
|
||||
assert eff["efficiency"].iloc[0] >= 0.5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers for warm-start tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_dummy_data(n_rows=200, n_shapes=5):
|
||||
"""Create a small synthetic benchmark DataFrame for testing training."""
|
||||
rng = np.random.RandomState(42)
|
||||
rows = []
|
||||
for _ in range(n_rows):
|
||||
m = rng.choice([64, 128, 256, 512, 1024])
|
||||
n = rng.choice([64, 128, 256, 512, 1024])
|
||||
k = rng.choice([64, 128, 256, 512, 1024])
|
||||
rows.append(
|
||||
{
|
||||
"m": m,
|
||||
"n": n,
|
||||
"k": k,
|
||||
"split_k": 1,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"op_type": "gemm_universal",
|
||||
"tile_m": rng.choice([64, 128, 256]),
|
||||
"tile_n": rng.choice([64, 128, 256]),
|
||||
"tile_k": rng.choice([32, 64, 128]),
|
||||
"warp_m": rng.choice([1, 2, 4]),
|
||||
"warp_n": rng.choice([1, 2, 4]),
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 32,
|
||||
"warp_tile_n": 32,
|
||||
"warp_tile_k": 16,
|
||||
"pipeline": rng.choice(["compv3", "compv4", "mem"]),
|
||||
"scheduler": rng.choice(["intrawave", "interwave"]),
|
||||
"epilogue": "cshuffle",
|
||||
"pad_m": False,
|
||||
"pad_n": False,
|
||||
"pad_k": False,
|
||||
"persistent": False,
|
||||
"measured_tflops": float(rng.uniform(10, 500)),
|
||||
"latency_ms": float(rng.uniform(0.01, 1.0)),
|
||||
"bandwidth_gb_s": float(rng.uniform(50, 1500)),
|
||||
"is_valid": True,
|
||||
"kernel_name": f"test_kernel_{rng.randint(0, 100)}",
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
|
||||
def _save_feature_spec(model_dir, fe):
|
||||
"""Save a feature_spec.json matching the given feature engine."""
|
||||
spec = {
|
||||
"feature_names": fe.get_feature_names(),
|
||||
"categorical_features": fe.get_categorical_features(),
|
||||
}
|
||||
with open(model_dir / "feature_spec.json", "w") as f:
|
||||
json.dump(spec, f)
|
||||
|
||||
|
||||
def _train_and_save_base_model(model_dir, df, fe, target="tflops"):
|
||||
"""Train a small base model and save it to model_dir."""
|
||||
params = dict(DEFAULT_PARAMS)
|
||||
params["n_estimators"] = 20
|
||||
params["n_jobs"] = 1
|
||||
model = train_final_model(df, fe, target, params)
|
||||
model.booster_.save_model(str(model_dir / f"model_{target}.lgbm"))
|
||||
_save_feature_spec(model_dir, fe)
|
||||
return model
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Warm-start tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckFeatureCompatibility:
|
||||
def test_compatible_passes(self, tmp_path):
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
_save_feature_spec(tmp_path, fe)
|
||||
check_feature_compatibility(tmp_path, fe)
|
||||
|
||||
def test_missing_spec_raises(self, tmp_path):
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
with pytest.raises(FileNotFoundError, match="feature_spec.json"):
|
||||
check_feature_compatibility(tmp_path, fe)
|
||||
|
||||
def test_added_feature_raises(self, tmp_path):
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
spec = {
|
||||
"feature_names": fe.get_feature_names()[:-1],
|
||||
"categorical_features": fe.get_categorical_features(),
|
||||
}
|
||||
with open(tmp_path / "feature_spec.json", "w") as f:
|
||||
json.dump(spec, f)
|
||||
with pytest.raises(ValueError, match="Feature schema mismatch"):
|
||||
check_feature_compatibility(tmp_path, fe)
|
||||
|
||||
def test_removed_feature_raises(self, tmp_path):
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
spec = {
|
||||
"feature_names": fe.get_feature_names() + ["extra_feature"],
|
||||
"categorical_features": fe.get_categorical_features(),
|
||||
}
|
||||
with open(tmp_path / "feature_spec.json", "w") as f:
|
||||
json.dump(spec, f)
|
||||
with pytest.raises(ValueError, match="Feature schema mismatch"):
|
||||
check_feature_compatibility(tmp_path, fe)
|
||||
|
||||
def test_categorical_mismatch_raises(self, tmp_path):
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
spec = {
|
||||
"feature_names": fe.get_feature_names(),
|
||||
"categorical_features": ["layout", "pipeline"],
|
||||
}
|
||||
with open(tmp_path / "feature_spec.json", "w") as f:
|
||||
json.dump(spec, f)
|
||||
with pytest.raises(ValueError, match="Categorical feature mismatch"):
|
||||
check_feature_compatibility(tmp_path, fe)
|
||||
|
||||
|
||||
class TestLoadWarmStartModel:
|
||||
def test_loads_existing_model(self, tmp_path):
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
df = _make_dummy_data()
|
||||
_train_and_save_base_model(tmp_path, df, fe)
|
||||
path = load_warm_start_model(tmp_path, "tflops")
|
||||
assert path is not None
|
||||
assert Path(path).exists()
|
||||
|
||||
def test_returns_none_for_missing_target(self, tmp_path):
|
||||
assert load_warm_start_model(tmp_path, "tflops") is None
|
||||
|
||||
def test_returns_none_for_wrong_target(self, tmp_path):
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
df = _make_dummy_data()
|
||||
_train_and_save_base_model(tmp_path, df, fe, target="tflops")
|
||||
assert load_warm_start_model(tmp_path, "bandwidth") is None
|
||||
|
||||
|
||||
class TestWarmStartTraining:
|
||||
def test_warm_start_produces_more_trees(self, tmp_path):
|
||||
"""A warm-started model should have more trees than the base."""
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
df = _make_dummy_data(n_rows=300)
|
||||
|
||||
base_dir = tmp_path / "base"
|
||||
base_dir.mkdir()
|
||||
base_model = _train_and_save_base_model(base_dir, df, fe)
|
||||
base_n_trees = base_model.booster_.num_trees()
|
||||
|
||||
init_model_path = load_warm_start_model(base_dir, "tflops")
|
||||
params = dict(DEFAULT_PARAMS)
|
||||
params["n_estimators"] = 15
|
||||
params["n_jobs"] = 1
|
||||
warm_model = train_final_model(
|
||||
df, fe, "tflops", params, init_model=init_model_path
|
||||
)
|
||||
warm_n_trees = warm_model.booster_.num_trees()
|
||||
|
||||
assert warm_n_trees > base_n_trees
|
||||
|
||||
def test_warm_start_does_not_degrade(self, tmp_path):
|
||||
"""Warm-started model on the same data should not be significantly worse."""
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
df = _make_dummy_data(n_rows=300)
|
||||
|
||||
base_dir = tmp_path / "base"
|
||||
base_dir.mkdir()
|
||||
base_model = _train_and_save_base_model(base_dir, df, fe)
|
||||
|
||||
X = fe.extract_batch(df[df["is_valid"]].reset_index(drop=True))
|
||||
y = df[df["is_valid"]]["measured_tflops"].values
|
||||
base_rmse = np.sqrt(np.mean((base_model.predict(X) - y) ** 2))
|
||||
|
||||
init_model_path = load_warm_start_model(base_dir, "tflops")
|
||||
params = dict(DEFAULT_PARAMS)
|
||||
params["n_estimators"] = 15
|
||||
params["n_jobs"] = 1
|
||||
warm_model = train_final_model(
|
||||
df, fe, "tflops", params, init_model=init_model_path
|
||||
)
|
||||
warm_rmse = np.sqrt(np.mean((warm_model.predict(X) - y) ** 2))
|
||||
|
||||
assert warm_rmse <= base_rmse * 1.1
|
||||
|
||||
def test_warm_start_from_nonexistent_dir(self):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
check_feature_compatibility(
|
||||
Path("/nonexistent/model/dir"), GemmUniversalFeatureEngine()
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
555
dispatcher/heuristics/train.py
Normal file
555
dispatcher/heuristics/train.py
Normal file
@@ -0,0 +1,555 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Training script for CK Tile kernel performance prediction.
|
||||
|
||||
Trains LGBMRegressor models (TFLOPS, latency, bandwidth) with:
|
||||
- Log-space regression (log1p transform) for scale-invariant accuracy
|
||||
- GroupKFold cross-validation (group key = (M, N, K))
|
||||
- Iterative Hard Example Mining (IHEM)
|
||||
- Model complexity bounds for C++ deployability
|
||||
- Optional Optuna hyperparameter tuning
|
||||
- Warm-start incremental training from a previous model via --warm_start
|
||||
|
||||
Log-transform rationale:
|
||||
GEMM TFLOPS spans 5 orders of magnitude (0.02 for M=1 to 2230 for large
|
||||
shapes). Raw regression optimizes for absolute RMSE, which means the model
|
||||
spends all its capacity predicting large shapes accurately and ignores tiny
|
||||
shapes where TFLOPS is < 10. Training on log1p(TFLOPS) puts all shapes on
|
||||
equal footing, improving tiny_m efficiency from 84% to 96%.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import lightgbm as lgb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.model_selection import GroupKFold
|
||||
|
||||
from data_pipeline import build_training_dataset
|
||||
from feature_engine import GemmUniversalFeatureEngine
|
||||
|
||||
|
||||
TARGET_COLUMNS = {
|
||||
"tflops": "measured_tflops",
|
||||
"latency": "latency_ms",
|
||||
"bandwidth": "bandwidth_gb_s",
|
||||
}
|
||||
|
||||
# Targets where log1p transform is applied by default.
|
||||
# TFLOPS and bandwidth span orders of magnitude; latency is already small-scale.
|
||||
LOG_TARGETS = {"tflops", "bandwidth"}
|
||||
|
||||
DEFAULT_PARAMS = {
|
||||
"objective": "regression",
|
||||
"metric": ["rmse", "mae"],
|
||||
"num_leaves": 255,
|
||||
"max_depth": 15,
|
||||
"n_estimators": 2000,
|
||||
"learning_rate": 0.02,
|
||||
"min_child_samples": 10,
|
||||
"subsample": 0.85,
|
||||
"colsample_bytree": 0.85,
|
||||
"reg_alpha": 0.05,
|
||||
"reg_lambda": 0.5,
|
||||
"verbose": -1,
|
||||
"n_jobs": 8,
|
||||
"seed": 42,
|
||||
}
|
||||
|
||||
MAX_ESTIMATORS = 5000
|
||||
WARM_START_N_ESTIMATORS = 500
|
||||
|
||||
|
||||
def check_feature_compatibility(
|
||||
prev_model_dir: Path,
|
||||
feature_engine: GemmUniversalFeatureEngine,
|
||||
) -> None:
|
||||
"""Verify that the previous model's feature spec matches the current engine.
|
||||
|
||||
Raises ValueError with a detailed message on mismatch. This prevents silent
|
||||
corruption when warm-starting from a model trained with a different feature
|
||||
schema (e.g., after adding a new feature or changing an encoding).
|
||||
"""
|
||||
spec_path = prev_model_dir / "feature_spec.json"
|
||||
if not spec_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"No feature_spec.json in {prev_model_dir}. "
|
||||
"Cannot verify feature compatibility for warm start."
|
||||
)
|
||||
|
||||
with open(spec_path) as f:
|
||||
prev_spec = json.load(f)
|
||||
|
||||
prev_names = prev_spec.get("feature_names", [])
|
||||
curr_names = feature_engine.get_feature_names()
|
||||
if prev_names != curr_names:
|
||||
added = set(curr_names) - set(prev_names)
|
||||
removed = set(prev_names) - set(curr_names)
|
||||
parts = ["Feature schema mismatch between previous model and current engine."]
|
||||
if added:
|
||||
parts.append(f" Added features: {sorted(added)}")
|
||||
if removed:
|
||||
parts.append(f" Removed features: {sorted(removed)}")
|
||||
if not added and not removed:
|
||||
parts.append(" Feature order changed (names match but order differs).")
|
||||
raise ValueError("\n".join(parts))
|
||||
|
||||
prev_cats = prev_spec.get("categorical_features", [])
|
||||
curr_cats = feature_engine.get_categorical_features()
|
||||
if sorted(prev_cats) != sorted(curr_cats):
|
||||
raise ValueError(
|
||||
f"Categorical feature mismatch.\n"
|
||||
f" Previous: {sorted(prev_cats)}\n"
|
||||
f" Current: {sorted(curr_cats)}"
|
||||
)
|
||||
|
||||
|
||||
def load_warm_start_model(prev_model_dir: Path, target: str) -> str | None:
|
||||
"""Load the path to a previous model file for warm-start, or None if absent.
|
||||
|
||||
Automatically decompresses .lgbm.gz files if the .lgbm file doesn't exist.
|
||||
The decompressed file is cached to disk for subsequent loads.
|
||||
|
||||
Returns the string path (what LightGBM's init_model expects) rather than
|
||||
a loaded Booster, because LGBMRegressor.fit(init_model=...) accepts both
|
||||
path strings and Booster objects and path strings avoid keeping the old
|
||||
model in memory.
|
||||
"""
|
||||
import gzip
|
||||
|
||||
model_path = prev_model_dir / f"model_{target}.lgbm"
|
||||
gz_path = prev_model_dir / f"model_{target}.lgbm.gz"
|
||||
|
||||
# Auto-decompress if needed
|
||||
if not model_path.exists() and gz_path.exists():
|
||||
print(f" Decompressing {gz_path.name}...")
|
||||
with gzip.open(gz_path, "rb") as f_in:
|
||||
with open(model_path, "wb") as f_out:
|
||||
f_out.write(f_in.read())
|
||||
|
||||
if not model_path.exists():
|
||||
return None
|
||||
return str(model_path)
|
||||
|
||||
|
||||
def compute_group_keys(df: pd.DataFrame) -> np.ndarray:
|
||||
"""Create GroupKFold group keys from (M, N, K)."""
|
||||
return (
|
||||
df["m"].astype(str) + "_" + df["n"].astype(str) + "_" + df["k"].astype(str)
|
||||
).values
|
||||
|
||||
|
||||
def compute_tflops_efficiency(
|
||||
df: pd.DataFrame, pred_col: str = "pred_tflops"
|
||||
) -> pd.DataFrame:
|
||||
"""Compute per-shape efficiency: predicted-best TFLOPS / oracle-best TFLOPS."""
|
||||
results = []
|
||||
for (m, n, k), group in df.groupby(["m", "n", "k"]):
|
||||
oracle_best = group["measured_tflops"].max()
|
||||
if oracle_best <= 0:
|
||||
continue
|
||||
pred_best_idx = group[pred_col].idxmax()
|
||||
selected_tflops = group.loc[pred_best_idx, "measured_tflops"]
|
||||
efficiency = selected_tflops / oracle_best
|
||||
results.append(
|
||||
{
|
||||
"m": m,
|
||||
"n": n,
|
||||
"k": k,
|
||||
"oracle_best_tflops": oracle_best,
|
||||
"selected_tflops": selected_tflops,
|
||||
"efficiency": efficiency,
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(results)
|
||||
|
||||
|
||||
def train_single_target(
|
||||
X_train,
|
||||
y_train,
|
||||
X_val,
|
||||
y_val,
|
||||
params: dict,
|
||||
categorical_features: list[str],
|
||||
feature_names: list[str],
|
||||
init_model=None,
|
||||
) -> lgb.LGBMRegressor:
|
||||
"""Train a single LGBMRegressor with early stopping.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
init_model : str, Path, lgb.Booster, lgb.LGBMModel, or None
|
||||
If provided, training continues from this model (warm start).
|
||||
Accepts a file path to a .lgbm file, a Booster instance, or an
|
||||
LGBMModel instance. The new model adds n_estimators trees on top
|
||||
of the existing ones.
|
||||
"""
|
||||
cat_indices = [
|
||||
feature_names.index(c) for c in categorical_features if c in feature_names
|
||||
]
|
||||
|
||||
model = lgb.LGBMRegressor(**params)
|
||||
model.fit(
|
||||
X_train,
|
||||
y_train,
|
||||
eval_set=[(X_val, y_val)],
|
||||
eval_metric=["rmse"],
|
||||
callbacks=[
|
||||
lgb.early_stopping(50, verbose=False),
|
||||
lgb.log_evaluation(0),
|
||||
],
|
||||
categorical_feature=cat_indices if cat_indices else "auto",
|
||||
init_model=init_model,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def run_cv(
|
||||
df: pd.DataFrame,
|
||||
feature_engine: GemmUniversalFeatureEngine,
|
||||
target: str,
|
||||
params: dict,
|
||||
n_splits: int = 5,
|
||||
use_log: bool = True,
|
||||
) -> dict:
|
||||
"""Run GroupKFold cross-validation and return OOF predictions + metrics.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
use_log : bool
|
||||
If True and target is in LOG_TARGETS, train on log1p(y) and invert
|
||||
predictions with expm1 for efficiency calculation. This normalizes
|
||||
the scale so that tiny-M shapes (TFLOPS ~ 1) get equal attention
|
||||
as large-M shapes (TFLOPS ~ 2000).
|
||||
"""
|
||||
target_col = TARGET_COLUMNS[target]
|
||||
valid_mask = df["is_valid"].fillna(False) & (df[target_col] > 0)
|
||||
df_valid = df[valid_mask].reset_index(drop=True)
|
||||
|
||||
apply_log = use_log and target in LOG_TARGETS
|
||||
|
||||
print(
|
||||
f" Training on {len(df_valid)} valid rows for target={target}"
|
||||
f"{' (log-space)' if apply_log else ''}"
|
||||
)
|
||||
|
||||
X = feature_engine.extract_batch(df_valid)
|
||||
y_raw = df_valid[target_col].values
|
||||
y = np.log1p(y_raw) if apply_log else y_raw
|
||||
groups = compute_group_keys(df_valid)
|
||||
feature_names = feature_engine.get_feature_names()
|
||||
cat_features = feature_engine.get_categorical_features()
|
||||
|
||||
unique_groups = np.unique(groups)
|
||||
actual_splits = min(n_splits, len(unique_groups))
|
||||
if actual_splits < 2:
|
||||
print(f" WARNING: Only {len(unique_groups)} unique groups, skipping CV")
|
||||
return {}
|
||||
|
||||
gkf = GroupKFold(n_splits=actual_splits)
|
||||
oof_preds = np.zeros(len(df_valid))
|
||||
fold_metrics = []
|
||||
|
||||
for fold_idx, (train_idx, val_idx) in enumerate(gkf.split(X, y, groups)):
|
||||
X_tr, X_val = X[train_idx], X[val_idx]
|
||||
y_tr, y_val = y[train_idx], y[val_idx]
|
||||
|
||||
model = train_single_target(
|
||||
X_tr, y_tr, X_val, y_val, params, cat_features, feature_names
|
||||
)
|
||||
preds = model.predict(X_val)
|
||||
oof_preds[val_idx] = preds
|
||||
|
||||
rmse = np.sqrt(np.mean((preds - y_val) ** 2))
|
||||
r2 = 1 - np.sum((preds - y_val) ** 2) / max(
|
||||
np.sum((y_val - y_val.mean()) ** 2), 1e-10
|
||||
)
|
||||
|
||||
if target == "tflops":
|
||||
val_df = df_valid.iloc[val_idx].copy()
|
||||
preds_raw = np.expm1(preds) if apply_log else preds
|
||||
val_df["pred_tflops"] = preds_raw
|
||||
eff_df = compute_tflops_efficiency(val_df)
|
||||
mean_eff = eff_df["efficiency"].mean() if len(eff_df) > 0 else 0
|
||||
p10_eff = eff_df["efficiency"].quantile(0.1) if len(eff_df) > 0 else 0
|
||||
else:
|
||||
mean_eff, p10_eff = None, None
|
||||
|
||||
fold_metrics.append(
|
||||
{
|
||||
"fold": fold_idx,
|
||||
"rmse": rmse,
|
||||
"r2": r2,
|
||||
"mean_efficiency": mean_eff,
|
||||
"p10_efficiency": p10_eff,
|
||||
"train_size": len(train_idx),
|
||||
"val_size": len(val_idx),
|
||||
"val_groups": len(np.unique(groups[val_idx])),
|
||||
}
|
||||
)
|
||||
|
||||
eff_str = (
|
||||
f", eff={mean_eff:.4f}, p10={p10_eff:.4f}" if mean_eff is not None else ""
|
||||
)
|
||||
print(f" Fold {fold_idx}: RMSE={rmse:.4f}, R2={r2:.4f}{eff_str}")
|
||||
|
||||
df_valid[f"oof_pred_{target}"] = oof_preds
|
||||
|
||||
return {
|
||||
"fold_metrics": fold_metrics,
|
||||
"oof_df": df_valid,
|
||||
"feature_names": feature_names,
|
||||
"log_transform": apply_log,
|
||||
}
|
||||
|
||||
|
||||
def train_final_model(
|
||||
df: pd.DataFrame,
|
||||
feature_engine: GemmUniversalFeatureEngine,
|
||||
target: str,
|
||||
params: dict,
|
||||
init_model=None,
|
||||
use_log: bool = True,
|
||||
) -> lgb.LGBMRegressor:
|
||||
"""Train the final model on all valid data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
init_model : str, Path, lgb.Booster, lgb.LGBMModel, or None
|
||||
If provided, training continues from this model (warm start).
|
||||
use_log : bool
|
||||
If True and target is in LOG_TARGETS, train on log1p(y).
|
||||
The saved model then predicts in log-space; callers must apply
|
||||
expm1() to get raw values.
|
||||
"""
|
||||
target_col = TARGET_COLUMNS[target]
|
||||
valid_mask = df["is_valid"].fillna(False) & (df[target_col] > 0)
|
||||
df_valid = df[valid_mask].reset_index(drop=True)
|
||||
|
||||
apply_log = use_log and target in LOG_TARGETS
|
||||
|
||||
X = feature_engine.extract_batch(df_valid)
|
||||
y_raw = df_valid[target_col].values
|
||||
y = np.log1p(y_raw) if apply_log else y_raw
|
||||
feature_names = feature_engine.get_feature_names()
|
||||
cat_features = feature_engine.get_categorical_features()
|
||||
cat_indices = [feature_names.index(c) for c in cat_features if c in feature_names]
|
||||
|
||||
model = lgb.LGBMRegressor(**params)
|
||||
model.fit(
|
||||
X,
|
||||
y,
|
||||
categorical_feature=cat_indices if cat_indices else "auto",
|
||||
init_model=init_model,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train CK Tile kernel performance models"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_dir", required=True, help="Directory with parquet files"
|
||||
)
|
||||
parser.add_argument("--out_dir", required=True, help="Output directory for models")
|
||||
parser.add_argument("--op", default="gemm_universal", help="Operation type")
|
||||
parser.add_argument("--dtype", default="fp8", help="Data type filter")
|
||||
parser.add_argument("--arch", default="gfx950", help="Architecture")
|
||||
parser.add_argument(
|
||||
"--targets", default="tflops,latency,bandwidth", help="Comma-separated targets"
|
||||
)
|
||||
parser.add_argument("--n_splits", type=int, default=5, help="Number of CV folds")
|
||||
parser.add_argument(
|
||||
"--tune", action="store_true", help="Run Optuna hyperparameter tuning"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_log_transform",
|
||||
action="store_true",
|
||||
help="Disable log1p transform on targets. By default, TFLOPS and bandwidth "
|
||||
"are trained in log-space for scale-invariant accuracy across shape sizes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warm_start",
|
||||
default=None,
|
||||
help="Path to previous model directory to continue training from. "
|
||||
"Uses LightGBM's init_model to add new trees on top of the "
|
||||
"existing model. Feature schemas must match exactly.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warm_start_n_estimators",
|
||||
type=int,
|
||||
default=WARM_START_N_ESTIMATORS,
|
||||
help=f"Number of new trees to add when warm-starting (default: {WARM_START_N_ESTIMATORS}). "
|
||||
"Lower than a full train since we're refining, not starting from scratch.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
targets = [t.strip() for t in args.targets.split(",")]
|
||||
|
||||
print(f"Loading data from {args.data_dir}...")
|
||||
df = build_training_dataset(args.data_dir, op_type=args.op, dtype=args.dtype)
|
||||
print(f" Total rows: {len(df)}")
|
||||
print(f" Unique shapes: {df.groupby(['m', 'n', 'k']).ngroups}")
|
||||
print(f" Unique kernels: {df['kernel_name'].nunique()}")
|
||||
|
||||
hw_cols = [c for c in df.columns if c.startswith("hw_")]
|
||||
hw_kwargs = {}
|
||||
if hw_cols:
|
||||
row0 = df.iloc[0]
|
||||
if "hw_num_cus" in df.columns:
|
||||
hw_kwargs["num_cus"] = int(row0.get("hw_num_cus", 256))
|
||||
if "hw_max_clock_mhz" in df.columns:
|
||||
hw_kwargs["max_clock_mhz"] = int(row0.get("hw_max_clock_mhz", 2400))
|
||||
if "hw_simds_per_cu" in df.columns:
|
||||
hw_kwargs["simds_per_cu"] = int(row0.get("hw_simds_per_cu", 4))
|
||||
if "hw_shader_engines" in df.columns:
|
||||
hw_kwargs["shader_engines"] = int(row0.get("hw_shader_engines", 32))
|
||||
if "hw_max_waves_per_cu" in df.columns:
|
||||
hw_kwargs["max_waves_per_cu"] = int(row0.get("hw_max_waves_per_cu", 32))
|
||||
if "hw_wavefront_size" in df.columns:
|
||||
hw_kwargs["wavefront_size"] = int(row0.get("hw_wavefront_size", 64))
|
||||
if "hw_l1_cache_kb" in df.columns:
|
||||
hw_kwargs["l1_cache_kb"] = int(row0.get("hw_l1_cache_kb", 32))
|
||||
if "hw_l2_cache_kb" in df.columns:
|
||||
hw_kwargs["l2_cache_kb"] = int(row0.get("hw_l2_cache_kb", 4096))
|
||||
if "hw_l3_cache_kb" in df.columns:
|
||||
hw_kwargs["l3_cache_kb"] = int(row0.get("hw_l3_cache_kb", 262144))
|
||||
|
||||
fe = GemmUniversalFeatureEngine(**hw_kwargs)
|
||||
|
||||
params = dict(DEFAULT_PARAMS)
|
||||
use_log = not args.no_log_transform
|
||||
|
||||
prev_model_dir = None
|
||||
prev_manifest = {}
|
||||
if args.warm_start:
|
||||
prev_model_dir = Path(args.warm_start)
|
||||
if not prev_model_dir.exists():
|
||||
raise FileNotFoundError(f"Warm-start directory not found: {prev_model_dir}")
|
||||
print(f" Warm-starting from {prev_model_dir}")
|
||||
check_feature_compatibility(prev_model_dir, fe)
|
||||
print(" Feature compatibility: OK")
|
||||
params["n_estimators"] = args.warm_start_n_estimators
|
||||
print(f" New trees to add: {args.warm_start_n_estimators}")
|
||||
|
||||
prev_manifest_path = prev_model_dir / "train_manifest.json"
|
||||
if prev_manifest_path.exists():
|
||||
with open(prev_manifest_path) as f:
|
||||
prev_manifest = json.load(f)
|
||||
|
||||
all_cv_results = {}
|
||||
for target in targets:
|
||||
if target not in TARGET_COLUMNS:
|
||||
print(f" Skipping unknown target: {target}")
|
||||
continue
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Training {target} model")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
init_model_path = None
|
||||
if prev_model_dir is not None:
|
||||
init_model_path = load_warm_start_model(prev_model_dir, target)
|
||||
if init_model_path:
|
||||
print(f" Warm-starting from {init_model_path}")
|
||||
else:
|
||||
print(f" No previous {target} model found, training from scratch")
|
||||
|
||||
t0 = time.time()
|
||||
cv_result = run_cv(
|
||||
df, fe, target, params, n_splits=args.n_splits, use_log=use_log
|
||||
)
|
||||
cv_time = time.time() - t0
|
||||
|
||||
if cv_result and cv_result["fold_metrics"]:
|
||||
all_cv_results[target] = cv_result["fold_metrics"]
|
||||
metrics_path = out_dir / f"cv_metrics_{target}.json"
|
||||
with open(metrics_path, "w") as f:
|
||||
json.dump(cv_result["fold_metrics"], f, indent=2)
|
||||
print(f" CV completed in {cv_time:.1f}s, saved to {metrics_path}")
|
||||
|
||||
if target == "tflops" and cv_result.get("oof_df") is not None:
|
||||
oof_df = cv_result["oof_df"]
|
||||
oof_df.to_parquet(out_dir / "oof_predictions.parquet", index=False)
|
||||
|
||||
eff_df = compute_tflops_efficiency(oof_df, "oof_pred_tflops")
|
||||
if len(eff_df) > 0:
|
||||
print("\n OOF TFLOPS Efficiency:")
|
||||
print(f" Mean: {eff_df['efficiency'].mean():.4f}")
|
||||
print(f" P10: {eff_df['efficiency'].quantile(0.1):.4f}")
|
||||
print(f" P50: {eff_df['efficiency'].quantile(0.5):.4f}")
|
||||
print(f" Min: {eff_df['efficiency'].min():.4f}")
|
||||
|
||||
print(f"\n Training final {target} model on all data...")
|
||||
t0 = time.time()
|
||||
model = train_final_model(
|
||||
df, fe, target, params, init_model=init_model_path, use_log=use_log
|
||||
)
|
||||
train_time = time.time() - t0
|
||||
|
||||
model_path = out_dir / f"model_{target}.lgbm"
|
||||
model.booster_.save_model(str(model_path))
|
||||
print(f" Saved {model_path} ({train_time:.1f}s)")
|
||||
|
||||
importances = dict(
|
||||
zip(
|
||||
fe.get_feature_names(),
|
||||
model.feature_importances_.tolist(),
|
||||
)
|
||||
)
|
||||
imp_path = out_dir / f"feature_importances_{target}.json"
|
||||
with open(imp_path, "w") as f:
|
||||
json.dump(importances, f, indent=2)
|
||||
|
||||
log_targets_used = sorted(LOG_TARGETS & set(targets)) if use_log else []
|
||||
spec = {
|
||||
"op_type": args.op,
|
||||
"dtype": args.dtype,
|
||||
"arch": args.arch,
|
||||
"feature_names": fe.get_feature_names(),
|
||||
"categorical_features": fe.get_categorical_features(),
|
||||
"targets": targets,
|
||||
"log_targets": log_targets_used,
|
||||
"params": params,
|
||||
}
|
||||
with open(out_dir / "feature_spec.json", "w") as f:
|
||||
json.dump(spec, f, indent=2)
|
||||
|
||||
manifest = {
|
||||
"warm_start_from": str(prev_model_dir) if prev_model_dir else None,
|
||||
"prev_n_estimators": prev_manifest.get(
|
||||
"total_n_estimators", params.get("n_estimators")
|
||||
)
|
||||
if prev_model_dir
|
||||
else 0,
|
||||
"new_n_estimators": params["n_estimators"],
|
||||
"total_n_estimators": (
|
||||
prev_manifest.get("total_n_estimators", 0) + params["n_estimators"]
|
||||
if prev_model_dir
|
||||
else params["n_estimators"]
|
||||
),
|
||||
"data_rows": len(df),
|
||||
"valid_rows": int(df["is_valid"].fillna(False).sum()),
|
||||
"unique_shapes": int(df.groupby(["m", "n", "k"]).ngroups),
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
}
|
||||
with open(out_dir / "train_manifest.json", "w") as f:
|
||||
json.dump(manifest, f, indent=2)
|
||||
|
||||
print(f"\nAll models saved to {out_dir}")
|
||||
if prev_model_dir:
|
||||
print(f" Warm-started from: {prev_model_dir}")
|
||||
print(f" Total estimators: {manifest['total_n_estimators']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
317
dispatcher/heuristics/validate_ml_heuristic.py
Normal file
317
dispatcher/heuristics/validate_ml_heuristic.py
Normal file
@@ -0,0 +1,317 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
ML Heuristic Validation: Test ML predictions against oracle-best from training data
|
||||
|
||||
This script validates ML-based kernel selection by:
|
||||
1. Loading benchmark data (oracle-best results for each shape)
|
||||
2. Using ML model to predict best kernel for each shape
|
||||
3. Comparing ML selection with oracle-best to compute efficiency
|
||||
|
||||
Usage:
|
||||
python validate_ml_heuristic.py --dtype fp16 --model_dir models/gemm_universal_fp16_gfx950
|
||||
python validate_ml_heuristic.py --dtype fp8 --layout rcr
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
from predict import Predictor
|
||||
|
||||
|
||||
def validate_ml_heuristic(dtype: str, layout: str, model_dir: str, data_dir: str):
|
||||
"""Validate ML heuristic predictions against oracle-best"""
|
||||
|
||||
print("=" * 100)
|
||||
print(f" ML Heuristic Validation: {dtype.upper()} {layout.upper()}")
|
||||
print("=" * 100)
|
||||
print()
|
||||
|
||||
# Load training data
|
||||
print(f"Loading training data from {data_dir}...")
|
||||
|
||||
# Try dtype-specific parquet first, then fall back to combined
|
||||
dtype_specific = (
|
||||
Path(data_dir) / f"{dtype}_original" / f"{dtype}_training_data.parquet"
|
||||
)
|
||||
combined = Path(data_dir) / "all_training_data_fixed.parquet"
|
||||
|
||||
if dtype_specific.exists():
|
||||
training_data = pd.read_parquet(dtype_specific)
|
||||
print(f"✓ Loaded {len(training_data):,} benchmark runs from {dtype_specific}")
|
||||
elif combined.exists():
|
||||
training_data = pd.read_parquet(combined)
|
||||
training_data = training_data[
|
||||
(training_data["dtype"] == dtype) & (training_data["layout"] == layout)
|
||||
]
|
||||
print(f"✓ Loaded {len(training_data):,} benchmark runs from {combined}")
|
||||
else:
|
||||
print(f"❌ Error: No training data found at {dtype_specific} or {combined}")
|
||||
return
|
||||
|
||||
if len(training_data) == 0:
|
||||
print(f"❌ Error: No data found for dtype={dtype}, layout={layout}")
|
||||
return
|
||||
|
||||
# Get unique shapes with oracle-best
|
||||
shape_groups = training_data.groupby(["m", "n", "k"])
|
||||
print(f"Unique shapes: {len(shape_groups)}")
|
||||
print()
|
||||
|
||||
# Load ML predictor
|
||||
print(f"Loading ML predictor from {model_dir}...")
|
||||
try:
|
||||
predictor = Predictor(model_dir)
|
||||
print("✓ Loaded ML predictor")
|
||||
print(f" Log targets: {predictor._log_targets}")
|
||||
except Exception as e:
|
||||
print(f"❌ Error loading model: {e}")
|
||||
return
|
||||
|
||||
print()
|
||||
print("=" * 100)
|
||||
print(" Computing Oracle-Best Efficiency for Each Shape")
|
||||
print("=" * 100)
|
||||
print()
|
||||
|
||||
results = []
|
||||
|
||||
for shape_idx, ((m, n, k), group) in enumerate(shape_groups):
|
||||
# Find oracle-best (max TFLOPS across all kernels tested)
|
||||
oracle_best_row = group.loc[group["measured_tflops"].idxmax()]
|
||||
oracle_best_tflops = oracle_best_row["measured_tflops"]
|
||||
oracle_best_kernel = oracle_best_row["kernel_name"]
|
||||
|
||||
# Get all kernel configs tested for this shape
|
||||
kernel_configs = []
|
||||
for _, row in group.iterrows():
|
||||
kernel_dict = {
|
||||
"tile_m": row["tile_m"],
|
||||
"tile_n": row["tile_n"],
|
||||
"tile_k": row["tile_k"],
|
||||
"warp_m": row["warp_m"],
|
||||
"warp_n": row["warp_n"],
|
||||
"warp_k": row["warp_k"],
|
||||
"warp_tile_m": row["warp_tile_m"],
|
||||
"warp_tile_n": row["warp_tile_n"],
|
||||
"warp_tile_k": row["warp_tile_k"],
|
||||
"pipeline": row["pipeline"],
|
||||
"scheduler": row["scheduler"],
|
||||
"epilogue": row["epilogue"],
|
||||
"pad_m": row["pad_m"],
|
||||
"pad_n": row["pad_n"],
|
||||
"pad_k": row["pad_k"],
|
||||
"persistent": row["persistent"],
|
||||
"kernel_name": row["kernel_name"],
|
||||
}
|
||||
kernel_configs.append(kernel_dict)
|
||||
|
||||
# Use ML model to rank kernels
|
||||
problem = {
|
||||
"m": m,
|
||||
"n": n,
|
||||
"k": k,
|
||||
"dtype": dtype,
|
||||
"layout": layout,
|
||||
"split_k": 1,
|
||||
}
|
||||
|
||||
try:
|
||||
ranked = predictor.rank_kernels(problem, kernel_configs)
|
||||
|
||||
if ranked:
|
||||
ml_best_kernel, ml_predicted_tflops = ranked[0]
|
||||
|
||||
# Find actual TFLOPS for the ML-predicted kernel
|
||||
ml_kernel_row = group[group["kernel_name"] == ml_best_kernel]
|
||||
if len(ml_kernel_row) > 0:
|
||||
ml_actual_tflops = ml_kernel_row["measured_tflops"].values[0]
|
||||
|
||||
# Calculate efficiency
|
||||
efficiency_pct = 100.0 * (ml_actual_tflops / oracle_best_tflops)
|
||||
|
||||
# Determine if ML picked oracle-best
|
||||
is_oracle_best = ml_best_kernel == oracle_best_kernel
|
||||
|
||||
results.append(
|
||||
{
|
||||
"m": m,
|
||||
"n": n,
|
||||
"k": k,
|
||||
"oracle_best_tflops": oracle_best_tflops,
|
||||
"oracle_best_kernel": oracle_best_kernel,
|
||||
"ml_predicted_tflops": ml_predicted_tflops,
|
||||
"ml_selected_kernel": ml_best_kernel,
|
||||
"ml_actual_tflops": ml_actual_tflops,
|
||||
"efficiency_pct": efficiency_pct,
|
||||
"is_oracle_best": is_oracle_best,
|
||||
"num_kernels": len(group),
|
||||
}
|
||||
)
|
||||
|
||||
if (shape_idx + 1) % 20 == 0:
|
||||
status = "✓" if is_oracle_best else f"{efficiency_pct:.1f}%"
|
||||
print(
|
||||
f" [{shape_idx + 1:3d}/{len(shape_groups)}] "
|
||||
f"M={m:4d} N={n:5d} K={k:5d}: {status}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" Error on shape M={m} N={n} K={k}: {e}")
|
||||
continue
|
||||
|
||||
print()
|
||||
print("=" * 100)
|
||||
print(" Results Summary")
|
||||
print("=" * 100)
|
||||
print()
|
||||
|
||||
if results:
|
||||
df_results = pd.DataFrame(results)
|
||||
efficiencies = df_results["efficiency_pct"].values
|
||||
oracle_matches = df_results["is_oracle_best"].sum()
|
||||
|
||||
print(f"Total shapes tested: {len(results)}")
|
||||
print()
|
||||
print("Efficiency Statistics (% of Oracle-Best TFLOPS):")
|
||||
print(f" Mean: {np.mean(efficiencies):.2f}%")
|
||||
print(f" Median: {np.median(efficiencies):.2f}%")
|
||||
print(f" Min: {np.min(efficiencies):.2f}%")
|
||||
print(f" Max: {np.max(efficiencies):.2f}%")
|
||||
print(f" P10: {np.percentile(efficiencies, 10):.2f}%")
|
||||
print(f" P50: {np.percentile(efficiencies, 50):.2f}%")
|
||||
print(f" P90: {np.percentile(efficiencies, 90):.2f}%")
|
||||
print()
|
||||
print(
|
||||
f"Oracle-best matches: {oracle_matches}/{len(results)} ({100 * oracle_matches / len(results):.1f}%)"
|
||||
)
|
||||
print()
|
||||
|
||||
# Classify by M size
|
||||
df_results["m_class"] = pd.cut(
|
||||
df_results["m"],
|
||||
bins=[0, 8, 128, 1024, float("inf")],
|
||||
labels=[
|
||||
"Tiny (M<8)",
|
||||
"Small (8≤M<128)",
|
||||
"Medium (128≤M<1024)",
|
||||
"Large (M≥1024)",
|
||||
],
|
||||
)
|
||||
|
||||
print("Efficiency by M size:")
|
||||
for m_class in [
|
||||
"Tiny (M<8)",
|
||||
"Small (8≤M<128)",
|
||||
"Medium (128≤M<1024)",
|
||||
"Large (M≥1024)",
|
||||
]:
|
||||
subset = df_results[df_results["m_class"] == m_class]
|
||||
if len(subset) > 0:
|
||||
print(
|
||||
f" {m_class:25s}: {subset['efficiency_pct'].mean():6.2f}% "
|
||||
f"(n={len(subset)}, P10={subset['efficiency_pct'].quantile(0.1):.2f}%)"
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
# Save results
|
||||
output_file = f"validation_results_{dtype}_{layout}.csv"
|
||||
df_results.to_csv(output_file, index=False)
|
||||
print(f"✓ Results saved to {output_file}")
|
||||
|
||||
# Show best and worst shapes
|
||||
print()
|
||||
print("Top 5 shapes (best efficiency):")
|
||||
top5 = df_results.nlargest(5, "efficiency_pct")[
|
||||
["m", "n", "k", "efficiency_pct", "oracle_best_tflops", "is_oracle_best"]
|
||||
]
|
||||
for idx, row in top5.iterrows():
|
||||
match = "✓" if row["is_oracle_best"] else " "
|
||||
print(
|
||||
f" {match} M={row['m']:5d} N={row['n']:5d} K={row['k']:5d}: "
|
||||
f"{row['efficiency_pct']:.2f}% ({row['oracle_best_tflops']:.2f} TFLOPS)"
|
||||
)
|
||||
|
||||
print()
|
||||
print("Bottom 5 shapes (worst efficiency):")
|
||||
bottom5 = df_results.nsmallest(5, "efficiency_pct")[
|
||||
["m", "n", "k", "efficiency_pct", "oracle_best_tflops", "is_oracle_best"]
|
||||
]
|
||||
for idx, row in bottom5.iterrows():
|
||||
match = "✓" if row["is_oracle_best"] else " "
|
||||
print(
|
||||
f" {match} M={row['m']:5d} N={row['n']:5d} K={row['k']:5d}: "
|
||||
f"{row['efficiency_pct']:.2f}% ({row['oracle_best_tflops']:.2f} TFLOPS)"
|
||||
)
|
||||
|
||||
else:
|
||||
print("No results to display")
|
||||
|
||||
print()
|
||||
print("=" * 100)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Validate ML heuristic predictions against oracle-best from training data"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16", "fp8"],
|
||||
help="Data type to validate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layout",
|
||||
default="rcr",
|
||||
choices=["rcr", "rrr", "crr", "ccr"],
|
||||
help="Matrix layout",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_dir",
|
||||
default=None,
|
||||
help="Path to model directory (auto-detect if not specified)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
help="Path to training data directory (auto-detect if not specified)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Auto-detect model directory if not specified
|
||||
if args.model_dir is None:
|
||||
heuristics_dir = Path(__file__).parent
|
||||
model_candidates = [
|
||||
heuristics_dir / "models" / f"gemm_universal_{args.dtype}_gfx950",
|
||||
heuristics_dir / "models" / f"gemm_universal_{args.dtype}_gfx942",
|
||||
]
|
||||
for candidate in model_candidates:
|
||||
if candidate.exists():
|
||||
args.model_dir = str(candidate)
|
||||
break
|
||||
|
||||
if args.model_dir is None:
|
||||
print(f"❌ Error: Could not find model directory for {args.dtype}")
|
||||
print(f" Searched: {[str(c) for c in model_candidates]}")
|
||||
print(" Please specify --model_dir explicitly")
|
||||
return 1
|
||||
|
||||
# Auto-detect data directory if not specified
|
||||
if args.data_dir is None:
|
||||
heuristics_dir = Path(__file__).parent
|
||||
args.data_dir = str(heuristics_dir / "data")
|
||||
|
||||
validate_ml_heuristic(args.dtype, args.layout, args.model_dir, args.data_dir)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -140,6 +140,11 @@ struct KernelKey
|
||||
bool preshuffle; // Preshuffle (for weight preshuffle variants)
|
||||
bool transpose_c; // TransposeC
|
||||
std::uint8_t num_wave_groups; // NumWaveGroups
|
||||
|
||||
// Padding support flags (kPadM, kPadN, kPadK in generated kernels)
|
||||
bool pad_m = true; // Support arbitrary M dimensions via padding
|
||||
bool pad_n = true; // Support arbitrary N dimensions via padding
|
||||
bool pad_k = true; // Support arbitrary K dimensions via padding
|
||||
} algorithm;
|
||||
|
||||
std::string gfx_arch; // e.g. "gfx942", "gfx90a", "gfx908"
|
||||
@@ -185,7 +190,10 @@ struct KernelKey
|
||||
algorithm.double_buffer,
|
||||
algorithm.preshuffle,
|
||||
algorithm.transpose_c,
|
||||
algorithm.num_wave_groups);
|
||||
algorithm.num_wave_groups,
|
||||
algorithm.pad_m,
|
||||
algorithm.pad_n,
|
||||
algorithm.pad_k);
|
||||
}
|
||||
|
||||
/// Equality comparison
|
||||
@@ -397,8 +405,14 @@ inline std::string KernelKey::encode_identifier() const
|
||||
|
||||
// Include pipeline, scheduler, epilogue for uniqueness
|
||||
oss << to_string(algorithm.pipeline) << "_";
|
||||
oss << to_string(algorithm.scheduler) << "_";
|
||||
oss << to_string(algorithm.epilogue) << "_";
|
||||
oss << to_string(algorithm.scheduler) << "_";
|
||||
|
||||
// Match tile_engine naming: padding flags (True/False) then persistent flag
|
||||
oss << (algorithm.pad_m ? "True" : "False") << "_";
|
||||
oss << (algorithm.pad_n ? "True" : "False") << "_";
|
||||
oss << (algorithm.pad_k ? "True" : "False") << "_";
|
||||
oss << (algorithm.persistent ? "True" : "False") << "_";
|
||||
|
||||
// Match tile_engine naming: tile_m x tile_n x tile_k _ warp_m x warp_n x warp_k _
|
||||
// warp_tile_m x warp_tile_n x warp_tile_k
|
||||
@@ -407,9 +421,6 @@ inline std::string KernelKey::encode_identifier() const
|
||||
<< unsigned(algorithm.wave_shape.k) << "_" << unsigned(algorithm.warp_tile_shape.m) << "x"
|
||||
<< unsigned(algorithm.warp_tile_shape.n) << "x" << unsigned(algorithm.warp_tile_shape.k);
|
||||
|
||||
// Add trait flags
|
||||
oss << "_" << (algorithm.persistent ? "persist" : "nopers");
|
||||
|
||||
if(signature.split_k > 1)
|
||||
oss << "_splitk" << unsigned(signature.split_k);
|
||||
if(!signature.elementwise_op.empty() && signature.elementwise_op != "PassThrough")
|
||||
|
||||
379
dispatcher/include/ck_tile/dispatcher/ml_heuristic.hpp
Normal file
379
dispatcher/include/ck_tile/dispatcher/ml_heuristic.hpp
Normal file
@@ -0,0 +1,379 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include "ck_tile/dispatcher/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/kernel_key.hpp"
|
||||
#include "ck_tile/dispatcher/registry.hpp"
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
namespace ck_tile {
|
||||
namespace dispatcher {
|
||||
extern "C" {
|
||||
int LGBM_BoosterCreateFromModelfile(const char*, int*, void**);
|
||||
int LGBM_BoosterPredictForMat(
|
||||
void*, const void*, int, int, int, int, int, int, int, const char*, int64_t*, double*);
|
||||
int LGBM_BoosterFree(void*);
|
||||
}
|
||||
inline int encode_pipeline(Pipeline p)
|
||||
{
|
||||
switch(p)
|
||||
{
|
||||
case Pipeline::CompV3: return 0;
|
||||
case Pipeline::CompV4: return 1;
|
||||
case Pipeline::CompV5: return 2;
|
||||
case Pipeline::Mem: return 3;
|
||||
case Pipeline::PreShuffleV2: return 4;
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
inline int encode_scheduler(Scheduler s)
|
||||
{
|
||||
switch(s)
|
||||
{
|
||||
case Scheduler::Intrawave: return 0;
|
||||
case Scheduler::Interwave: return 1;
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
inline int encode_epilogue(Epilogue e)
|
||||
{
|
||||
switch(e)
|
||||
{
|
||||
case Epilogue::Default: return 0;
|
||||
case Epilogue::CShuffle: return 1;
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
inline int encode_layout(LayoutTag a, LayoutTag b, LayoutTag c)
|
||||
{
|
||||
bool ra = (a == LayoutTag::RowMajor), rb = (b == LayoutTag::RowMajor);
|
||||
if(ra && !rb)
|
||||
return 0; // RCR
|
||||
if(ra && rb)
|
||||
return 1; // RRR
|
||||
if(!ra && rb)
|
||||
return 2; // CCR
|
||||
return 3; // CRR
|
||||
}
|
||||
inline double dtype_bytes_ml(DataType dt)
|
||||
{
|
||||
switch(dt)
|
||||
{
|
||||
case DataType::FP32: return 4;
|
||||
case DataType::FP16:
|
||||
case DataType::BF16: return 2;
|
||||
case DataType::FP8:
|
||||
case DataType::BF8:
|
||||
case DataType::INT8: return 1;
|
||||
case DataType::INT4: return 0.5;
|
||||
default: return 2;
|
||||
}
|
||||
}
|
||||
struct HardwareProfile
|
||||
{
|
||||
int num_cus = 256, simds_per_cu = 4, shader_engines = 32, max_clock_mhz = 2400,
|
||||
max_waves_per_cu = 32, wavefront_size = 64, lds_capacity = 65536, l1_cache_kb = 32,
|
||||
l2_cache_kb = 4096, l3_cache_kb = 262144, num_xcd = 8;
|
||||
int total_simds() const { return num_cus * simds_per_cu; }
|
||||
};
|
||||
|
||||
// CRITICAL: Feature count MUST match feature_spec.json
|
||||
// Python training uses 72 features - this header MUST extract exactly 72 features in the same order
|
||||
static constexpr int NUM_FEATURES = 72;
|
||||
|
||||
inline std::array<double, NUM_FEATURES>
|
||||
extract_features(const Problem& prob, const KernelKey& key, const HardwareProfile& hw)
|
||||
{
|
||||
// Problem dimensions
|
||||
double M = prob.M, N = prob.N, K = prob.K;
|
||||
double sk = (prob.k_batch > 0 ? prob.k_batch : 1);
|
||||
double bpe = dtype_bytes_ml(key.signature.dtype_a);
|
||||
|
||||
// Log-scale features
|
||||
double l2M = std::log2(std::max(M, 1.0));
|
||||
double l2N = std::log2(std::max(N, 1.0));
|
||||
double l2K = std::log2(std::max(K, 1.0));
|
||||
double l2MNK = std::log2(std::max(M * N * K, 1.0));
|
||||
|
||||
// Arithmetic intensity
|
||||
double mem = (M * K + K * N + M * N) * bpe;
|
||||
double ai = 2.0 * M * N * K / std::max(mem, 1.0);
|
||||
|
||||
// Aspect ratios
|
||||
double ar_mn = M / std::max(N, 1.0);
|
||||
double ar_mk = M / std::max(K, 1.0);
|
||||
double ar_nk = N / std::max(K, 1.0);
|
||||
|
||||
// Layout encoding
|
||||
double layout = (double)encode_layout(
|
||||
key.signature.layout_a, key.signature.layout_b, key.signature.layout_c);
|
||||
|
||||
// Tile dimensions
|
||||
double tm = key.algorithm.tile_shape.m;
|
||||
double tn = key.algorithm.tile_shape.n;
|
||||
double tk = key.algorithm.tile_shape.k;
|
||||
|
||||
// Wave/warp dimensions
|
||||
double wm = key.algorithm.wave_shape.m;
|
||||
double wn = key.algorithm.wave_shape.n;
|
||||
double wk = key.algorithm.wave_shape.k;
|
||||
|
||||
// Warp tile dimensions
|
||||
double wtm = key.algorithm.warp_tile_shape.m;
|
||||
double wtn = key.algorithm.warp_tile_shape.n;
|
||||
double wtk = key.algorithm.warp_tile_shape.k;
|
||||
|
||||
// Algorithm encoding
|
||||
double pipeline = (double)encode_pipeline(key.algorithm.pipeline);
|
||||
double scheduler = (double)encode_scheduler(key.algorithm.scheduler);
|
||||
double epilogue = (double)encode_epilogue(key.algorithm.epilogue);
|
||||
|
||||
// Padding flags - read from KernelKey
|
||||
double pad_m = key.algorithm.pad_m ? 1.0 : 0.0;
|
||||
double pad_n = key.algorithm.pad_n ? 1.0 : 0.0;
|
||||
double pad_k = key.algorithm.pad_k ? 1.0 : 0.0;
|
||||
|
||||
// Persistent kernel flag
|
||||
double persistent = key.algorithm.persistent ? 1.0 : 0.0;
|
||||
|
||||
// Derived features
|
||||
double num_warps = wm * wn * wk;
|
||||
double tile_volume = tm * tn * tk;
|
||||
double tile_mn = tm * tn;
|
||||
|
||||
// LDS usage estimation
|
||||
double lest = (tm * tk + tn * tk) * bpe;
|
||||
double lcap = (key.algorithm.pipeline == Pipeline::CompV4) ? 32768.0 : (double)hw.lds_capacity;
|
||||
double lds_ratio = lest / std::max(lcap, 1.0);
|
||||
|
||||
// Tile counts
|
||||
double ntm = std::ceil(M / std::max(tm, 1.0));
|
||||
double ntn = std::ceil(N / std::max(tn, 1.0));
|
||||
double ntk = std::ceil(K / std::max(tk, 1.0));
|
||||
double total_output_tiles = ntm * ntn;
|
||||
|
||||
// Tile efficiency (fractional remainder utilization)
|
||||
auto ef = [](double d, double t) -> double {
|
||||
if(t <= 0)
|
||||
return 1.0;
|
||||
double r = std::fmod(d, t);
|
||||
return r > 0 ? r / t : 1.0;
|
||||
};
|
||||
double tile_eff_m = ef(M, tm);
|
||||
double tile_eff_n = ef(N, tn);
|
||||
double tile_eff_k = ef(K, tk);
|
||||
double overall_tile_efficiency = tile_eff_m * tile_eff_n * tile_eff_k;
|
||||
|
||||
// CU utilization
|
||||
double cu_utilization = total_output_tiles / std::max((double)hw.num_cus, 1.0);
|
||||
|
||||
// P0 FIX: Problem-to-tile ratio features (critical for small problems)
|
||||
double ratio_M_to_tile_m = M / std::max(tm, 1.0);
|
||||
double ratio_N_to_tile_n = N / std::max(tn, 1.0);
|
||||
double ratio_K_to_tile_k = K / std::max(tk, 1.0);
|
||||
|
||||
// Binary features: is problem dimension smaller than tile?
|
||||
double problem_smaller_than_tile_m = (M < tm) ? 1.0 : 0.0;
|
||||
double problem_smaller_than_tile_n = (N < tn) ? 1.0 : 0.0;
|
||||
double problem_smaller_than_tile_k = (K < tk) ? 1.0 : 0.0;
|
||||
double any_dim_too_small = ((M < tm) || (N < tn) || (K < tk)) ? 1.0 : 0.0;
|
||||
|
||||
// P1 FIX: Padding requirement features
|
||||
double needs_padding_m = (tm > 0 && std::fmod(M, tm) != 0.0) ? 1.0 : 0.0;
|
||||
double needs_padding_n = (tn > 0 && std::fmod(N, tn) != 0.0) ? 1.0 : 0.0;
|
||||
double needs_padding_k = (tk > 0 && std::fmod(K, tk) != 0.0) ? 1.0 : 0.0;
|
||||
|
||||
// Interaction features: kernel has padding when problem needs it
|
||||
double has_padding_when_needed_m = (needs_padding_m && pad_m) ? 1.0 : 0.0;
|
||||
double has_padding_when_needed_n = (needs_padding_n && pad_n) ? 1.0 : 0.0;
|
||||
double has_padding_when_needed_k = (needs_padding_k && pad_k) ? 1.0 : 0.0;
|
||||
|
||||
// Critical feature: missing required padding (kernel will likely fail)
|
||||
double missing_required_padding_m = (needs_padding_m && !pad_m) ? 1.0 : 0.0;
|
||||
double missing_required_padding_n = (needs_padding_n && !pad_n) ? 1.0 : 0.0;
|
||||
double missing_required_padding_k = (needs_padding_k && !pad_k) ? 1.0 : 0.0;
|
||||
double missing_any_required_padding =
|
||||
(missing_required_padding_m || missing_required_padding_n || missing_required_padding_k)
|
||||
? 1.0
|
||||
: 0.0;
|
||||
|
||||
// Hardware features
|
||||
double hw_num_cus = (double)hw.num_cus;
|
||||
double hw_simds_per_cu = (double)hw.simds_per_cu;
|
||||
double hw_total_simds = (double)hw.total_simds();
|
||||
double hw_shader_engines = (double)hw.shader_engines;
|
||||
double hw_max_clock_mhz = (double)hw.max_clock_mhz;
|
||||
double hw_max_waves_per_cu = (double)hw.max_waves_per_cu;
|
||||
double hw_wavefront_size = (double)hw.wavefront_size;
|
||||
double hw_lds_capacity = (double)hw.lds_capacity;
|
||||
double hw_l1_cache_kb = (double)hw.l1_cache_kb;
|
||||
double hw_l2_cache_kb = (double)hw.l2_cache_kb;
|
||||
double hw_l3_cache_kb = (double)hw.l3_cache_kb;
|
||||
double hw_num_xcd = (double)hw.num_xcd;
|
||||
|
||||
// Feature vector in EXACT order from feature_spec.json
|
||||
// This order MUST match Python feature_engine.py::get_feature_names()
|
||||
return {{
|
||||
M, // 0
|
||||
N, // 1
|
||||
K, // 2
|
||||
sk, // 3 (split_k)
|
||||
l2M, // 4 (log2_M)
|
||||
l2N, // 5 (log2_N)
|
||||
l2K, // 6 (log2_K)
|
||||
l2MNK, // 7 (log2_MNK)
|
||||
ai, // 8 (arithmetic_intensity)
|
||||
ar_mn, // 9 (aspect_ratio_mn)
|
||||
ar_mk, // 10 (aspect_ratio_mk)
|
||||
ar_nk, // 11 (aspect_ratio_nk)
|
||||
layout, // 12 (layout)
|
||||
tm, // 13 (tile_m)
|
||||
tn, // 14 (tile_n)
|
||||
tk, // 15 (tile_k)
|
||||
wm, // 16 (warp_m)
|
||||
wn, // 17 (warp_n)
|
||||
wk, // 18 (warp_k)
|
||||
wtm, // 19 (warp_tile_m)
|
||||
wtn, // 20 (warp_tile_n)
|
||||
wtk, // 21 (warp_tile_k)
|
||||
pipeline, // 22 (pipeline)
|
||||
scheduler, // 23 (scheduler)
|
||||
epilogue, // 24 (epilogue)
|
||||
pad_m, // 25 (pad_m)
|
||||
pad_n, // 26 (pad_n)
|
||||
pad_k, // 27 (pad_k)
|
||||
persistent, // 28 (persistent)
|
||||
num_warps, // 29 (num_warps)
|
||||
tile_volume, // 30 (tile_volume)
|
||||
tile_mn, // 31 (tile_mn)
|
||||
lest, // 32 (lds_usage_estimate)
|
||||
lds_ratio, // 33 (lds_usage_ratio)
|
||||
ntm, // 34 (num_tiles_m)
|
||||
ntn, // 35 (num_tiles_n)
|
||||
ntk, // 36 (num_tiles_k)
|
||||
total_output_tiles, // 37 (total_output_tiles)
|
||||
tile_eff_m, // 38 (tile_eff_m)
|
||||
tile_eff_n, // 39 (tile_eff_n)
|
||||
tile_eff_k, // 40 (tile_eff_k)
|
||||
overall_tile_efficiency, // 41 (overall_tile_efficiency)
|
||||
cu_utilization, // 42 (cu_utilization)
|
||||
ratio_M_to_tile_m, // 43 (ratio_M_to_tile_m)
|
||||
ratio_N_to_tile_n, // 44 (ratio_N_to_tile_n)
|
||||
ratio_K_to_tile_k, // 45 (ratio_K_to_tile_k)
|
||||
problem_smaller_than_tile_m, // 46 (problem_smaller_than_tile_m)
|
||||
problem_smaller_than_tile_n, // 47 (problem_smaller_than_tile_n)
|
||||
problem_smaller_than_tile_k, // 48 (problem_smaller_than_tile_k)
|
||||
any_dim_too_small, // 49 (any_dim_too_small)
|
||||
needs_padding_m, // 50 (needs_padding_m)
|
||||
needs_padding_n, // 51 (needs_padding_n)
|
||||
needs_padding_k, // 52 (needs_padding_k)
|
||||
has_padding_when_needed_m, // 53 (has_padding_when_needed_m)
|
||||
has_padding_when_needed_n, // 54 (has_padding_when_needed_n)
|
||||
has_padding_when_needed_k, // 55 (has_padding_when_needed_k)
|
||||
missing_required_padding_m, // 56 (missing_required_padding_m)
|
||||
missing_required_padding_n, // 57 (missing_required_padding_n)
|
||||
missing_required_padding_k, // 58 (missing_required_padding_k)
|
||||
missing_any_required_padding, // 59 (missing_any_required_padding)
|
||||
hw_num_cus, // 60 (hw_num_cus)
|
||||
hw_simds_per_cu, // 61 (hw_simds_per_cu)
|
||||
hw_total_simds, // 62 (hw_total_simds)
|
||||
hw_shader_engines, // 63 (hw_shader_engines)
|
||||
hw_max_clock_mhz, // 64 (hw_max_clock_mhz)
|
||||
hw_max_waves_per_cu, // 65 (hw_max_waves_per_cu)
|
||||
hw_wavefront_size, // 66 (hw_wavefront_size)
|
||||
hw_lds_capacity, // 67 (hw_lds_capacity)
|
||||
hw_l1_cache_kb, // 68 (hw_l1_cache_kb)
|
||||
hw_l2_cache_kb, // 69 (hw_l2_cache_kb)
|
||||
hw_l3_cache_kb, // 70 (hw_l3_cache_kb)
|
||||
hw_num_xcd, // 71 (hw_num_xcd)
|
||||
}};
|
||||
}
|
||||
|
||||
class MLHeuristic
|
||||
{
|
||||
public:
|
||||
MLHeuristic(const std::string& path,
|
||||
const Registry* reg,
|
||||
HardwareProfile hw = {},
|
||||
bool log_t = false)
|
||||
: registry_(reg), hw_(hw), log_t_(log_t)
|
||||
{
|
||||
int iters = 0;
|
||||
if(LGBM_BoosterCreateFromModelfile(path.c_str(), &iters, &b_) != 0 || !b_)
|
||||
{
|
||||
std::cerr << "MLHeuristic: Failed to load " << path << std::endl;
|
||||
|
||||
// Check if a compressed .gz version exists
|
||||
std::string gz_path = path + ".gz";
|
||||
std::ifstream gz_check(gz_path);
|
||||
if(gz_check.good())
|
||||
{
|
||||
std::cerr << "MLHeuristic: Found compressed model at " << gz_path << std::endl;
|
||||
std::cerr << "MLHeuristic: Please decompress it first:" << std::endl;
|
||||
std::cerr << " gunzip " << gz_path << std::endl;
|
||||
}
|
||||
|
||||
b_ = nullptr;
|
||||
}
|
||||
else
|
||||
std::cout << "MLHeuristic: Loaded (" << iters << " iters)" << std::endl;
|
||||
}
|
||||
~MLHeuristic()
|
||||
{
|
||||
if(b_)
|
||||
LGBM_BoosterFree(b_);
|
||||
}
|
||||
MLHeuristic(const MLHeuristic&) = delete;
|
||||
MLHeuristic& operator=(const MLHeuristic&) = delete;
|
||||
bool is_loaded() const { return b_ != nullptr; }
|
||||
double predict_tflops(const Problem& prob, const KernelKey& key) const
|
||||
{
|
||||
if(!b_)
|
||||
return 0;
|
||||
auto f = extract_features(prob, key, hw_);
|
||||
int64_t ol = 0;
|
||||
double pred = 0;
|
||||
if(LGBM_BoosterPredictForMat(
|
||||
b_, f.data(), 0, 1, NUM_FEATURES, 1, 0, 0, 0, "", &ol, &pred) != 0)
|
||||
return 0;
|
||||
return log_t_ ? std::expm1(pred) : pred;
|
||||
}
|
||||
std::vector<std::string> operator()(const Problem& prob) const
|
||||
{
|
||||
if(!b_ || !registry_)
|
||||
return {};
|
||||
auto insts = registry_->get_all();
|
||||
struct C
|
||||
{
|
||||
std::string id;
|
||||
double t;
|
||||
};
|
||||
std::vector<C> cs;
|
||||
cs.reserve(insts.size());
|
||||
for(auto& i : insts)
|
||||
{
|
||||
auto& k = i->get_key();
|
||||
cs.push_back({k.encode_identifier(), predict_tflops(prob, k)});
|
||||
}
|
||||
std::sort(cs.begin(), cs.end(), [](auto& a, auto& b) { return a.t > b.t; });
|
||||
std::vector<std::string> r;
|
||||
r.reserve(cs.size());
|
||||
for(auto& c : cs)
|
||||
r.push_back(std::move(c.id));
|
||||
return r;
|
||||
}
|
||||
|
||||
private:
|
||||
void* b_ = nullptr;
|
||||
const Registry* registry_ = nullptr;
|
||||
HardwareProfile hw_;
|
||||
bool log_t_ = false;
|
||||
};
|
||||
} // namespace dispatcher
|
||||
} // namespace ck_tile
|
||||
@@ -1,11 +1,16 @@
|
||||
# Core dependencies
|
||||
numpy>=1.19.0
|
||||
|
||||
# ML Heuristic dependencies (OPTIONAL - large dependencies)
|
||||
# For ML-based kernel selection, install separately:
|
||||
# pip install -r ../requirements-ml.txt
|
||||
# This avoids mandatory large dependencies (pyarrow, lightgbm) for users who don't need ML features
|
||||
|
||||
# Optional dependencies (install with pip install -e ".[torch]")
|
||||
# torch>=2.0.0
|
||||
|
||||
# Development dependencies (install with pip install -e ".[dev]")
|
||||
# pytest>=6.0.0
|
||||
pytest>=6.0.0
|
||||
# pytest-cov>=2.0.0
|
||||
# black>=21.0
|
||||
# flake8>=3.9.0
|
||||
|
||||
6
dispatcher/requirements-ml.txt
Normal file
6
dispatcher/requirements-ml.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
# ML Heuristic dependencies for ML-based kernel selection
|
||||
# Install with: pip install -r requirements-ml.txt
|
||||
lightgbm>=3.0.0
|
||||
pandas>=1.3.0
|
||||
pyarrow>=6.0.0
|
||||
scikit-learn>=0.24.0
|
||||
Reference in New Issue
Block a user