diff --git a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp index a93c6f5064..69423a0f15 100644 --- a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp +++ b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp @@ -185,7 +185,10 @@ template + ck_tile::index_t NumDTensor = 0, + ck_tile::index_t VectorSizeA = 1, + ck_tile::index_t VectorSizeB = 1, + ck_tile::index_t VectorSizeE = 1> struct BatchedContractionKernelArgs { const void* a_ptr; ///< Pointer to input tensor A @@ -220,14 +223,31 @@ struct BatchedContractionKernelArgs ck_tile::index_t stride_E; ///< Leading dimension stride for tensor E (for backward compatibility) - // Tensor descriptors (encode full multi-dimensional stride information) - // These are created on host and passed to device (like old CK) - using AGridDesc_M_K_ = decltype(TensorDescriptorUtils:: - Make_A_GridDescriptor_M_K({}, {})); - using BGridDesc_N_K_ = decltype(TensorDescriptorUtils:: - Make_B_GridDescriptor_N_K({}, {})); - using EGridDesc_M_N_ = decltype(TensorDescriptorUtils:: - Make_E_GridDescriptor_M_N({}, {})); + // Tensor descriptors (encode full multi-dimensional stride information with vectorization) + using AGridDesc_M_K_ = + decltype(TensorDescriptorUtils::Make_A_GridDescriptor_M_K({}, {})); + using BGridDesc_N_K_ = + decltype(TensorDescriptorUtils::Make_B_GridDescriptor_N_K({}, {})); + using EGridDesc_M_N_ = + decltype(TensorDescriptorUtils::Make_E_GridDescriptor_M_N({}, {})); AGridDesc_M_K_ a_grid_desc_m_k; ///< Tensor descriptor for A[M, K] with actual strides BGridDesc_N_K_ b_grid_desc_n_k; ///< Tensor descriptor for B[N, K] with actual strides @@ -293,18 +313,24 @@ struct BatchedContractionKernel static constexpr ck_tile::index_t kBlockSize = UniversalGemmKernel::kBlockSize; ///< GPU block size inherited from GEMM kernel - // Tensor descriptor utilities for creating stride-aware descriptors - using DescriptorUtils = TensorDescriptorUtils; + // Tensor descriptor utilities with vectorization support + using DescriptorUtils = TensorDescriptorUtils; - // Tensor descriptor types (created on host, encode all stride information) - using AGridDesc_M_K = decltype(DescriptorUtils::Make_A_GridDescriptor_M_K({}, {})); - using BGridDesc_N_K = decltype(DescriptorUtils::Make_B_GridDescriptor_N_K({}, {})); - using EGridDesc_M_N = decltype(DescriptorUtils::Make_E_GridDescriptor_M_N({}, {})); - - using KernelArgs = - BatchedContractionKernelArgs; ///< Kernel - ///< argument - ///< structure + // Kernel arguments with vectorization support + using KernelArgs = BatchedContractionKernelArgs; /// @brief Returns the kernel name for debugging and profiling purposes. /// @return Constant string identifier for this kernel @@ -363,36 +389,76 @@ struct BatchedContractionKernel EDataType* e_ptr, void* smem_ptr, const KernelArgs& kargs, + const index_t k_size, const index_t i_m, const index_t i_n) { - // Create tensor views from descriptors (handles multi-dimensional strides) +#if 1 // DESCRIPTOR PATH: Full multi-dimensional stride support + // Create tensor views from descriptors auto a_tensor_view = make_tensor_view(a_ptr, kargs.a_grid_desc_m_k); auto b_tensor_view = make_tensor_view(b_ptr, kargs.b_grid_desc_n_k); auto e_tensor_view = make_tensor_view(e_ptr, kargs.e_grid_desc_m_n); +#else // NAIVE PATH: Simple views assuming contiguous (for performance testing) + auto a_tensor_view = make_naive_tensor_view( + a_ptr, + make_tuple(kargs.M_total, k_size), + make_tuple(kargs.K_total, 1), + number{}, + number<1>{}); - // Create tile windows for this block's work - auto a_block_window = make_tile_window( + auto b_tensor_view = make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N_total, k_size), + make_tuple(kargs.K_total, 1), + number{}, + number<1>{}); + + auto e_tensor_view = make_naive_tensor_view( + e_ptr, + make_tuple(kargs.M_total, kargs.N_total), + make_tuple(kargs.N_total, 1), + number{}, + number<1>{}); +#endif + + // Pad views for boundary handling and optimization (like UniversalGemmKernel) + auto a_pad_view = pad_tensor_view( a_tensor_view, make_tuple(number{}, number{}), + sequence{}); + + auto b_pad_view = pad_tensor_view( + b_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + auto e_pad_view = pad_tensor_view( + e_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + // Create tile windows from PADDED views + auto a_block_window = make_tile_window( + a_pad_view, + make_tuple(number{}, number{}), {i_m, 0}); auto b_block_window = make_tile_window( - b_tensor_view, + b_pad_view, make_tuple(number{}, number{}), {i_n, 0}); auto e_block_window = make_tile_window( - e_tensor_view, + e_pad_view, make_tuple(number{}, number{}), {i_m, i_n}); // Calculate number of K loops const index_t num_loop = - __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.K_total)); + __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(k_size)); // Run GEMM Pipeline (same as UniversalGemmKernel, but with descriptor-based windows) using AElementWise = remove_cvref_t; @@ -631,8 +697,10 @@ struct BatchedContractionKernel i_m, i_n); #else // NEW PATH: Descriptor-based RunGemm - // custom descriptor-based RunGemm with full multi-dimensional stride support - RunGemm(a_ptr, b_ptr, ds_batch_ptr, e_ptr, smem_ptr, kargs, i_m, i_n); + // Custom descriptor-based RunGemm with full multi-dimensional stride support + // For now, use K_total (split-K to be properly implemented later) + const index_t k_size = kargs.K_total; + RunGemm(a_ptr, b_ptr, ds_batch_ptr, e_ptr, smem_ptr, kargs, k_size, i_m, i_n); #endif } }; diff --git a/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp b/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp index 6d3286ce09..a3c818d570 100644 --- a/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp +++ b/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp @@ -30,7 +30,10 @@ namespace ck_tile { template + ck_tile::index_t NumDimK, + ck_tile::index_t VectorSizeA, + ck_tile::index_t VectorSizeB, + ck_tile::index_t VectorSizeE> struct TensorDescriptorUtils { /// @brief Creates a tensor descriptor for input tensor A with batch dimensions removed. @@ -62,9 +65,9 @@ struct TensorDescriptorUtils const auto dims_M = get_container_subset(A_dims_M_K, A_dims_M_ids); const auto dims_K = get_container_subset(A_dims_M_K, A_dims_K_ids); - // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Discriptor - const auto A_grid_desc_Ms_Ks = - ck_tile::make_naive_tensor_descriptor(A_dims_M_K, A_strides_M_K); + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Descriptor with vector size + const auto A_grid_desc_Ms_Ks = ck_tile::make_naive_tensor_descriptor( + A_dims_M_K, A_strides_M_K, number{}, number<1>{}); // transformed tensor to flatten M and K dimensions [M_total = M0 * M1 * M2 * ... , K_total // = K0 * K1 * K2 * ...] @@ -106,9 +109,9 @@ struct TensorDescriptorUtils const auto dims_N = get_container_subset(B_dims_N_K, B_dims_N_ids); const auto dims_K = get_container_subset(B_dims_N_K, B_dims_K_ids); - // naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Discriptor - const auto B_grid_desc_Ns_Ks = - ck_tile::make_naive_tensor_descriptor(B_dims_N_K, B_strides_N_K); + // naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Descriptor with vector size + const auto B_grid_desc_Ns_Ks = ck_tile::make_naive_tensor_descriptor( + B_dims_N_K, B_strides_N_K, number{}, number<1>{}); // transformed tensor to flatten N and K dimensions [N_total = N0 * N1 * N2 * ... , K_total // = K0 * K1 * K2 * ...] @@ -150,9 +153,9 @@ struct TensorDescriptorUtils const auto dims_M = get_container_subset(E_dims_M_N, E_dims_M_ids); const auto dims_N = get_container_subset(E_dims_M_N, E_dims_N_ids); - // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Discriptor - const auto E_grid_desc_Ms_Ns = - ck_tile::make_naive_tensor_descriptor(E_dims_M_N, E_strides_M_N); + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Descriptor with vector size + const auto E_grid_desc_Ms_Ns = ck_tile::make_naive_tensor_descriptor( + E_dims_M_N, E_strides_M_N, number{}, number<1>{}); // transformed tensor to flatten M and N dimensions [M_total = M0 * M1 * M2 * ... , // N_total = N0 * N1 * N2 * ...]