fix(grouped_gemm): pipeline selection when tail_num varies per group and leads to numerical error (#2863)

* fix(grouped_gemm): numerical errors on gfx950 by correctly calculating the tail num

* WIP: add temp config to stress test numerical error correction

* refactor: remove comments
This commit is contained in:
Aviral Goel
2025-09-16 21:43:19 -04:00
committed by GitHub
parent f97b2a3f5d
commit db79fad16f
4 changed files with 33 additions and 35 deletions

View File

@@ -292,34 +292,8 @@ struct GroupedGemmKernel
{
__shared__ char smem_ptr_1[GetSmemSize()];
if constexpr(UsePersistentKernel || GemmPipeline::Preshuffle)
{
RunGemmWithPipelineSelection2LDS(a_ptr,
b_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
return;
}
else
{
Base::RunGemm2LDS({a_ptr},
{b_ptr},
{/*ds_ptr*/},
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
RunGemmWithPipelineSelection2LDS(
a_ptr, b_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, splitk_batch_offset, i_m, i_n);
}
else // SingleSmemBuffer
{