Files
composable_kernel/dispatcher/src/dispatcher.cpp
Vidyasagar Ananthan 8763bbf6cf Adding dispatcher architecture (#3300)
* WIP POC of dispatcher

* Dispatcher python workflow setup.

* Dispatcher cleanup and updates.

Further dispatcher cleanup and updates.

Build fixes

Improvements and python to CK example

Improvements to readme

* Fixes to python paths

* Cleaning up code

* Improving dispatcher support for different arch

Fixing typos

* Fix formatting errors

* Cleaning up examples

* Improving codegeneration

* Improving and fixing C++ examples

* Adding conv functionality (fwd,bwd,bwdw) and examples.

* Fixes based on feedback.

* Further fixes based on feedback.

* Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug.

* Another round of improvements  based on feedback.

* Trimming out unnecessary code.

* Fixing the multi-D implementation.

* Using gpu verification for gemms and fixing convolutions tflops calculation.

* Fix counter usage issue and arch filtering per ops.

* Adding changelog and other fixes.

* Improve examples and resolve critical bugs.

* Reduce build time for python examples.

* Fixing minor bug.

* Fix compilation error.

* Improve installation instructions for dispatcher.

* Add docker based  installation instructions for dispatcher.

* Fixing arch-based filtering to match tile engine.

* Remove dead code and fix arch filtering.

* Minor bugfix.

* Updates after rebase.

* Trimming code.

* Fix copyright headers.

* Consolidate examples, cut down code.

* Minor fixes.

* Improving python examples.

* Update readmes.

* Remove conv functionality.

* Cleanup following conv removable.

[ROCm/composable_kernel commit: 9e049a32a1]
2026-01-22 09:34:33 -08:00

153 lines
4.1 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/dispatcher/dispatcher.hpp"
#include <stdexcept>
#include <sstream>
#include <iostream>
namespace ck_tile {
namespace dispatcher {
Dispatcher::Dispatcher(Registry* registry)
: registry_(registry ? registry : &Registry::instance()),
heuristic_(nullptr),
strategy_(SelectionStrategy::FirstFit)
{
}
void Dispatcher::set_heuristic(HeuristicFunction heuristic)
{
heuristic_ = heuristic;
if(heuristic_)
{
strategy_ = SelectionStrategy::Heuristic;
}
}
void Dispatcher::set_strategy(SelectionStrategy strategy) { strategy_ = strategy; }
KernelInstancePtr Dispatcher::select_kernel(const Problem& problem) const
{
if(!problem.is_valid())
{
return nullptr;
}
switch(strategy_)
{
case SelectionStrategy::FirstFit: return select_first_fit(problem);
case SelectionStrategy::Heuristic: return select_heuristic(problem);
default: return nullptr;
}
}
float Dispatcher::run(
const void* a_ptr, const void* b_ptr, void* c_ptr, const Problem& problem, void* stream) const
{
return run_fused(a_ptr, b_ptr, c_ptr, nullptr, problem, stream);
}
float Dispatcher::run_fused(const void* a_ptr,
const void* b_ptr,
void* c_ptr,
const void** d_ptrs,
const Problem& problem,
void* stream) const
{
auto kernel = select_kernel(problem);
if(!kernel)
{
std::ostringstream oss;
oss << "No suitable kernel found for problem: M=" << problem.M << " N=" << problem.N
<< " K=" << problem.K;
throw std::runtime_error(oss.str());
}
return kernel->run(a_ptr, b_ptr, c_ptr, d_ptrs, problem, stream);
}
float Dispatcher::run_explicit(const std::string& kernel_id,
const void* a_ptr,
const void* b_ptr,
void* c_ptr,
const void** d_ptrs,
const Problem& problem,
void* stream) const
{
auto kernel = registry_->lookup(kernel_id);
if(!kernel)
{
throw std::runtime_error("Kernel not found: " + kernel_id);
}
if(!kernel->supports(problem))
{
std::ostringstream oss;
oss << "Kernel " << kernel_id << " does not support problem: M=" << problem.M
<< " N=" << problem.N << " K=" << problem.K;
throw std::runtime_error(oss.str());
}
return kernel->run(a_ptr, b_ptr, c_ptr, d_ptrs, problem, stream);
}
bool Dispatcher::validate(const void* a_ptr,
const void* b_ptr,
const void* c_ptr,
const void** d_ptrs,
const Problem& problem,
float tolerance) const
{
auto kernel = select_kernel(problem);
if(!kernel)
{
return false;
}
return kernel->validate(a_ptr, b_ptr, c_ptr, d_ptrs, problem, tolerance);
}
KernelInstancePtr Dispatcher::select_first_fit(const Problem& problem) const
{
auto all_kernels = registry_->get_all();
for(const auto& kernel : all_kernels)
{
if(kernel->supports(problem))
{
return kernel;
}
}
return nullptr;
}
KernelInstancePtr Dispatcher::select_heuristic(const Problem& problem) const
{
if(!heuristic_)
{
// Fall back to first-fit if no heuristic available
return select_first_fit(problem);
}
// Get ranked list of kernel identifiers from heuristic
auto candidates = heuristic_(problem);
// Try each candidate in order
for(const auto& kernel_id : candidates)
{
auto kernel = registry_->lookup(kernel_id);
if(kernel && kernel->supports(problem))
{
return kernel;
}
}
// If no heuristic candidate works, fall back to first-fit
return select_first_fit(problem);
}
} // namespace dispatcher
} // namespace ck_tile