Use RunGemmDesc instead of custom RunGemm in BatchedContractionKernel

This commit is contained in:
Matti Eskelinen
2025-12-17 13:36:35 +00:00
parent 96820bf5a8
commit ccf4558f9a

View File

@@ -671,18 +671,28 @@ struct BatchedContractionKernel
i_splitk);
// Apply K-split offsets and run descriptor-based RunGemm
const ADataType* a_ptr_split = a_ptr + splitk_batch_offset.as_k_split_offset[0];
const BDataType* b_ptr_split = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
const std::array<const ADataType*, number<1>{}> a_ptr_split = {
a_ptr + splitk_batch_offset.as_k_split_offset[0]};
const std::array<const BDataType*, number<1>{}> b_ptr_split = {
b_ptr + splitk_batch_offset.bs_k_split_offset[0]};
RunGemm(a_ptr_split,
b_ptr_split,
ds_batch_ptr,
e_ptr,
smem_ptr,
kargs,
splitk_batch_offset.splitted_k,
i_m,
i_n);
const std::array<typename KernelArgs::AGridDesc_M_K_, number<1>{}> a_grid_desc = {
kargs.a_grid_desc_m_k};
const std::array<typename KernelArgs::BGridDesc_N_K_, number<1>{}> b_grid_desc = {
kargs.b_grid_desc_n_k};
UniversalGemmKernel::RunGemmDesc(a_ptr_split,
b_ptr_split,
ds_batch_ptr,
e_ptr,
smem_ptr,
splitk_batch_offset,
i_m,
i_n,
a_grid_desc,
b_grid_desc,
kargs.ds_grid_desc_m_n,
kargs.e_grid_desc_m_n);
}
};