[CK Tile] batched contraction kernel generalizing (#3126)

* Add help for example

* Refactore the compute reference batched contraction to manage stride-aware calculation and some code cleanings

* Add stride-aware reference for batched contraction with independent D tensor layouts

* Add -num_d argument for runtime D tensor count selection in batched contraction

* Add stride vector arguments in example code for testing non-contiguous batched contraction inputs

* Add descriptor-based architecture for batched contraction multi-dimensional stride support

* Add multi-dimensional non-contiguous stride support to batched contraction, num_d = 0

* Add complete multi-dimensional stride support via descriptors

* Enable vectorization in descriptor-based batched contraction. Add pad_tensor_view to local RunGemm

* Clean up batched contraction: remove old UniversalGemmKernel path

* Clean up batched contraction: remove legacy paths and finalize docs

* Optimize batched contraction example: pass dimension sizes not vectors

* correct the reference calculation, unsigned int to int

* Fix batched_contraction C++17 build errors for gfx90a CI
This commit is contained in:
msaffari-amd
2025-12-02 13:30:27 +01:00
committed by GitHub
parent d3f37ebf6c
commit 2d3020e5b0
6 changed files with 694 additions and 313 deletions

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp"
#include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp"
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
/**
@@ -60,23 +61,19 @@
* Rather than implementing tensor contraction from scratch, this kernel leverages the highly
* optimized `UniversalGemmKernel` as its computational backend.
*
* @subsection current_limitations Current Kernel Limitations
* @subsection implementation_features Implementation Features
*
* **Layout Restrictions:**
* - **Row-Major Only**: All tensors must use row-major memory layout
* - **Packed Tensors**: Only contiguous/packed tensor layouts supported
* - **Hardcoded Strides**: stride_A = K_total, stride_B = K_total, stride_E = N_total
* - **D Tensor Layout**: All D tensors must match E tensor layout (stride_Ds = N_total)
* **Stride Support:**
* - Supports arbitrary multi-dimensional stride patterns
* - Handles non-contiguous and padded tensor layouts
* - Independent strides for each auxiliary D tensor
* - Descriptor-based architecture with vectorization
*
* **Implementation Constraints:**
* - **Fixed Stride Calculation**: Strides are automatically calculated and cannot be customized
* - **No Column-Major**: Column-major or custom stride patterns not supported
* - **No Strided Access**: Non-contiguous tensor slicing not supported
*
* **Future Enhancements:**
* - Support for arbitrary stride patterns
* - Column-major and mixed layout support
* - Non-contiguous tensor operation support
* **Architecture:**
* - Uses TensorDescriptorUtils for stride-aware descriptor creation
* - Custom RunGemm implementation with descriptor-based tensor views
* - Reuses GemmPipeline and EpiloguePipeline for computation
* - Split-K support via UniversalGemmKernel utilities
*/
namespace ck_tile {
@@ -184,7 +181,10 @@ template <ck_tile::index_t NumDimG,
ck_tile::index_t NumDimM,
ck_tile::index_t NumDimN,
ck_tile::index_t NumDimK,
ck_tile::index_t NumDTensor = 0>
ck_tile::index_t NumDTensor = 0,
ck_tile::index_t VectorSizeA = 1,
ck_tile::index_t VectorSizeB = 1,
ck_tile::index_t VectorSizeE = 1>
struct BatchedContractionKernelArgs
{
const void* a_ptr; ///< Pointer to input tensor A
@@ -210,11 +210,46 @@ struct BatchedContractionKernelArgs
ck_tile::index_t N_total; ///< Total N dimension: N0 * N1 * ... * N_{NumDimN-1}
ck_tile::index_t K_total; ///< Total K dimension: K0 * K1 * ... * K_{NumDimK-1}
ck_tile::index_t stride_A; ///< Leading dimension stride for tensor A (row-major: K_total)
ck_tile::index_t stride_B; ///< Leading dimension stride for tensor B (row-major: K_total)
ck_tile::index_t
stride_A; ///< Leading dimension stride for tensor A (for backward compatibility)
ck_tile::index_t
stride_B; ///< Leading dimension stride for tensor B (for backward compatibility)
std::array<ck_tile::index_t, NumDTensor>
stride_Ds; ///< Leading dimension strides for D tensors (row-major: N_total)
ck_tile::index_t stride_E; ///< Leading dimension stride for tensor E (row-major: N_total)
stride_Ds; ///< Leading dimension strides for D tensors (for backward compatibility)
ck_tile::index_t
stride_E; ///< Leading dimension stride for tensor E (for backward compatibility)
// Tensor descriptors (encode full multi-dimensional stride information with vectorization)
using AGridDesc_M_K_ =
decltype(TensorDescriptorUtils<NumDimG,
NumDimM,
NumDimN,
NumDimK,
VectorSizeA,
VectorSizeB,
VectorSizeE>::Make_A_GridDescriptor_M_K({}, {}));
using BGridDesc_N_K_ =
decltype(TensorDescriptorUtils<NumDimG,
NumDimM,
NumDimN,
NumDimK,
VectorSizeA,
VectorSizeB,
VectorSizeE>::Make_B_GridDescriptor_N_K({}, {}));
using EGridDesc_M_N_ =
decltype(TensorDescriptorUtils<NumDimG,
NumDimM,
NumDimN,
NumDimK,
VectorSizeA,
VectorSizeB,
VectorSizeE>::Make_E_GridDescriptor_M_N({}, {}));
AGridDesc_M_K_ a_grid_desc_m_k; ///< Tensor descriptor for A[M, K] with actual strides
BGridDesc_N_K_ b_grid_desc_n_k; ///< Tensor descriptor for B[N, K] with actual strides
EGridDesc_M_N_ e_grid_desc_m_n; ///< Tensor descriptor for E[M, N] with actual strides
std::array<EGridDesc_M_N_, NumDTensor>
ds_grid_desc_m_n; ///< Descriptors for D tensors (same shape as E, independent strides)
};
/// @brief GPU kernel for batched tensor contraction operations.
@@ -274,10 +309,24 @@ struct BatchedContractionKernel
static constexpr ck_tile::index_t kBlockSize =
UniversalGemmKernel::kBlockSize; ///< GPU block size inherited from GEMM kernel
using KernelArgs =
BatchedContractionKernelArgs<NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor>; ///< Kernel
///< argument
///< structure
// Tensor descriptor utilities with vectorization support
using DescriptorUtils = TensorDescriptorUtils<NumDimG,
NumDimM,
NumDimN,
NumDimK,
GemmPipeline::GetVectorSizeA(),
GemmPipeline::GetVectorSizeB(),
EpiloguePipeline::GetVectorSizeC()>;
// Kernel arguments with vectorization support
using KernelArgs = BatchedContractionKernelArgs<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDTensor,
GemmPipeline::GetVectorSizeA(),
GemmPipeline::GetVectorSizeB(),
EpiloguePipeline::GetVectorSizeC()>;
/// @brief Returns the kernel name for debugging and profiling purposes.
/// @return Constant string identifier for this kernel
@@ -326,6 +375,104 @@ struct BatchedContractionKernel
TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch);
}
/// @brief Executes GEMM computation with descriptor-based tensor views for arbitrary stride
/// support
///
/// @details This function performs the core GEMM computation using tensor descriptors to handle
/// arbitrary multi-dimensional stride patterns. It creates tensor views from
/// pre-computed descriptors (stored in kargs), applies padding, creates tile windows,
/// and executes the GemmPipeline and EpiloguePipeline.
///
/// @param a_ptr Pointer to input tensor A data (after batch and split-K offsets applied)
/// @param b_ptr Pointer to input tensor B data (after batch and split-K offsets applied)
/// @param ds_ptr Array of pointers to auxiliary D tensor data
/// @param e_ptr Pointer to output tensor E data (after batch offset applied)
/// @param smem_ptr Pointer to shared memory for tile operations
/// @param kargs Kernel arguments containing tensor descriptors and dimension information
/// @param k_size Size of K dimension for this split (for split-K support)
/// @param i_m Starting M index for this block's tile
/// @param i_n Starting N index for this block's tile
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* smem_ptr,
const KernelArgs& kargs,
const index_t k_size,
const index_t i_m,
const index_t i_n)
{
// Create tensor views from descriptors (supports arbitrary stride patterns)
auto a_tensor_view =
make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_m_k);
auto b_tensor_view =
make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_n_k);
auto e_tensor_view =
make_tensor_view<address_space_enum::global>(e_ptr, kargs.e_grid_desc_m_n);
// Pad views for boundary handling and optimization (like UniversalGemmKernel)
auto a_pad_view = pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
auto b_pad_view = pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
auto e_pad_view = pad_tensor_view(
e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
// Create tile windows from PADDED views
auto a_block_window = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
auto b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
auto e_block_window = make_tile_window(
e_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
// Calculate number of K loops
const index_t num_loop =
__builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(k_size));
// Run GEMM Pipeline (same as UniversalGemmKernel, but with descriptor-based windows)
using AElementWise = remove_cvref_t<typename GemmPipeline::AElementWise>;
using BElementWise = remove_cvref_t<typename GemmPipeline::BElementWise>;
const auto& c_block_tile = GemmPipeline{}(
a_block_window, AElementWise{}, b_block_window, BElementWise{}, num_loop, smem_ptr);
// Create D windows from descriptors (for each D tensor)
auto ds_block_windows = generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
const DDataType* d_ptr = static_cast<const DDataType*>(ds_ptr[i]);
auto d_tensor_view =
make_tensor_view<address_space_enum::global>(d_ptr, kargs.ds_grid_desc_m_n[i]);
return make_tile_window(d_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
},
number<NumDTensor>{});
// Run Epilogue Pipeline with descriptor-based D windows
EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_windows, smem_ptr);
}
CK_TILE_HOST static constexpr KernelArgs
MakeKernelArgs(const BatchedContractionHostArgs<NumDTensor>& host_args)
{
@@ -435,6 +582,22 @@ struct BatchedContractionKernel
kargs.K_total *= kargs.K_dims[i];
}
// Create tensor descriptors on host using actual dims and strides
kargs.a_grid_desc_m_k =
DescriptorUtils::Make_A_GridDescriptor_M_K(host_args.A_dims, host_args.A_strides);
kargs.b_grid_desc_n_k =
DescriptorUtils::Make_B_GridDescriptor_N_K(host_args.B_dims, host_args.B_strides);
kargs.e_grid_desc_m_n =
DescriptorUtils::Make_E_GridDescriptor_M_N(host_args.E_dims, host_args.E_strides);
// Create D descriptors with their own strides (same shape as E, independent strides)
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
{
kargs.ds_grid_desc_m_n[d] = DescriptorUtils::Make_E_GridDescriptor_M_N(
host_args.Ds_dims[d], host_args.Ds_strides[d]);
}
// Keep simple strides for backward compatibility
kargs.stride_A = kargs.K_total;
kargs.stride_B = kargs.K_total;
kargs.stride_E = kargs.N_total;
@@ -468,8 +631,8 @@ struct BatchedContractionKernel
const ck_tile::index_t i_n =
__builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y);
const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y);
[[maybe_unused]] const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
// Calculate batch offsets for each tensor
const auto batch_offset_A = i_batch_flat * kargs.batch_stride_A;
@@ -487,6 +650,10 @@ struct BatchedContractionKernel
ds_batch_ptr[i] = static_cast<const DDataType*>(kargs.ds_ptr[i]) + batch_offset_D;
});
// Allocate shared memory
__shared__ char smem_ptr[GetSmemSize()];
// Use UniversalGemmKernel's SplitKBatchOffset for split-K calculation
typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr},
{b_ptr},
ds_batch_ptr,
@@ -503,19 +670,19 @@ struct BatchedContractionKernel
const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs,
i_splitk);
const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0];
const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
__shared__ char smem_ptr[GetSmemSize()];
// Apply K-split offsets and run descriptor-based RunGemm
const ADataType* a_ptr_split = a_ptr + splitk_batch_offset.as_k_split_offset[0];
const BDataType* b_ptr_split = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
UniversalGemmKernel::RunGemm({a_ptr_final},
{b_ptr_final},
ds_batch_ptr,
e_ptr,
smem_ptr,
gemm_kargs,
splitk_batch_offset,
i_m,
i_n);
RunGemm(a_ptr_split,
b_ptr_split,
ds_batch_ptr,
e_ptr,
smem_ptr,
kargs,
splitk_batch_offset.splitted_k,
i_m,
i_n);
}
};