fix(grouped_gemm): pipeline selection when tail_num varies per group and leads to numerical error (#2863)

* fix(grouped_gemm): numerical errors on gfx950 by correctly calculating the tail num

* WIP: add temp config to stress test numerical error correction

* refactor: remove comments
This commit is contained in:
Aviral Goel
2025-09-16 21:43:19 -04:00
committed by GitHub
parent f97b2a3f5d
commit db79fad16f
4 changed files with 33 additions and 35 deletions

View File

@@ -356,6 +356,8 @@ int main(int argc, char* argv[])
#if CK_TILE_USE_WMMA
return !run_grouped_gemm_example<GemmConfigComputeV4_Wmma>(argc, argv);
#else
return !run_grouped_gemm_example<GemmConfigComputeV4>(argc, argv);
return !run_grouped_gemm_example<GemmConfigComputeV4>(argc, argv) ||
!run_grouped_gemm_example<GemmConfigComputeV3_2>(argc, argv) ||
!run_grouped_gemm_example<GemmConfigComputeV4_V2>(argc, argv);
#endif
}