From ede105dd9123c90da7ac4e5e732eb38d445d9d45 Mon Sep 17 00:00:00 2001 From: Emily Martins <65371150+ecamartins@users.noreply.github.com> Date: Fri, 21 Nov 2025 20:29:47 -0700 Subject: [PATCH] Fix CK Tile DP + 2 Tile Stream-K Validation Errors (#3269) When there are multiple workgroups contributing to a tile, when using atomics, there may be round off error in cases where the accumulator type is not the same as the C type. To compute an error tolerance for test validation, the Stream-K Tile Partitioner has a function called estimate_num_wgs_per_tile to estimate the number of workgroups per tile. That said, this function only provides an estimate. In some cases for DP+2TSK, the function returns 1 rather than the more accurate value of 2. Thus, this change updates the estimate_num_wgs_per_tile function to explicitely return the value of 2 in cases for DP+2TSK to ensure that we have a better error tolerance to avoid test failures due to round-off error. [ROCm/composable_kernel commit: 02ab76c2cb47143b82743bcf9d86389c540a608b] --- .../streamk_gemm_tile_partitioner_impl.hpp | 22 ++++++++++++++----- test/ck_tile/gemm_streamk/CMakeLists.txt | 2 +- .../test_streamk_tile_partitioner.cpp | 2 +- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp index 9116e0448c..626f440119 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp @@ -219,17 +219,27 @@ CK_TILE_HOST index_t StreamKTilePartitionerBase::estimate_num_wgs_per_tile() const noexcept { - // In the case of non-atomic reduction or data-parallel only, there will always be 1 workgroup - // writing final results to a given macro tile in C. + // In the case of non-atomic reduction or data-parallel (DP) only, there will always be 1 + // workgroup writing final results to a given macro tile in C. int num_wgs_per_tile = 1; // Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C. if(sk_ctas_ > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) { - ck_tile::index_t iters_per_sk_cta_non_zero = ck_tile::max(iters_per_sk_cta_, 1); - // Estimate the number of workgroups per macro tile. - num_wgs_per_tile = (iters_per_tile_ / iters_per_sk_cta_non_zero) + - ((iters_per_tile_ % iters_per_sk_cta_non_zero) != 0); + // If we have DP and SK tiles, this is DP+2TSK which guarantees at most 2 workgroups per + // tile. We only need to check that dp_tiles is greater than zero since we know we have SK + // workgroups. + if(dp_tiles_ > 0) + { + num_wgs_per_tile = 2; + } + else + { + ck_tile::index_t iters_per_sk_cta_non_zero = ck_tile::max(iters_per_sk_cta_, 1); + // Estimate the number of workgroups per macro tile. + num_wgs_per_tile = (iters_per_tile_ / iters_per_sk_cta_non_zero) + + ((iters_per_tile_ % iters_per_sk_cta_non_zero) != 0); + } } return std::max(num_wgs_per_tile, 1); diff --git a/test/ck_tile/gemm_streamk/CMakeLists.txt b/test/ck_tile/gemm_streamk/CMakeLists.txt index 90aa7771fe..7b1bc6f4f2 100644 --- a/test/ck_tile/gemm_streamk/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk/CMakeLists.txt @@ -13,7 +13,7 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) # Currently test_ck_tile_streamk_smoke is only built on gfx9 -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950") include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp index 525817641a..dd74efc27a 100644 --- a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp +++ b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp @@ -84,7 +84,7 @@ TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileLower ck_tile::StreamKTilePartitionerBase tile_partitioner{ Config::M, Config::N, Config::K, Config::GRID}; - EXPECT_EQ(tile_partitioner.estimate_num_wgs_per_tile(), 1); + EXPECT_EQ(tile_partitioner.estimate_num_wgs_per_tile(), 2); } TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileEqualValue)