Clean up batched contraction: remove old UniversalGemmKernel path

This commit is contained in:
Mohsen Saffari
2025-10-27 15:14:47 +00:00
parent 6144f5c490
commit 48838830f9

View File

@@ -667,7 +667,7 @@ struct BatchedContractionKernel
// Allocate shared memory
__shared__ char smem_ptr[GetSmemSize()];
#if 0 // OLD PATH: UniversalGemmKernel (kept for reference, can be deleted after full validation)
// Use UniversalGemmKernel's SplitKBatchOffset for split-K calculation
typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr},
{b_ptr},
ds_batch_ptr,
@@ -684,24 +684,19 @@ struct BatchedContractionKernel
const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs,
i_splitk);
const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0];
const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
// Apply K-split offsets and run descriptor-based RunGemm
const ADataType* a_ptr_split = a_ptr + splitk_batch_offset.as_k_split_offset[0];
const BDataType* b_ptr_split = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
UniversalGemmKernel::RunGemm({a_ptr_final},
{b_ptr_final},
ds_batch_ptr,
e_ptr,
smem_ptr,
gemm_kargs,
splitk_batch_offset,
i_m,
i_n);
#else // NEW PATH: Descriptor-based RunGemm
// Custom descriptor-based RunGemm with full multi-dimensional stride support
// For now, use K_total (split-K to be properly implemented later)
const index_t k_size = kargs.K_total;
RunGemm(a_ptr, b_ptr, ds_batch_ptr, e_ptr, smem_ptr, kargs, k_size, i_m, i_n);
#endif
RunGemm(a_ptr_split,
b_ptr_split,
ds_batch_ptr,
e_ptr,
smem_ptr,
kargs,
splitk_batch_offset.splitted_k,
i_m,
i_n);
}
};