Files
composable_kernel/dispatcher/examples/gemm/cpp/04_heuristics.cpp
Vidyasagar Ananthan 9e049a32a1 Adding dispatcher architecture (#3300)
* WIP POC of dispatcher

* Dispatcher python workflow setup.

* Dispatcher cleanup and updates.

Further dispatcher cleanup and updates.

Build fixes

Improvements and python to CK example

Improvements to readme

* Fixes to python paths

* Cleaning up code

* Improving dispatcher support for different arch

Fixing typos

* Fix formatting errors

* Cleaning up examples

* Improving codegeneration

* Improving and fixing C++ examples

* Adding conv functionality (fwd,bwd,bwdw) and examples.

* Fixes based on feedback.

* Further fixes based on feedback.

* Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug.

* Another round of improvements  based on feedback.

* Trimming out unnecessary code.

* Fixing the multi-D implementation.

* Using gpu verification for gemms and fixing convolutions tflops calculation.

* Fix counter usage issue and arch filtering per ops.

* Adding changelog and other fixes.

* Improve examples and resolve critical bugs.

* Reduce build time for python examples.

* Fixing minor bug.

* Fix compilation error.

* Improve installation instructions for dispatcher.

* Add docker based  installation instructions for dispatcher.

* Fixing arch-based filtering to match tile engine.

* Remove dead code and fix arch filtering.

* Minor bugfix.

* Updates after rebase.

* Trimming code.

* Fix copyright headers.

* Consolidate examples, cut down code.

* Minor fixes.

* Improving python examples.

* Update readmes.

* Remove conv functionality.

* Cleanup following conv removable.
2026-01-22 09:34:33 -08:00

169 lines
5.6 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* Example 04: Custom Heuristics
*
* Demonstrates custom kernel selection heuristics for different workloads.
*
* Build: cd dispatcher/build && cmake .. && make gemm_04_heuristics
*/
#include <hip/hip_runtime.h>
#include <iostream>
#include <iomanip>
#include <vector>
#include <algorithm>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/kernel_decl.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
using Signature = decl::Signature;
using Algorithm = decl::Algorithm;
// =============================================================================
// KERNEL SET: Multiple tile sizes for heuristic-based selection
// =============================================================================
DECL_KERNEL_SET(heuristics_kernels,
// Small tile - low latency
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(64, 64, 32)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942")
// Medium tile - balanced
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(128, 128, 64)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942"));
// =============================================================================
// Custom Heuristic
// =============================================================================
std::vector<std::string> size_based_heuristic(const Problem& problem)
{
std::vector<std::string> ranked_kernels;
int64_t total_elements = problem.M * problem.N;
if(total_elements < 100000)
{
ranked_kernels = {"gemm_64x64", "gemm_128x128"};
}
else
{
ranked_kernels = {"gemm_128x128", "gemm_64x64"};
}
return ranked_kernels;
}
// =============================================================================
// MAIN
// =============================================================================
int main(int argc, char* argv[])
{
ExampleArgs args("Example 04: Custom Heuristics",
"Demonstrates custom kernel selection heuristics");
args.add_option("--arch", "gfx942", "GPU architecture");
if(!args.parse(argc, argv))
return 0;
print_header("Example 04: Custom Heuristics");
std::string gfx_arch = args.get("--arch", "gfx942");
// =========================================================================
// Setup Registry and Dispatcher
// =========================================================================
Registry registry;
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
Dispatcher dispatcher(&registry);
dispatcher.set_strategy(Dispatcher::SelectionStrategy::Heuristic);
dispatcher.set_heuristic(size_based_heuristic);
std::cout << "\nSetup:\n";
std::cout << " Registry: " << registry.size() << " kernel(s)\n";
std::cout << " Strategy: Heuristic (size-based)\n";
// =========================================================================
// Test Different Problem Sizes
// =========================================================================
std::cout << "\nTesting heuristic selection:\n";
print_separator();
using DataType = ck_tile::fp16_t;
std::vector<std::tuple<int, int, int>> sizes = {
{128, 128, 64},
{512, 512, 256},
{2048, 2048, 1024},
};
bool all_passed = true;
for(const auto& [M, N, K] : sizes)
{
Problem problem(M, N, K);
auto selected = dispatcher.select_kernel(problem);
std::cout << "Problem " << M << "x" << N << "x" << K << ":\n";
if(selected)
{
std::cout << " Selected: " << selected->get_name() << "\n";
}
GpuBuffer<DataType> a_dev(M * K);
GpuBuffer<DataType> b_dev(K * N);
GpuBuffer<DataType> c_dev(M * N);
std::vector<DataType> a_host(M * K, DataType(1.0f));
std::vector<DataType> b_host(K * N, DataType(1.0f));
a_dev.copy_from_host(a_host.data());
b_dev.copy_from_host(b_host.data());
c_dev.zero();
float time_ms = dispatcher.run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr);
double tflops = calculate_tflops(M, N, K, time_ms);
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
// Verify
std::vector<DataType> c_host(M * N);
c_dev.copy_to_host(c_host.data());
float expected = static_cast<float>(K);
int errors = 0;
for(int i = 0; i < M * N; ++i)
{
float actual = static_cast<float>(c_host[i]);
if(std::abs(actual - expected) > 0.01f * expected + 1.0f)
++errors;
}
bool pass = (errors == 0);
std::cout << " Verify: " << (pass ? "PASS" : "FAIL") << "\n";
if(!pass)
all_passed = false;
print_separator();
}
std::cout << "Overall: " << (all_passed ? "ALL PASSED" : "SOME FAILED") << "\n";
return all_passed ? 0 : 1;
}