diff --git a/include/ck_tile/host/reference/reference_batched_contraction.hpp b/include/ck_tile/host/reference/reference_batched_contraction.hpp index 83476bd4ae..1ccff293eb 100644 --- a/include/ck_tile/host/reference/reference_batched_contraction.hpp +++ b/include/ck_tile/host/reference/reference_batched_contraction.hpp @@ -4,8 +4,6 @@ #pragma once #include -#include -#include #include #include "ck_tile/core.hpp" diff --git a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp index 21d638dcb4..7dc98a7544 100644 --- a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp +++ b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp @@ -61,23 +61,19 @@ * Rather than implementing tensor contraction from scratch, this kernel leverages the highly * optimized `UniversalGemmKernel` as its computational backend. * - * @subsection current_limitations Current Kernel Limitations + * @subsection implementation_features Implementation Features * - * **Layout Restrictions:** - * - **Row-Major Only**: All tensors must use row-major memory layout - * - **Packed Tensors**: Only contiguous/packed tensor layouts supported - * - **Hardcoded Strides**: stride_A = K_total, stride_B = K_total, stride_E = N_total - * - **D Tensor Layout**: All D tensors must match E tensor layout (stride_Ds = N_total) + * **Stride Support:** + * - Supports arbitrary multi-dimensional stride patterns + * - Handles non-contiguous and padded tensor layouts + * - Independent strides for each auxiliary D tensor + * - Descriptor-based architecture with vectorization * - * **Implementation Constraints:** - * - **Fixed Stride Calculation**: Strides are automatically calculated and cannot be customized - * - **No Column-Major**: Column-major or custom stride patterns not supported - * - **No Strided Access**: Non-contiguous tensor slicing not supported - * - * **Future Enhancements:** - * - Support for arbitrary stride patterns - * - Column-major and mixed layout support - * - Non-contiguous tensor operation support + * **Architecture:** + * - Uses TensorDescriptorUtils for stride-aware descriptor creation + * - Custom RunGemm implementation with descriptor-based tensor views + * - Reuses GemmPipeline and EpiloguePipeline for computation + * - Split-K support via UniversalGemmKernel utilities */ namespace ck_tile { @@ -379,10 +375,23 @@ struct BatchedContractionKernel TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch); } - /// @brief Custom RunGemm implementation using tensor descriptors for multi-dimensional - /// non-contiguous support - /// @details This function creates tensor views from descriptors and runs GEMM pipeline, - /// similar to UniversalGemmKernel::RunGemm but with descriptor-based tensor views + /// @brief Executes GEMM computation with descriptor-based tensor views for arbitrary stride + /// support + /// + /// @details This function performs the core GEMM computation using tensor descriptors to handle + /// arbitrary multi-dimensional stride patterns. It creates tensor views from + /// pre-computed descriptors (stored in kargs), applies padding, creates tile windows, + /// and executes the GemmPipeline and EpiloguePipeline. + /// + /// @param a_ptr Pointer to input tensor A data (after batch and split-K offsets applied) + /// @param b_ptr Pointer to input tensor B data (after batch and split-K offsets applied) + /// @param ds_ptr Array of pointers to auxiliary D tensor data + /// @param e_ptr Pointer to output tensor E data (after batch offset applied) + /// @param smem_ptr Pointer to shared memory for tile operations + /// @param kargs Kernel arguments containing tensor descriptors and dimension information + /// @param k_size Size of K dimension for this split (for split-K support) + /// @param i_m Starting M index for this block's tile + /// @param i_n Starting N index for this block's tile CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, const BDataType* b_ptr, const std::array& ds_ptr, @@ -393,36 +402,13 @@ struct BatchedContractionKernel const index_t i_m, const index_t i_n) { -#if 1 // DESCRIPTOR PATH: Full multi-dimensional stride support - // Create tensor views from descriptors + // Create tensor views from descriptors (supports arbitrary stride patterns) auto a_tensor_view = make_tensor_view(a_ptr, kargs.a_grid_desc_m_k); auto b_tensor_view = make_tensor_view(b_ptr, kargs.b_grid_desc_n_k); auto e_tensor_view = make_tensor_view(e_ptr, kargs.e_grid_desc_m_n); -#else // NAIVE PATH: Simple views assuming contiguous (for performance testing) - auto a_tensor_view = make_naive_tensor_view( - a_ptr, - make_tuple(kargs.M_total, k_size), - make_tuple(kargs.K_total, 1), - number{}, - number<1>{}); - - auto b_tensor_view = make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N_total, k_size), - make_tuple(kargs.K_total, 1), - number{}, - number<1>{}); - - auto e_tensor_view = make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M_total, kargs.N_total), - make_tuple(kargs.N_total, 1), - number{}, - number<1>{}); -#endif // Pad views for boundary handling and optimization (like UniversalGemmKernel) auto a_pad_view = pad_tensor_view( diff --git a/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp b/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp index a3c818d570..07b8e25c0a 100644 --- a/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp +++ b/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp @@ -13,8 +13,8 @@ * dimensions for GEMM operations. These functions transform multi-dimensional tensors into * 2D matrix descriptors by removing batch dimensions and flattening the remaining dimensions. * - * These utilities are currently not used in the main batched contraction kernel but are preserved - * for future implementations that may require explicit tensor descriptor creation. + * These utilities are used by BatchedContractionKernel to create stride-aware descriptors + * that support arbitrary multi-dimensional non-contiguous tensor layouts. */ namespace ck_tile {