Files
composable_kernel/dispatcher/include/ck_tile/dispatcher/problem.hpp
Vidyasagar Ananthan 7ce0127e8f Adding dispatcher architecture (#3300)
* WIP POC of dispatcher

* Dispatcher python workflow setup.

* Dispatcher cleanup and updates.

Further dispatcher cleanup and updates.

Build fixes

Improvements and python to CK example

Improvements to readme

* Fixes to python paths

* Cleaning up code

* Improving dispatcher support for different arch

Fixing typos

* Fix formatting errors

* Cleaning up examples

* Improving codegeneration

* Improving and fixing C++ examples

* Adding conv functionality (fwd,bwd,bwdw) and examples.

* Fixes based on feedback.

* Further fixes based on feedback.

* Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug.

* Another round of improvements  based on feedback.

* Trimming out unnecessary code.

* Fixing the multi-D implementation.

* Using gpu verification for gemms and fixing convolutions tflops calculation.

* Fix counter usage issue and arch filtering per ops.

* Adding changelog and other fixes.

* Improve examples and resolve critical bugs.

* Reduce build time for python examples.

* Fixing minor bug.

* Fix compilation error.

* Improve installation instructions for dispatcher.

* Add docker based  installation instructions for dispatcher.

* Fixing arch-based filtering to match tile engine.

* Remove dead code and fix arch filtering.

* Minor bugfix.

* Updates after rebase.

* Trimming code.

* Fix copyright headers.

* Consolidate examples, cut down code.

* Minor fixes.

* Improving python examples.

* Update readmes.

* Remove conv functionality.

* Cleanup following conv removable.

[ROCm/composable_kernel commit: 9e049a32a1]
2026-01-22 09:34:33 -08:00

312 lines
10 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstdint>
#include <stdexcept>
#include <string>
namespace ck_tile {
namespace dispatcher {
// =============================================================================
// Tensor Information for Automatic MNK Inference
// =============================================================================
/// TensorShape: Describes tensor dimensions for automatic MNK inference
struct TensorShape
{
std::int64_t rows; // First dimension
std::int64_t cols; // Second dimension
bool is_transposed; // Whether the tensor is transposed (column-major)
TensorShape() : rows(0), cols(0), is_transposed(false) {}
TensorShape(std::int64_t r, std::int64_t c, bool trans = false)
: rows(r), cols(c), is_transposed(trans)
{
}
/// Get logical M (rows when not transposed)
[[nodiscard]] std::int64_t logical_rows() const { return is_transposed ? cols : rows; }
/// Get logical N (cols when not transposed)
[[nodiscard]] std::int64_t logical_cols() const { return is_transposed ? rows : cols; }
};
// =============================================================================
// Problem: Runtime Parameters
// =============================================================================
/// Problem: Runtime parameters for kernel invocation
/// Captures problem dimensions and resource constraints that vary between invocations
/// even when using the same kernel
struct Problem
{
// Problem dimensions
std::int64_t M; // Number of rows in A and C
std::int64_t N; // Number of columns in B and C
std::int64_t K; // Shared dimension (columns of A, rows of B)
// Batch configuration
std::int32_t k_batch; // Number of K-dimension splits for split-K GEMM
// Resource preferences
std::int32_t smem_budget; // Shared memory budget in bytes (0 = no constraint)
bool prefer_persistent; // Prefer persistent kernel variants
// Validation control
bool enable_validation; // Enable output validation against reference
/// Default constructor with sensible defaults
Problem()
: M(0),
N(0),
K(0),
k_batch(1),
smem_budget(0),
prefer_persistent(false),
enable_validation(false)
{
}
/// Constructor with problem dimensions
Problem(std::int64_t m, std::int64_t n, std::int64_t k)
: M(m),
N(n),
K(k),
k_batch(1),
smem_budget(0),
prefer_persistent(false),
enable_validation(false)
{
}
/// Check if problem dimensions are valid
[[nodiscard]] bool is_valid() const { return M > 0 && N > 0 && K > 0 && k_batch > 0; }
/// Get total number of operations (for performance metrics)
[[nodiscard]] std::int64_t num_ops() const
{
return 2 * M * N * K; // Multiply-add counts as 2 ops
}
// =========================================================================
// Factory Methods for Automatic MNK Inference
// =========================================================================
/**
* Create Problem by inferring MNK from tensor shapes.
*
* For GEMM: C[M,N] = A[M,K] × B[K,N]
*
* @param a_shape Shape of matrix A (M x K, or K x M if transposed)
* @param b_shape Shape of matrix B (K x N, or N x K if transposed)
* @param c_shape Shape of matrix C (M x N) - used for validation
* @throws std::invalid_argument if dimensions are inconsistent
*
* Example:
* // A is 512x256, B is 256x1024, C is 512x1024
* auto problem = Problem::from_shapes({512, 256}, {256, 1024}, {512, 1024});
* // Infers: M=512, N=1024, K=256
*/
[[nodiscard]] static Problem
from_shapes(TensorShape a_shape, TensorShape b_shape, TensorShape c_shape)
{
// For C = A × B:
// A: [M, K] (or [K, M] if transposed)
// B: [K, N] (or [N, K] if transposed)
// C: [M, N]
std::int64_t M_from_A = a_shape.logical_rows();
std::int64_t K_from_A = a_shape.logical_cols();
std::int64_t K_from_B = b_shape.logical_rows();
std::int64_t N_from_B = b_shape.logical_cols();
std::int64_t M_from_C = c_shape.logical_rows();
std::int64_t N_from_C = c_shape.logical_cols();
// Validate K dimension matches between A and B
if(K_from_A != K_from_B)
{
throw std::invalid_argument(
"K dimension mismatch: A has K=" + std::to_string(K_from_A) +
", B has K=" + std::to_string(K_from_B));
}
// Validate M dimension matches between A and C
if(M_from_A != M_from_C)
{
throw std::invalid_argument(
"M dimension mismatch: A has M=" + std::to_string(M_from_A) +
", C has M=" + std::to_string(M_from_C));
}
// Validate N dimension matches between B and C
if(N_from_B != N_from_C)
{
throw std::invalid_argument(
"N dimension mismatch: B has N=" + std::to_string(N_from_B) +
", C has N=" + std::to_string(N_from_C));
}
return Problem(M_from_A, N_from_B, K_from_A);
}
/**
* Create Problem from tensor dimensions (simple version without transpose).
*
* @param a_rows Rows of matrix A (= M)
* @param a_cols Columns of matrix A (= K)
* @param b_rows Rows of matrix B (= K)
* @param b_cols Columns of matrix B (= N)
* @param c_rows Rows of matrix C (= M) - for validation
* @param c_cols Columns of matrix C (= N) - for validation
* @throws std::invalid_argument if dimensions are inconsistent
*
* Example:
* // A[512,256] × B[256,1024] = C[512,1024]
* auto problem = Problem::from_dimensions(512, 256, 256, 1024, 512, 1024);
*/
[[nodiscard]] static Problem from_dimensions(std::int64_t a_rows,
std::int64_t a_cols,
std::int64_t b_rows,
std::int64_t b_cols,
std::int64_t c_rows,
std::int64_t c_cols)
{
return from_shapes(
TensorShape(a_rows, a_cols), TensorShape(b_rows, b_cols), TensorShape(c_rows, c_cols));
}
/**
* Create Problem from A and B dimensions only (C is inferred).
*
* @param a_rows Rows of matrix A (= M)
* @param a_cols Columns of matrix A (= K)
* @param b_rows Rows of matrix B (= K) - validated
* @param b_cols Columns of matrix B (= N)
* @throws std::invalid_argument if K dimensions don't match
*
* Example:
* // A[512,256] × B[256,1024] = C[512,1024]
* auto problem = Problem::from_ab(512, 256, 256, 1024);
*/
[[nodiscard]] static Problem
from_ab(std::int64_t a_rows, std::int64_t a_cols, std::int64_t b_rows, std::int64_t b_cols)
{
if(a_cols != b_rows)
{
throw std::invalid_argument("K dimension mismatch: A.cols=" + std::to_string(a_cols) +
", B.rows=" + std::to_string(b_rows));
}
return Problem(a_rows, b_cols, a_cols);
}
/**
* Validate that tensor pointers have consistent sizes.
* Call this before kernel execution to catch dimension errors early.
*
* @param a_size Total elements in A tensor
* @param b_size Total elements in B tensor
* @param c_size Total elements in C tensor
* @throws std::invalid_argument if sizes don't match expected dimensions
*/
void validate_sizes(std::int64_t a_size, std::int64_t b_size, std::int64_t c_size) const
{
std::int64_t expected_a = M * K;
std::int64_t expected_b = K * N;
std::int64_t expected_c = M * N;
if(a_size != expected_a)
{
throw std::invalid_argument("A tensor size mismatch: got " + std::to_string(a_size) +
", expected " + std::to_string(expected_a) + " (M*K = " +
std::to_string(M) + "*" + std::to_string(K) + ")");
}
if(b_size != expected_b)
{
throw std::invalid_argument("B tensor size mismatch: got " + std::to_string(b_size) +
", expected " + std::to_string(expected_b) + " (K*N = " +
std::to_string(K) + "*" + std::to_string(N) + ")");
}
if(c_size != expected_c)
{
throw std::invalid_argument("C tensor size mismatch: got " + std::to_string(c_size) +
", expected " + std::to_string(expected_c) + " (M*N = " +
std::to_string(M) + "*" + std::to_string(N) + ")");
}
}
};
// =============================================================================
// Convenience Builders
// =============================================================================
/// Builder pattern for Problem configuration
class ProblemBuilder
{
public:
ProblemBuilder() = default;
/// Set dimensions from A and B shapes
ProblemBuilder&
from_ab(std::int64_t a_rows, std::int64_t a_cols, std::int64_t b_rows, std::int64_t b_cols)
{
problem_ = Problem::from_ab(a_rows, a_cols, b_rows, b_cols);
return *this;
}
/// Set MNK directly
ProblemBuilder& dimensions(std::int64_t m, std::int64_t n, std::int64_t k)
{
problem_.M = m;
problem_.N = n;
problem_.K = k;
return *this;
}
/// Set split-K batch count
ProblemBuilder& split_k(std::int32_t k_batch)
{
problem_.k_batch = k_batch;
return *this;
}
/// Set shared memory budget
ProblemBuilder& smem_budget(std::int32_t budget)
{
problem_.smem_budget = budget;
return *this;
}
/// Prefer persistent kernels
ProblemBuilder& persistent(bool prefer = true)
{
problem_.prefer_persistent = prefer;
return *this;
}
/// Enable validation
ProblemBuilder& validate(bool enable = true)
{
problem_.enable_validation = enable;
return *this;
}
/// Build the Problem
[[nodiscard]] Problem build() const
{
if(!problem_.is_valid())
{
throw std::invalid_argument("Invalid problem dimensions");
}
return problem_;
}
private:
Problem problem_;
};
} // namespace dispatcher
} // namespace ck_tile