Merge commit '02ab76c2cb47143b82743bcf9d86389c540a608b' into develop

This commit is contained in:
assistant-librarian[bot]
2025-11-22 04:13:58 +00:00
parent 343e40d0e9
commit d7685c394a
3 changed files with 18 additions and 8 deletions

View File

@@ -219,17 +219,27 @@ CK_TILE_HOST index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::estimate_num_wgs_per_tile()
const noexcept
{
// In the case of non-atomic reduction or data-parallel only, there will always be 1 workgroup
// writing final results to a given macro tile in C.
// In the case of non-atomic reduction or data-parallel (DP) only, there will always be 1
// workgroup writing final results to a given macro tile in C.
int num_wgs_per_tile = 1;
// Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C.
if(sk_ctas_ > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
{
ck_tile::index_t iters_per_sk_cta_non_zero = ck_tile::max(iters_per_sk_cta_, 1);
// Estimate the number of workgroups per macro tile.
num_wgs_per_tile = (iters_per_tile_ / iters_per_sk_cta_non_zero) +
((iters_per_tile_ % iters_per_sk_cta_non_zero) != 0);
// If we have DP and SK tiles, this is DP+2TSK which guarantees at most 2 workgroups per
// tile. We only need to check that dp_tiles is greater than zero since we know we have SK
// workgroups.
if(dp_tiles_ > 0)
{
num_wgs_per_tile = 2;
}
else
{
ck_tile::index_t iters_per_sk_cta_non_zero = ck_tile::max(iters_per_sk_cta_, 1);
// Estimate the number of workgroups per macro tile.
num_wgs_per_tile = (iters_per_tile_ / iters_per_sk_cta_non_zero) +
((iters_per_tile_ % iters_per_sk_cta_non_zero) != 0);
}
}
return std::max(num_wgs_per_tile, 1);

View File

@@ -13,7 +13,7 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS
set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
# Currently test_ck_tile_streamk_smoke is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR})

View File

@@ -84,7 +84,7 @@ TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileLower
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
Config::M, Config::N, Config::K, Config::GRID};
EXPECT_EQ(tile_partitioner.estimate_num_wgs_per_tile(), 1);
EXPECT_EQ(tile_partitioner.estimate_num_wgs_per_tile(), 2);
}
TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileEqualValue)