mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Use RunGemmDesc instead of custom RunGemm in BatchedContractionKernel
This commit is contained in:
@@ -671,18 +671,28 @@ struct BatchedContractionKernel
|
||||
i_splitk);
|
||||
|
||||
// Apply K-split offsets and run descriptor-based RunGemm
|
||||
const ADataType* a_ptr_split = a_ptr + splitk_batch_offset.as_k_split_offset[0];
|
||||
const BDataType* b_ptr_split = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
|
||||
const std::array<const ADataType*, number<1>{}> a_ptr_split = {
|
||||
a_ptr + splitk_batch_offset.as_k_split_offset[0]};
|
||||
const std::array<const BDataType*, number<1>{}> b_ptr_split = {
|
||||
b_ptr + splitk_batch_offset.bs_k_split_offset[0]};
|
||||
|
||||
RunGemm(a_ptr_split,
|
||||
b_ptr_split,
|
||||
ds_batch_ptr,
|
||||
e_ptr,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitk_batch_offset.splitted_k,
|
||||
i_m,
|
||||
i_n);
|
||||
const std::array<typename KernelArgs::AGridDesc_M_K_, number<1>{}> a_grid_desc = {
|
||||
kargs.a_grid_desc_m_k};
|
||||
const std::array<typename KernelArgs::BGridDesc_N_K_, number<1>{}> b_grid_desc = {
|
||||
kargs.b_grid_desc_n_k};
|
||||
|
||||
UniversalGemmKernel::RunGemmDesc(a_ptr_split,
|
||||
b_ptr_split,
|
||||
ds_batch_ptr,
|
||||
e_ptr,
|
||||
smem_ptr,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n,
|
||||
a_grid_desc,
|
||||
b_grid_desc,
|
||||
kargs.ds_grid_desc_m_n,
|
||||
kargs.e_grid_desc_m_n);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user