From 48838830f92436b86f3d76afab0f138e37ebcf96 Mon Sep 17 00:00:00 2001 From: Mohsen Saffari Date: Mon, 27 Oct 2025 15:14:47 +0000 Subject: [PATCH] Clean up batched contraction: remove old UniversalGemmKernel path --- .../kernel/batched_contraction_kernel.hpp | 31 ++++++++----------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp index 69423a0f15..21d638dcb4 100644 --- a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp +++ b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp @@ -667,7 +667,7 @@ struct BatchedContractionKernel // Allocate shared memory __shared__ char smem_ptr[GetSmemSize()]; -#if 0 // OLD PATH: UniversalGemmKernel (kept for reference, can be deleted after full validation) + // Use UniversalGemmKernel's SplitKBatchOffset for split-K calculation typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr}, {b_ptr}, ds_batch_ptr, @@ -684,24 +684,19 @@ struct BatchedContractionKernel const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs, i_splitk); - const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0]; - const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0]; + // Apply K-split offsets and run descriptor-based RunGemm + const ADataType* a_ptr_split = a_ptr + splitk_batch_offset.as_k_split_offset[0]; + const BDataType* b_ptr_split = b_ptr + splitk_batch_offset.bs_k_split_offset[0]; - UniversalGemmKernel::RunGemm({a_ptr_final}, - {b_ptr_final}, - ds_batch_ptr, - e_ptr, - smem_ptr, - gemm_kargs, - splitk_batch_offset, - i_m, - i_n); -#else // NEW PATH: Descriptor-based RunGemm - // Custom descriptor-based RunGemm with full multi-dimensional stride support - // For now, use K_total (split-K to be properly implemented later) - const index_t k_size = kargs.K_total; - RunGemm(a_ptr, b_ptr, ds_batch_ptr, e_ptr, smem_ptr, kargs, k_size, i_m, i_n); -#endif + RunGemm(a_ptr_split, + b_ptr_split, + ds_batch_ptr, + e_ptr, + smem_ptr, + kargs, + splitk_batch_offset.splitted_k, + i_m, + i_n); } };