mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Clean up batched contraction: remove old UniversalGemmKernel path
This commit is contained in:
@@ -667,7 +667,7 @@ struct BatchedContractionKernel
|
||||
// Allocate shared memory
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
#if 0 // OLD PATH: UniversalGemmKernel (kept for reference, can be deleted after full validation)
|
||||
// Use UniversalGemmKernel's SplitKBatchOffset for split-K calculation
|
||||
typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr},
|
||||
{b_ptr},
|
||||
ds_batch_ptr,
|
||||
@@ -684,24 +684,19 @@ struct BatchedContractionKernel
|
||||
const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs,
|
||||
i_splitk);
|
||||
|
||||
const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0];
|
||||
const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
|
||||
// Apply K-split offsets and run descriptor-based RunGemm
|
||||
const ADataType* a_ptr_split = a_ptr + splitk_batch_offset.as_k_split_offset[0];
|
||||
const BDataType* b_ptr_split = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
|
||||
|
||||
UniversalGemmKernel::RunGemm({a_ptr_final},
|
||||
{b_ptr_final},
|
||||
ds_batch_ptr,
|
||||
e_ptr,
|
||||
smem_ptr,
|
||||
gemm_kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
#else // NEW PATH: Descriptor-based RunGemm
|
||||
// Custom descriptor-based RunGemm with full multi-dimensional stride support
|
||||
// For now, use K_total (split-K to be properly implemented later)
|
||||
const index_t k_size = kargs.K_total;
|
||||
RunGemm(a_ptr, b_ptr, ds_batch_ptr, e_ptr, smem_ptr, kargs, k_size, i_m, i_n);
|
||||
#endif
|
||||
RunGemm(a_ptr_split,
|
||||
b_ptr_split,
|
||||
ds_batch_ptr,
|
||||
e_ptr,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitk_batch_offset.splitted_k,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user