mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
* WIP POC of dispatcher
* Dispatcher python workflow setup.
* Dispatcher cleanup and updates.
Further dispatcher cleanup and updates.
Build fixes
Improvements and python to CK example
Improvements to readme
* Fixes to python paths
* Cleaning up code
* Improving dispatcher support for different arch
Fixing typos
* Fix formatting errors
* Cleaning up examples
* Improving codegeneration
* Improving and fixing C++ examples
* Adding conv functionality (fwd,bwd,bwdw) and examples.
* Fixes based on feedback.
* Further fixes based on feedback.
* Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug.
* Another round of improvements based on feedback.
* Trimming out unnecessary code.
* Fixing the multi-D implementation.
* Using gpu verification for gemms and fixing convolutions tflops calculation.
* Fix counter usage issue and arch filtering per ops.
* Adding changelog and other fixes.
* Improve examples and resolve critical bugs.
* Reduce build time for python examples.
* Fixing minor bug.
* Fix compilation error.
* Improve installation instructions for dispatcher.
* Add docker based installation instructions for dispatcher.
* Fixing arch-based filtering to match tile engine.
* Remove dead code and fix arch filtering.
* Minor bugfix.
* Updates after rebase.
* Trimming code.
* Fix copyright headers.
* Consolidate examples, cut down code.
* Minor fixes.
* Improving python examples.
* Update readmes.
* Remove conv functionality.
* Cleanup following conv removable.
[ROCm/composable_kernel commit: 9e049a32a1]
153 lines
4.1 KiB
C++
153 lines
4.1 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#include "ck_tile/dispatcher/dispatcher.hpp"
|
|
#include <stdexcept>
|
|
#include <sstream>
|
|
#include <iostream>
|
|
|
|
namespace ck_tile {
|
|
namespace dispatcher {
|
|
|
|
Dispatcher::Dispatcher(Registry* registry)
|
|
: registry_(registry ? registry : &Registry::instance()),
|
|
heuristic_(nullptr),
|
|
strategy_(SelectionStrategy::FirstFit)
|
|
{
|
|
}
|
|
|
|
void Dispatcher::set_heuristic(HeuristicFunction heuristic)
|
|
{
|
|
heuristic_ = heuristic;
|
|
if(heuristic_)
|
|
{
|
|
strategy_ = SelectionStrategy::Heuristic;
|
|
}
|
|
}
|
|
|
|
void Dispatcher::set_strategy(SelectionStrategy strategy) { strategy_ = strategy; }
|
|
|
|
KernelInstancePtr Dispatcher::select_kernel(const Problem& problem) const
|
|
{
|
|
if(!problem.is_valid())
|
|
{
|
|
return nullptr;
|
|
}
|
|
|
|
switch(strategy_)
|
|
{
|
|
case SelectionStrategy::FirstFit: return select_first_fit(problem);
|
|
case SelectionStrategy::Heuristic: return select_heuristic(problem);
|
|
default: return nullptr;
|
|
}
|
|
}
|
|
|
|
float Dispatcher::run(
|
|
const void* a_ptr, const void* b_ptr, void* c_ptr, const Problem& problem, void* stream) const
|
|
{
|
|
return run_fused(a_ptr, b_ptr, c_ptr, nullptr, problem, stream);
|
|
}
|
|
|
|
float Dispatcher::run_fused(const void* a_ptr,
|
|
const void* b_ptr,
|
|
void* c_ptr,
|
|
const void** d_ptrs,
|
|
const Problem& problem,
|
|
void* stream) const
|
|
{
|
|
auto kernel = select_kernel(problem);
|
|
if(!kernel)
|
|
{
|
|
std::ostringstream oss;
|
|
oss << "No suitable kernel found for problem: M=" << problem.M << " N=" << problem.N
|
|
<< " K=" << problem.K;
|
|
throw std::runtime_error(oss.str());
|
|
}
|
|
|
|
return kernel->run(a_ptr, b_ptr, c_ptr, d_ptrs, problem, stream);
|
|
}
|
|
|
|
float Dispatcher::run_explicit(const std::string& kernel_id,
|
|
const void* a_ptr,
|
|
const void* b_ptr,
|
|
void* c_ptr,
|
|
const void** d_ptrs,
|
|
const Problem& problem,
|
|
void* stream) const
|
|
{
|
|
auto kernel = registry_->lookup(kernel_id);
|
|
if(!kernel)
|
|
{
|
|
throw std::runtime_error("Kernel not found: " + kernel_id);
|
|
}
|
|
|
|
if(!kernel->supports(problem))
|
|
{
|
|
std::ostringstream oss;
|
|
oss << "Kernel " << kernel_id << " does not support problem: M=" << problem.M
|
|
<< " N=" << problem.N << " K=" << problem.K;
|
|
throw std::runtime_error(oss.str());
|
|
}
|
|
|
|
return kernel->run(a_ptr, b_ptr, c_ptr, d_ptrs, problem, stream);
|
|
}
|
|
|
|
bool Dispatcher::validate(const void* a_ptr,
|
|
const void* b_ptr,
|
|
const void* c_ptr,
|
|
const void** d_ptrs,
|
|
const Problem& problem,
|
|
float tolerance) const
|
|
{
|
|
auto kernel = select_kernel(problem);
|
|
if(!kernel)
|
|
{
|
|
return false;
|
|
}
|
|
|
|
return kernel->validate(a_ptr, b_ptr, c_ptr, d_ptrs, problem, tolerance);
|
|
}
|
|
|
|
KernelInstancePtr Dispatcher::select_first_fit(const Problem& problem) const
|
|
{
|
|
auto all_kernels = registry_->get_all();
|
|
|
|
for(const auto& kernel : all_kernels)
|
|
{
|
|
if(kernel->supports(problem))
|
|
{
|
|
return kernel;
|
|
}
|
|
}
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
KernelInstancePtr Dispatcher::select_heuristic(const Problem& problem) const
|
|
{
|
|
if(!heuristic_)
|
|
{
|
|
// Fall back to first-fit if no heuristic available
|
|
return select_first_fit(problem);
|
|
}
|
|
|
|
// Get ranked list of kernel identifiers from heuristic
|
|
auto candidates = heuristic_(problem);
|
|
|
|
// Try each candidate in order
|
|
for(const auto& kernel_id : candidates)
|
|
{
|
|
auto kernel = registry_->lookup(kernel_id);
|
|
if(kernel && kernel->supports(problem))
|
|
{
|
|
return kernel;
|
|
}
|
|
}
|
|
|
|
// If no heuristic candidate works, fall back to first-fit
|
|
return select_first_fit(problem);
|
|
}
|
|
|
|
} // namespace dispatcher
|
|
} // namespace ck_tile
|