From ccf4558f9a5c055ce415db44fe810cb68aef4919 Mon Sep 17 00:00:00 2001 From: Matti Eskelinen Date: Wed, 17 Dec 2025 13:36:35 +0000 Subject: [PATCH] Use RunGemmDesc instead of custom RunGemm in BatchedContractionKernel --- .../kernel/batched_contraction_kernel.hpp | 32 ++++++++++++------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp index 968d5d6ac2..f0f40828d3 100644 --- a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp +++ b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp @@ -671,18 +671,28 @@ struct BatchedContractionKernel i_splitk); // Apply K-split offsets and run descriptor-based RunGemm - const ADataType* a_ptr_split = a_ptr + splitk_batch_offset.as_k_split_offset[0]; - const BDataType* b_ptr_split = b_ptr + splitk_batch_offset.bs_k_split_offset[0]; + const std::array{}> a_ptr_split = { + a_ptr + splitk_batch_offset.as_k_split_offset[0]}; + const std::array{}> b_ptr_split = { + b_ptr + splitk_batch_offset.bs_k_split_offset[0]}; - RunGemm(a_ptr_split, - b_ptr_split, - ds_batch_ptr, - e_ptr, - smem_ptr, - kargs, - splitk_batch_offset.splitted_k, - i_m, - i_n); + const std::array{}> a_grid_desc = { + kargs.a_grid_desc_m_k}; + const std::array{}> b_grid_desc = { + kargs.b_grid_desc_n_k}; + + UniversalGemmKernel::RunGemmDesc(a_ptr_split, + b_ptr_split, + ds_batch_ptr, + e_ptr, + smem_ptr, + splitk_batch_offset, + i_m, + i_n, + a_grid_desc, + b_grid_desc, + kargs.ds_grid_desc_m_n, + kargs.e_grid_desc_m_n); } };