Files
composable_kernel/dispatcher/include/ck_tile/dispatcher/kernel_instance.hpp
Vidyasagar Ananthan 8763bbf6cf Adding dispatcher architecture (#3300)
* WIP POC of dispatcher

* Dispatcher python workflow setup.

* Dispatcher cleanup and updates.

Further dispatcher cleanup and updates.

Build fixes

Improvements and python to CK example

Improvements to readme

* Fixes to python paths

* Cleaning up code

* Improving dispatcher support for different arch

Fixing typos

* Fix formatting errors

* Cleaning up examples

* Improving codegeneration

* Improving and fixing C++ examples

* Adding conv functionality (fwd,bwd,bwdw) and examples.

* Fixes based on feedback.

* Further fixes based on feedback.

* Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug.

* Another round of improvements  based on feedback.

* Trimming out unnecessary code.

* Fixing the multi-D implementation.

* Using gpu verification for gemms and fixing convolutions tflops calculation.

* Fix counter usage issue and arch filtering per ops.

* Adding changelog and other fixes.

* Improve examples and resolve critical bugs.

* Reduce build time for python examples.

* Fixing minor bug.

* Fix compilation error.

* Improve installation instructions for dispatcher.

* Add docker based  installation instructions for dispatcher.

* Fixing arch-based filtering to match tile engine.

* Remove dead code and fix arch filtering.

* Minor bugfix.

* Updates after rebase.

* Trimming code.

* Fix copyright headers.

* Consolidate examples, cut down code.

* Minor fixes.

* Improving python examples.

* Update readmes.

* Remove conv functionality.

* Cleanup following conv removable.

[ROCm/composable_kernel commit: 9e049a32a1]
2026-01-22 09:34:33 -08:00

69 lines
2.9 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/dispatcher/kernel_key.hpp"
#include "ck_tile/dispatcher/problem.hpp"
#include <memory>
#include <string>
namespace ck_tile {
namespace dispatcher {
/// KernelInstance: Uniform interface for kernel execution
/// Abstracts away implementation details (CK Library vs CK Tile vs future JIT)
/// Enables type-erased storage in registry while backends perform type-safe casts
class KernelInstance
{
public:
virtual ~KernelInstance() = default;
/// Get the kernel's configuration metadata
[[nodiscard]] virtual const KernelKey& get_key() const = 0;
/// Check if this kernel supports the given problem
/// Returns false if problem dimensions don't meet kernel requirements
/// (e.g., divisibility constraints, resource limits)
[[nodiscard]] virtual bool supports(const Problem& problem) const = 0;
/// Get human-readable kernel name for logging and debugging
[[nodiscard]] virtual std::string get_name() const = 0;
/// Execute the kernel with given problem and data pointers
/// @param a_ptr Pointer to matrix A (device memory)
/// @param b_ptr Pointer to matrix B (device memory)
/// @param c_ptr Pointer to matrix C (device memory, input/output)
/// @param d_ptrs Array of pointers to additional D tensors for fusion (device memory)
/// @param problem Problem configuration
/// @param stream HIP stream for kernel launch (nullptr = default stream)
/// @return Kernel execution time in milliseconds (0 if timing not available)
[[nodiscard]] virtual float run(const void* a_ptr,
const void* b_ptr,
void* c_ptr,
const void** d_ptrs,
const Problem& problem,
void* stream = nullptr) const = 0;
/// Validate kernel output against reference implementation
/// @param a_ptr Pointer to matrix A (device memory)
/// @param b_ptr Pointer to matrix B (device memory)
/// @param c_ptr Pointer to matrix C (device memory, kernel output)
/// @param d_ptrs Array of pointers to additional D tensors (device memory)
/// @param problem Problem configuration
/// @param tolerance Relative error tolerance for validation
/// @return true if validation passes, false otherwise
[[nodiscard]] virtual bool validate(const void* a_ptr,
const void* b_ptr,
const void* c_ptr,
const void** d_ptrs,
const Problem& problem,
float tolerance = 1e-3f) const = 0;
};
/// Shared pointer type for kernel instances
using KernelInstancePtr = std::shared_ptr<KernelInstance>;
} // namespace dispatcher
} // namespace ck_tile