mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
* WIP POC of dispatcher
* Dispatcher python workflow setup.
* Dispatcher cleanup and updates.
Further dispatcher cleanup and updates.
Build fixes
Improvements and python to CK example
Improvements to readme
* Fixes to python paths
* Cleaning up code
* Improving dispatcher support for different arch
Fixing typos
* Fix formatting errors
* Cleaning up examples
* Improving codegeneration
* Improving and fixing C++ examples
* Adding conv functionality (fwd,bwd,bwdw) and examples.
* Fixes based on feedback.
* Further fixes based on feedback.
* Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug.
* Another round of improvements based on feedback.
* Trimming out unnecessary code.
* Fixing the multi-D implementation.
* Using gpu verification for gemms and fixing convolutions tflops calculation.
* Fix counter usage issue and arch filtering per ops.
* Adding changelog and other fixes.
* Improve examples and resolve critical bugs.
* Reduce build time for python examples.
* Fixing minor bug.
* Fix compilation error.
* Improve installation instructions for dispatcher.
* Add docker based installation instructions for dispatcher.
* Fixing arch-based filtering to match tile engine.
* Remove dead code and fix arch filtering.
* Minor bugfix.
* Updates after rebase.
* Trimming code.
* Fix copyright headers.
* Consolidate examples, cut down code.
* Minor fixes.
* Improving python examples.
* Update readmes.
* Remove conv functionality.
* Cleanup following conv removable.
[ROCm/composable_kernel commit: 9e049a32a1]
69 lines
2.9 KiB
C++
69 lines
2.9 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include "ck_tile/dispatcher/kernel_key.hpp"
|
|
#include "ck_tile/dispatcher/problem.hpp"
|
|
#include <memory>
|
|
#include <string>
|
|
|
|
namespace ck_tile {
|
|
namespace dispatcher {
|
|
|
|
/// KernelInstance: Uniform interface for kernel execution
|
|
/// Abstracts away implementation details (CK Library vs CK Tile vs future JIT)
|
|
/// Enables type-erased storage in registry while backends perform type-safe casts
|
|
class KernelInstance
|
|
{
|
|
public:
|
|
virtual ~KernelInstance() = default;
|
|
|
|
/// Get the kernel's configuration metadata
|
|
[[nodiscard]] virtual const KernelKey& get_key() const = 0;
|
|
|
|
/// Check if this kernel supports the given problem
|
|
/// Returns false if problem dimensions don't meet kernel requirements
|
|
/// (e.g., divisibility constraints, resource limits)
|
|
[[nodiscard]] virtual bool supports(const Problem& problem) const = 0;
|
|
|
|
/// Get human-readable kernel name for logging and debugging
|
|
[[nodiscard]] virtual std::string get_name() const = 0;
|
|
|
|
/// Execute the kernel with given problem and data pointers
|
|
/// @param a_ptr Pointer to matrix A (device memory)
|
|
/// @param b_ptr Pointer to matrix B (device memory)
|
|
/// @param c_ptr Pointer to matrix C (device memory, input/output)
|
|
/// @param d_ptrs Array of pointers to additional D tensors for fusion (device memory)
|
|
/// @param problem Problem configuration
|
|
/// @param stream HIP stream for kernel launch (nullptr = default stream)
|
|
/// @return Kernel execution time in milliseconds (0 if timing not available)
|
|
[[nodiscard]] virtual float run(const void* a_ptr,
|
|
const void* b_ptr,
|
|
void* c_ptr,
|
|
const void** d_ptrs,
|
|
const Problem& problem,
|
|
void* stream = nullptr) const = 0;
|
|
|
|
/// Validate kernel output against reference implementation
|
|
/// @param a_ptr Pointer to matrix A (device memory)
|
|
/// @param b_ptr Pointer to matrix B (device memory)
|
|
/// @param c_ptr Pointer to matrix C (device memory, kernel output)
|
|
/// @param d_ptrs Array of pointers to additional D tensors (device memory)
|
|
/// @param problem Problem configuration
|
|
/// @param tolerance Relative error tolerance for validation
|
|
/// @return true if validation passes, false otherwise
|
|
[[nodiscard]] virtual bool validate(const void* a_ptr,
|
|
const void* b_ptr,
|
|
const void* c_ptr,
|
|
const void** d_ptrs,
|
|
const Problem& problem,
|
|
float tolerance = 1e-3f) const = 0;
|
|
};
|
|
|
|
/// Shared pointer type for kernel instances
|
|
using KernelInstancePtr = std::shared_ptr<KernelInstance>;
|
|
|
|
} // namespace dispatcher
|
|
} // namespace ck_tile
|