[CK Tile] batched contraction kernel generalizing (#3126)

* Add help for example

* Refactore the compute reference batched contraction to manage stride-aware calculation and some code cleanings

* Add stride-aware reference for batched contraction with independent D tensor layouts

* Add -num_d argument for runtime D tensor count selection in batched contraction

* Add stride vector arguments in example code for testing non-contiguous batched contraction inputs

* Add descriptor-based architecture for batched contraction multi-dimensional stride support

* Add multi-dimensional non-contiguous stride support to batched contraction, num_d = 0

* Add complete multi-dimensional stride support via descriptors

* Enable vectorization in descriptor-based batched contraction. Add pad_tensor_view to local RunGemm

* Clean up batched contraction: remove old UniversalGemmKernel path

* Clean up batched contraction: remove legacy paths and finalize docs

* Optimize batched contraction example: pass dimension sizes not vectors

* correct the reference calculation, unsigned int to int

* Fix batched_contraction C++17 build errors for gfx90a CI
This commit is contained in:
msaffari-amd
2025-12-02 13:30:27 +01:00
committed by GitHub
parent d3f37ebf6c
commit 2d3020e5b0
6 changed files with 694 additions and 313 deletions

View File

@@ -13,8 +13,8 @@
* dimensions for GEMM operations. These functions transform multi-dimensional tensors into
* 2D matrix descriptors by removing batch dimensions and flattening the remaining dimensions.
*
* These utilities are currently not used in the main batched contraction kernel but are preserved
* for future implementations that may require explicit tensor descriptor creation.
* These utilities are used by BatchedContractionKernel to create stride-aware descriptors
* that support arbitrary multi-dimensional non-contiguous tensor layouts.
*/
namespace ck_tile {
@@ -30,7 +30,10 @@ namespace ck_tile {
template <ck_tile::index_t NumDimG,
ck_tile::index_t NumDimM,
ck_tile::index_t NumDimN,
ck_tile::index_t NumDimK>
ck_tile::index_t NumDimK,
ck_tile::index_t VectorSizeA,
ck_tile::index_t VectorSizeB,
ck_tile::index_t VectorSizeE>
struct TensorDescriptorUtils
{
/// @brief Creates a tensor descriptor for input tensor A with batch dimensions removed.
@@ -62,9 +65,9 @@ struct TensorDescriptorUtils
const auto dims_M = get_container_subset(A_dims_M_K, A_dims_M_ids);
const auto dims_K = get_container_subset(A_dims_M_K, A_dims_K_ids);
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Discriptor
const auto A_grid_desc_Ms_Ks =
ck_tile::make_naive_tensor_descriptor(A_dims_M_K, A_strides_M_K);
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Descriptor with vector size
const auto A_grid_desc_Ms_Ks = ck_tile::make_naive_tensor_descriptor(
A_dims_M_K, A_strides_M_K, number<VectorSizeA>{}, number<1>{});
// transformed tensor to flatten M and K dimensions [M_total = M0 * M1 * M2 * ... , K_total
// = K0 * K1 * K2 * ...]
@@ -106,9 +109,9 @@ struct TensorDescriptorUtils
const auto dims_N = get_container_subset(B_dims_N_K, B_dims_N_ids);
const auto dims_K = get_container_subset(B_dims_N_K, B_dims_K_ids);
// naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Discriptor
const auto B_grid_desc_Ns_Ks =
ck_tile::make_naive_tensor_descriptor(B_dims_N_K, B_strides_N_K);
// naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Descriptor with vector size
const auto B_grid_desc_Ns_Ks = ck_tile::make_naive_tensor_descriptor(
B_dims_N_K, B_strides_N_K, number<VectorSizeB>{}, number<1>{});
// transformed tensor to flatten N and K dimensions [N_total = N0 * N1 * N2 * ... , K_total
// = K0 * K1 * K2 * ...]
@@ -150,9 +153,9 @@ struct TensorDescriptorUtils
const auto dims_M = get_container_subset(E_dims_M_N, E_dims_M_ids);
const auto dims_N = get_container_subset(E_dims_M_N, E_dims_N_ids);
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Discriptor
const auto E_grid_desc_Ms_Ns =
ck_tile::make_naive_tensor_descriptor(E_dims_M_N, E_strides_M_N);
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Descriptor with vector size
const auto E_grid_desc_Ms_Ns = ck_tile::make_naive_tensor_descriptor(
E_dims_M_N, E_strides_M_N, number<VectorSizeE>{}, number<1>{});
// transformed tensor to flatten M and N dimensions [M_total = M0 * M1 * M2 * ... ,
// N_total = N0 * N1 * N2 * ...]