diff --git a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp index 6d8f9f3f0e..e2a47e299a 100644 --- a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp +++ b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp" +#include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp" #include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" /** @@ -210,11 +211,28 @@ struct BatchedContractionKernelArgs ck_tile::index_t N_total; ///< Total N dimension: N0 * N1 * ... * N_{NumDimN-1} ck_tile::index_t K_total; ///< Total K dimension: K0 * K1 * ... * K_{NumDimK-1} - ck_tile::index_t stride_A; ///< Leading dimension stride for tensor A (row-major: K_total) - ck_tile::index_t stride_B; ///< Leading dimension stride for tensor B (row-major: K_total) + ck_tile::index_t + stride_A; ///< Leading dimension stride for tensor A (for backward compatibility) + ck_tile::index_t + stride_B; ///< Leading dimension stride for tensor B (for backward compatibility) std::array - stride_Ds; ///< Leading dimension strides for D tensors (row-major: N_total) - ck_tile::index_t stride_E; ///< Leading dimension stride for tensor E (row-major: N_total) + stride_Ds; ///< Leading dimension strides for D tensors (for backward compatibility) + ck_tile::index_t + stride_E; ///< Leading dimension stride for tensor E (for backward compatibility) + + // Tensor descriptors (encode full multi-dimensional stride information) + // These are created on host and passed to device (like old CK) + using AGridDesc_M_K_ = decltype(TensorDescriptorUtils:: + Make_A_GridDescriptor_M_K({}, {})); + using BGridDesc_N_K_ = decltype(TensorDescriptorUtils:: + Make_B_GridDescriptor_N_K({}, {})); + using EGridDesc_M_N_ = decltype(TensorDescriptorUtils:: + Make_E_GridDescriptor_M_N({}, {})); + + AGridDesc_M_K_ a_grid_desc_m_k; ///< Tensor descriptor for A[M, K] with actual strides + BGridDesc_N_K_ b_grid_desc_n_k; ///< Tensor descriptor for B[N, K] with actual strides + EGridDesc_M_N_ e_grid_desc_m_n; ///< Tensor descriptor for E[M, N] with actual strides + // TODO: Add D descriptors array when needed }; /// @brief GPU kernel for batched tensor contraction operations. @@ -274,6 +292,14 @@ struct BatchedContractionKernel static constexpr ck_tile::index_t kBlockSize = UniversalGemmKernel::kBlockSize; ///< GPU block size inherited from GEMM kernel + // Tensor descriptor utilities for creating stride-aware descriptors + using DescriptorUtils = TensorDescriptorUtils; + + // Tensor descriptor types (created on host, encode all stride information) + using AGridDesc_M_K = decltype(DescriptorUtils::Make_A_GridDescriptor_M_K({}, {})); + using BGridDesc_N_K = decltype(DescriptorUtils::Make_B_GridDescriptor_N_K({}, {})); + using EGridDesc_M_N = decltype(DescriptorUtils::Make_E_GridDescriptor_M_N({}, {})); + using KernelArgs = BatchedContractionKernelArgs; ///< Kernel ///< argument @@ -435,6 +461,15 @@ struct BatchedContractionKernel kargs.K_total *= kargs.K_dims[i]; } + // Create tensor descriptors on host using actual dims and strides + kargs.a_grid_desc_m_k = + DescriptorUtils::Make_A_GridDescriptor_M_K(host_args.A_dims, host_args.A_strides); + kargs.b_grid_desc_n_k = + DescriptorUtils::Make_B_GridDescriptor_N_K(host_args.B_dims, host_args.B_strides); + kargs.e_grid_desc_m_n = + DescriptorUtils::Make_E_GridDescriptor_M_N(host_args.E_dims, host_args.E_strides); + + // Keep simple strides for backward compatibility kargs.stride_A = kargs.K_total; kargs.stride_B = kargs.K_total; kargs.stride_E = kargs.N_total;