From b8b56d5cc678185db9460e9e7c5212c7237001b6 Mon Sep 17 00:00:00 2001 From: Mohsen Saffari Date: Mon, 20 Oct 2025 13:15:39 +0000 Subject: [PATCH] Add multi-dimensional non-contiguous stride support to batched contraction, num_d = 0 --- .../kernel/batched_contraction_kernel.hpp | 126 ++++++++++++++---- 1 file changed, 97 insertions(+), 29 deletions(-) diff --git a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp index e2a47e299a..0d2ec7a624 100644 --- a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp +++ b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp @@ -352,6 +352,63 @@ struct BatchedContractionKernel TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch); } + /// @brief Custom RunGemm implementation using tensor descriptors for multi-dimensional + /// non-contiguous support + /// @details This function creates tensor views from descriptors and runs GEMM pipeline, + /// similar to UniversalGemmKernel::RunGemm but with descriptor-based tensor views + CK_TILE_DEVICE static void + RunGemm(const ADataType* a_ptr, + const BDataType* b_ptr, + [[maybe_unused]] const std::array& ds_ptr, + EDataType* e_ptr, + void* smem_ptr, + const KernelArgs& kargs, + const index_t i_m, + const index_t i_n) + { + // Create tensor views from descriptors (handles multi-dimensional strides) + auto a_tensor_view = + make_tensor_view(a_ptr, kargs.a_grid_desc_m_k); + auto b_tensor_view = + make_tensor_view(b_ptr, kargs.b_grid_desc_n_k); + auto e_tensor_view = + make_tensor_view(e_ptr, kargs.e_grid_desc_m_n); + + // Create tile windows for this block's work + auto a_block_window = make_tile_window( + a_tensor_view, + make_tuple(number{}, number{}), + {i_m, 0}); + + auto b_block_window = make_tile_window( + b_tensor_view, + make_tuple(number{}, number{}), + {i_n, 0}); + + auto e_block_window = make_tile_window( + e_tensor_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + // Calculate number of K loops + const index_t num_loop = + __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.K_total)); + + // Run GEMM Pipeline (same as UniversalGemmKernel, but with descriptor-based windows) + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + + const auto& c_block_tile = GemmPipeline{}( + a_block_window, AElementWise{}, b_block_window, BElementWise{}, num_loop, smem_ptr); + + // Create empty D windows tuple (for NumDTensor=0 case) + // TODO: Create actual D windows from descriptors when NumDTensor > 0 + auto ds_block_windows = make_tuple(); + + // Run Epilogue Pipeline (same as UniversalGemmKernel) + EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_windows, smem_ptr); + } + CK_TILE_HOST static constexpr KernelArgs MakeKernelArgs(const BatchedContractionHostArgs& host_args) { @@ -503,8 +560,8 @@ struct BatchedContractionKernel const ck_tile::index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); - const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y); - const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z); + const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y); + [[maybe_unused]] const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z); // Calculate batch offsets for each tensor const auto batch_offset_A = i_batch_flat * kargs.batch_stride_A; @@ -522,35 +579,46 @@ struct BatchedContractionKernel ds_batch_ptr[i] = static_cast(kargs.ds_ptr[i]) + batch_offset_D; }); - typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr}, - {b_ptr}, - ds_batch_ptr, - e_ptr, - kargs.M_total, - kargs.N_total, - kargs.K_total, - {kargs.stride_A}, - {kargs.stride_B}, - kargs.stride_Ds, - kargs.stride_E, - kargs.k_batch}; - - const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs, - i_splitk); - - const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0]; - const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0]; + // Allocate shared memory __shared__ char smem_ptr[GetSmemSize()]; - UniversalGemmKernel::RunGemm({a_ptr_final}, - {b_ptr_final}, - ds_batch_ptr, - e_ptr, - smem_ptr, - gemm_kargs, - splitk_batch_offset, - i_m, - i_n); + if constexpr(NumDTensor > 0) + { + // OLD PATH: Use UniversalGemmKernel for cases with D tensors + typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr}, + {b_ptr}, + ds_batch_ptr, + e_ptr, + kargs.M_total, + kargs.N_total, + kargs.K_total, + {kargs.stride_A}, + {kargs.stride_B}, + kargs.stride_Ds, + kargs.stride_E, + kargs.k_batch}; + + const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs, + i_splitk); + + const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0]; + const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0]; + + UniversalGemmKernel::RunGemm({a_ptr_final}, + {b_ptr_final}, + ds_batch_ptr, + e_ptr, + smem_ptr, + gemm_kargs, + splitk_batch_offset, + i_m, + i_n); + } + else + { + // NEW PATH: Use descriptor-based RunGemm for num_d=0 + RunGemm(a_ptr, b_ptr, ds_batch_ptr, e_ptr, smem_ptr, kargs, i_m, i_n); + } } };