From 005f9fc5828422b54089e04e63c480e8ec8fa03b Mon Sep 17 00:00:00 2001 From: arai713 <67439843+arai713@users.noreply.github.com> Date: Fri, 20 Mar 2026 02:27:44 -0700 Subject: [PATCH] [CK_TILE] Rename Stream-K grid function (#4795) ## Motivation This PR introduces a change in the name of the get_grid function in the Stream-K TilePartitioner to avoid confusion with a similarly named method. In the Stream-K TilePartitioner, there is get_grid() which returns num_cu*occupancy and there is grid_size() which returns the grid size used to launch the kernel. In this PR, we change get_grid() to be get_max_active_wgs() to better reflect what the function returns and not confuse it with grid_size(). ## Technical Details Initially in the Stream-K TilePartitioner we had get_grid() which returned grid_. We are renaming get_grid() to get_max_active_wgs() and grid_ to max_active_wgs_ internally, while keeping grid_size() the same. The parameter, grid, for the Stream-K TilePartitioner remains the same to maintain consistency with the rest of the Stream-K API. ## Test Plan Validated using the test suite that is already present. ## Test Result All tests passed ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../streamk_gemm/streamk_gemm_kernel.hpp | 11 +-- .../streamk_gemm_tile_partitioner.hpp | 12 +-- .../streamk_gemm_tile_partitioner_impl.hpp | 36 ++++---- .../test_streamk_tile_partitioner.cpp | 70 ++++++++-------- .../test_streamk_tile_partitioner_common.hpp | 84 +++++++++---------- 5 files changed, 108 insertions(+), 105 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp index ac83babeb6..8ee6d3689c 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp @@ -119,7 +119,7 @@ struct StreamKKernel struct StreamKKernelArgs : ck_tile::UniversalGemmKernelArgs<> { - StreamKKernelArgs(const StreamKHostArgs& host_args, index_t grid) + StreamKKernelArgs(const StreamKHostArgs& host_args, index_t max_active_wgs) : UniversalGemmKernelArgs{host_args.as_ptr, host_args.bs_ptr, host_args.ds_ptr, @@ -135,7 +135,8 @@ struct StreamKKernel // The workspace pointer is set to nullptr because we must first // instantiate the TilePartitioner to get the necessary size workspace_ptr{nullptr}, - tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}} + tile_partitioner{ + TilePartitioner{host_args.M, host_args.N, host_args.K, max_active_wgs}} { } @@ -206,9 +207,9 @@ struct StreamKKernel int num_cu = NumCU(), int occupancy = Occupancy()) { - const index_t grid = num_cu * occupancy; + const index_t max_active_wgs = num_cu * occupancy; - return StreamKKernelArgs{host_args, grid}; + return StreamKKernelArgs{host_args, max_active_wgs}; } template @@ -790,7 +791,7 @@ struct StreamKKernel // Data-parallel section for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles(); - tile_idx += kargs.tile_partitioner.get_grid()) + tile_idx += kargs.tile_partitioner.get_max_active_wgs()) { BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0); block_sync_lds(); diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp index f028ba0c62..15311f4eec 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp @@ -31,7 +31,7 @@ struct StreamKTilePartitionerBase ? memory_operation_enum::atomic_add : memory_operation_enum::set; - StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid); + StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t max_number_wgs); /** * @brief Calculates the total space needed for the partials buffer. @@ -156,7 +156,7 @@ struct StreamKTilePartitionerBase * @brief Returns the maximum number of active workgroups; this is assumed to be number of CUs * * occupancy. */ - CK_TILE_HOST_DEVICE index_t get_grid() const noexcept; + CK_TILE_HOST_DEVICE index_t get_max_active_wgs() const noexcept; /** * @brief Returns the number of tiles in the C tensor that will use the data-parallel (DP) @@ -215,7 +215,7 @@ struct StreamKTilePartitionerBase protected: index_t num_tiles_; - index_t grid_; + index_t max_active_wgs_; index_t dp_tiles_; private: @@ -270,7 +270,7 @@ struct StreamKTilePartitioner StreamKTilePartitioner(ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, - ck_tile::index_t grid); + ck_tile::index_t max_active_wgs); public: static constexpr bool PERSISTENT = true; @@ -290,7 +290,7 @@ struct StreamKTilePartitioner /** * @brief Returns the total number of DP tiles left over when `dp_tiles_` is not evenly - * divisible by `grid_`. + * divisible by `max_active_wgs_`. */ CK_TILE_HOST_DEVICE index_t get_extra_dp_tiles() const noexcept; @@ -317,7 +317,7 @@ struct StreamKTilePartitioner StreamKTilePartitioner(ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, - ck_tile::index_t grid); + ck_tile::index_t max_number_wgs); public: static constexpr bool PERSISTENT = false; diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp index 52cfea5872..229eefc1db 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp @@ -7,24 +7,25 @@ namespace ck_tile { template StreamKTilePartitionerBase::StreamKTilePartitionerBase( - index_t m, index_t n, index_t k, index_t grid) - : grid_{grid}, n_{n} + index_t m, index_t n, index_t k, index_t max_active_wgs) + : max_active_wgs_{max_active_wgs}, n_{n} { iters_per_tile_ = integer_divide_ceil(k, KPerBlock); num_tiles_ = integer_divide_ceil(m, MPerBlock) * integer_divide_ceil(n_, NPerBlock); - bool big_enough = num_tiles_ > grid_; - index_t remainder_tiles = num_tiles_ % grid_; + bool big_enough = num_tiles_ > max_active_wgs_; + index_t remainder_tiles = num_tiles_ % max_active_wgs_; if(remainder_tiles) { - sk_tiles_ = big_enough ? full_tiles_ * grid_ + (num_tiles_ % grid_) : num_tiles_; - sk_tiles_ = min(num_tiles_, sk_tiles_); - sk_ctas_ = grid_; + sk_tiles_ = big_enough ? full_tiles_ * max_active_wgs_ + (num_tiles_ % max_active_wgs_) + : num_tiles_; + sk_tiles_ = min(num_tiles_, sk_tiles_); + sk_ctas_ = max_active_wgs_; total_sk_iters_ = sk_tiles_ * iters_per_tile_; // If there still isn't enough work to saturate all CUs, then just revert to DP only. - if(total_sk_iters_ < grid_) + if(total_sk_iters_ < max_active_wgs_) { sk_tiles_ = 0; sk_ctas_ = 0; @@ -175,9 +176,10 @@ StreamKTilePartitionerBase::get_num_t template CK_TILE_HOST_DEVICE index_t -StreamKTilePartitionerBase::get_grid() const noexcept +StreamKTilePartitionerBase::get_max_active_wgs() + const noexcept { - return grid_; + return max_active_wgs_; } template @@ -287,11 +289,11 @@ struct StreamKTilePartitioner; // child class for Persistent Tile Partitioner template StreamKTilePartitioner::StreamKTilePartitioner( - ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid) - : StreamKTilePartitionerBase(m, n, k, grid) + ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t max_active_wgs) + : StreamKTilePartitionerBase(m, n, k, max_active_wgs) { // inherit from base constructor - dp_tiles_per_cta_ = this->dp_tiles_ / this->grid_; - extra_dp_tiles_ = this->dp_tiles_ % this->grid_; + dp_tiles_per_cta_ = this->dp_tiles_ / this->max_active_wgs_; + extra_dp_tiles_ = this->dp_tiles_ % this->max_active_wgs_; } template @@ -301,7 +303,7 @@ StreamKTilePartitioner::grid_si { if(extra_dp_tiles_ == 0) { - return dim3(this->grid_, 1, 1); + return dim3(this->max_active_wgs_, 1, 1); } else { @@ -328,8 +330,8 @@ StreamKTilePartitioner::get_ext // child class for Non-Persistent Tile Partitioner template StreamKTilePartitioner::StreamKTilePartitioner( - ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid) - : StreamKTilePartitionerBase(m, n, k, grid) + ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t max_active_wgs) + : StreamKTilePartitionerBase(m, n, k, max_active_wgs) { // inherit from base constructor dp_ctas_ = this->dp_tiles_; dp_start_block_idx_ = 0; diff --git a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp index 75c3e0b4fb..c71656cf6b 100644 --- a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp +++ b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp @@ -8,10 +8,10 @@ TEST(StreamKTilePartitionerBaseConstructor, SKOnly) using Config = StreamKTilePartitionerBaseConfigSKOnly; ck_tile::StreamKTilePartitionerBase tile_partitioner{ - Config::M, Config::N, Config::K, Config::GRID}; + Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; StreamKTilePartitionerBaseExpected expected_values{ - 2, 0, 3, 4, 1, 2, 1, 0, 2, Config::GRID, Config::N}; + 2, 0, 3, 4, 1, 2, 1, 0, 2, Config::MAX_ACTIVE_WGS, Config::N}; validate_streamk_base_constructor(expected_values, tile_partitioner); } @@ -20,10 +20,10 @@ TEST(StreamKTilePartitionerBaseConstructor, DPOnly) using Config = StreamKTilePartitionerBaseConfigDPOnly; ck_tile::StreamKTilePartitionerBase tile_partitioner{ - Config::M, Config::N, Config::K, Config::GRID}; + Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; StreamKTilePartitionerBaseExpected expected_values{ - 0, 6, 0, 0, 0, 2, 0, 12, 6, Config::GRID, Config::N}; + 0, 6, 0, 0, 0, 2, 0, 12, 6, Config::MAX_ACTIVE_WGS, Config::N}; validate_streamk_base_constructor(expected_values, tile_partitioner); } @@ -32,10 +32,10 @@ TEST(StreamKTilePartitionerBaseConstructor, DP2TileSK) using Config = StreamKTilePartitionerBaseConfigDP2TileSK; ck_tile::StreamKTilePartitionerBase tile_partitioner{ - Config::M, Config::N, Config::K, Config::GRID}; + Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; StreamKTilePartitionerBaseExpected expected_values{ - 4, 3, 3, 8, 2, 2, 2, 6, 7, Config::GRID, Config::N}; + 4, 3, 3, 8, 2, 2, 2, 6, 7, Config::MAX_ACTIVE_WGS, Config::N}; validate_streamk_base_constructor(expected_values, tile_partitioner); } @@ -44,10 +44,10 @@ TEST(StreamKTilePartitionerBaseConstructor, EdgeCase) using Config = StreamKTilePartitionerBaseConfigEdgeCase; ck_tile::StreamKTilePartitionerBase tile_partitioner{ - Config::M, Config::N, Config::K, Config::GRID}; + Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; StreamKTilePartitionerBaseExpected expected_values{ - 0, 1, 0, 0, 0, 2, 0, 2, 1, Config::GRID, Config::N}; + 0, 1, 0, 0, 0, 2, 0, 2, 1, Config::MAX_ACTIVE_WGS, Config::N}; validate_streamk_base_constructor(expected_values, tile_partitioner); } @@ -57,7 +57,7 @@ TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsLessThan128Bytes) ck_tile::StreamKTilePartitionerBase - tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 128); } @@ -68,7 +68,7 @@ TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsEqual128Bytes) ck_tile::StreamKTilePartitionerBase - tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 128); } @@ -79,7 +79,7 @@ TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsGreaterThan128Bytes) ck_tile::StreamKTilePartitionerBase - tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 256); } @@ -89,7 +89,7 @@ TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, AtomicStrategy) using Config = StreamKTilePartitionerBaseConfigDP2TileSK; ck_tile::StreamKTilePartitionerBase tile_partitioner{ - Config::M, Config::N, Config::K, Config::GRID}; + Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; EXPECT_EQ(tile_partitioner.get_workspace_size(sizeof(float)), 0); } @@ -100,12 +100,12 @@ TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, ReductionStrategy) ck_tile::StreamKTilePartitionerBase - tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; ck_tile::index_t expected_partials_size = - sizeof(float) * Config::M_TILE * Config::N_TILE * Config::GRID; - // Since GRID is 3, the final padded flags array must be 128B to ensure the total byte size of - // the flags array is 128B-aligned. + sizeof(float) * Config::M_TILE * Config::N_TILE * Config::MAX_ACTIVE_WGS; + // Since MAX_ACTIVE_WGS is 3, the final padded flags array must be 128B to ensure the total byte + // size of the flags array is 128B-aligned. ck_tile::index_t expected_flags_size = 128; EXPECT_EQ(tile_partitioner.get_workspace_size(sizeof(float)), @@ -117,7 +117,7 @@ TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileLower using Config = StreamKTilePartitionerBaseConfigDP2TileSK; ck_tile::StreamKTilePartitionerBase tile_partitioner{ - Config::M, Config::N, Config::K, Config::GRID}; + Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; EXPECT_EQ(tile_partitioner.estimate_num_wgs_per_tile(), 2); } @@ -127,7 +127,7 @@ TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileEqual using Config = StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile; ck_tile::StreamKTilePartitionerBase tile_partitioner{ - Config::M, Config::N, Config::K, Config::GRID}; + Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; EXPECT_EQ(tile_partitioner.estimate_num_wgs_per_tile(), 2); } @@ -232,7 +232,7 @@ TEST(StreamKTilePartitionerBaseGetTileBoundaries, GetTileBoundaries) // Test parameters ck_tile::StreamKTilePartitionerBase tile_partitioner{ - Config::M, Config::N, Config::K, Config::GRID}; + Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; ck_tile::DeviceMem tile_iter_start_dev(sizeof(ck_tile::index_t)); ck_tile::DeviceMem tile_iter_end_dev(sizeof(ck_tile::index_t)); ck_tile::index_t tile_idx = 1; @@ -267,7 +267,7 @@ TEST(StreamKTilePartitionerBaseGetTileIndex, GetTileIndex) // Test parameters ck_tile::StreamKTilePartitionerBase tile_partitioner{ - Config::M, Config::N, Config::K, Config::GRID}; + Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; ck_tile::DeviceMem tile_idx_dev(sizeof(ck_tile::index_t)); ck_tile::index_t iter_start = 8; @@ -299,7 +299,7 @@ TEST(StreamKTilePartitionerBaseGetIterBoundaries, ZeroExtraItersBeforeMe) // Test parameters ck_tile::StreamKTilePartitionerBase tile_partitioner{ - Config::M, Config::N, Config::K, Config::GRID}; + Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; ck_tile::DeviceMem iter_start_dev(sizeof(ck_tile::index_t)); ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t)); ck_tile::index_t cta_idx = 0; @@ -333,7 +333,7 @@ TEST(StreamKTilePartitionerBaseGetIterBoundaries, NonZeroExtraItersBeforeMe) // Test parameters ck_tile::StreamKTilePartitionerBase tile_partitioner{ - Config::M, Config::N, Config::K, Config::GRID}; + Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; ck_tile::DeviceMem iter_start_dev(sizeof(ck_tile::index_t)); ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t)); ck_tile::index_t cta_idx = 1; @@ -367,7 +367,7 @@ TEST(StreamKTilePartitionerBaseGetIterBoundaries, MinIsExtraIters) // Test parameters ck_tile::StreamKTilePartitionerBase tile_partitioner{ - Config::M, Config::N, Config::K, Config::GRID}; + Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; ck_tile::DeviceMem iter_start_dev(sizeof(ck_tile::index_t)); ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t)); ck_tile::index_t cta_idx = 2; @@ -493,7 +493,7 @@ TEST(StreamKTilePartitioner_PersistentConstructor, SKOnly) ck_tile:: StreamKTilePartitioner - tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; StreamKTilePartitionerV2PersistentExpected expected_values{0, 0, 3}; validate_streamk_persistent(expected_values, tile_partitioner); @@ -506,7 +506,7 @@ TEST(StreamKTilePartitioner_PersistentConstructor, DPOnly) ck_tile::StreamKTilePartitioner - tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; StreamKTilePartitionerV2PersistentExpected expected_values{2, 0, 3}; validate_streamk_persistent(expected_values, tile_partitioner); @@ -519,7 +519,7 @@ TEST(StreamKTilePartitioner_PersistentConstructor, DP2TileSK) ck_tile::StreamKTilePartitioner - tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; StreamKTilePartitionerV2PersistentExpected expected_values{1, 0, 3}; validate_streamk_persistent(expected_values, tile_partitioner); @@ -532,7 +532,7 @@ TEST(StreamKTilePartitioner_PersistentConstructor, EdgeCase) ck_tile::StreamKTilePartitioner - tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; StreamKTilePartitionerV2PersistentExpected expected_values{0, 1, 4}; validate_streamk_persistent(expected_values, tile_partitioner); @@ -545,10 +545,10 @@ TEST(StreamKTilePartitioner_GridSize_Persistent, SKOnly) ck_tile::StreamKTilePartitioner - tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; const auto g = tile_partitioner.grid_size(); - EXPECT_EQ(g.x, Config::GRID); + EXPECT_EQ(g.x, Config::MAX_ACTIVE_WGS); } TEST(StreamKTilePartitioner_GridSize_Persistent, EdgeCase) @@ -558,7 +558,7 @@ TEST(StreamKTilePartitioner_GridSize_Persistent, EdgeCase) ck_tile::StreamKTilePartitioner - tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; const auto g = tile_partitioner.grid_size(); EXPECT_EQ(g.x, 1); @@ -571,7 +571,7 @@ TEST(StreamKTilePartitioner_NonPersistentConstructor, SKOnly) ck_tile:: StreamKTilePartitioner - tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; StreamKTilePartitionerV2NonPersistentExpected expected_values{0, 0, 0, 3}; validate_streamk_nonpersistent(expected_values, tile_partitioner); @@ -584,7 +584,7 @@ TEST(StreamKTilePartitioner_NonPersistentConstructor, DPOnly) ck_tile::StreamKTilePartitioner - tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; StreamKTilePartitionerV2NonPersistentExpected expected_values{6, 0, 6, 3}; validate_streamk_nonpersistent(expected_values, tile_partitioner); @@ -597,7 +597,7 @@ TEST(StreamKTilePartitioner_NonPersistentConstructor, DP2TileSK) ck_tile::StreamKTilePartitioner - tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; StreamKTilePartitionerV2NonPersistentExpected expected_values{3, 0, 3, 3}; validate_streamk_nonpersistent(expected_values, tile_partitioner); @@ -610,7 +610,7 @@ TEST(StreamKTilePartitioner_NonPersistentConstructor, EdgeCase) ck_tile::StreamKTilePartitioner - tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; StreamKTilePartitionerV2NonPersistentExpected expected_values{1, 0, 1, 4}; validate_streamk_nonpersistent(expected_values, tile_partitioner); @@ -623,7 +623,7 @@ TEST(StreamKTilePartitioner_GridSize_NonPersistent, DP2TileSK) ck_tile::StreamKTilePartitioner - tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; const auto g = tile_partitioner.grid_size(); EXPECT_EQ(g.x, 6); diff --git a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp index 31217ba101..6aecd49a3c 100644 --- a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp +++ b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp @@ -165,7 +165,7 @@ struct StreamKTilePartitionerBaseExpected ck_tile::index_t extra_iters_; ck_tile::index_t total_dp_iters_; ck_tile::index_t num_tiles_; - ck_tile::index_t grid_; + ck_tile::index_t max_active_wgs_; ck_tile::index_t n_; }; @@ -183,7 +183,7 @@ void validate_streamk_base_constructor( EXPECT_EQ(tile_partitioner.get_iters_per_tile(), expected_values.iters_per_tile_); EXPECT_EQ(tile_partitioner.get_total_dp_iters(), expected_values.total_dp_iters_); EXPECT_EQ(tile_partitioner.get_num_tiles(), expected_values.num_tiles_); - EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_); + EXPECT_EQ(tile_partitioner.get_max_active_wgs(), expected_values.max_active_wgs_); EXPECT_EQ(tile_partitioner.get_n(), expected_values.n_); } @@ -201,9 +201,9 @@ struct StreamKTilePartitionerBaseConfigDP2TileSK : public StreamKTilePartitioner static constexpr ck_tile::index_t M = 28; static constexpr ck_tile::index_t N = 4; static constexpr ck_tile::index_t K = 16; - // The minimum number of bytes needed for the flags array is GRID * 4B = 3 * 4B = 12B. To ensure - // the total byte size of the array is 128B-aligned, the flags array must be 128B. - static constexpr ck_tile::index_t GRID = 3; + // The minimum number of bytes needed for the flags array is MAX_ACTIVE_WGS * 4B = 3 * 4B = 12B. + // To ensure the total byte size of the array is 128B-aligned, the flags array must be 128B. + static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 3; static constexpr ck_tile::index_t M_TILE = 4; static constexpr ck_tile::index_t N_TILE = 4; @@ -220,9 +220,9 @@ struct StreamKTilePartitionerBaseConfigFlagsSizeEqual128Bytes static constexpr ck_tile::index_t M = 28; static constexpr ck_tile::index_t N = 4; static constexpr ck_tile::index_t K = 32; - // The minimum number of bytes needed for the flags array is GRID * 4B = 32 * 4B = 128B. So, the - // number of bytes for the flags array should be 128B. - static constexpr ck_tile::index_t GRID = 32; + // The minimum number of bytes needed for the flags array is MAX_ACTIVE_WGS * 4B = 32 * 4B = + // 128B. So, the number of bytes for the flags array should be 128B. + static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 32; static constexpr ck_tile::index_t M_TILE = 4; static constexpr ck_tile::index_t N_TILE = 4; @@ -239,10 +239,10 @@ struct StreamKTilePartitionerBaseConfigFlagsSizeGreaterThan128Bytes static constexpr ck_tile::index_t M = 28; static constexpr ck_tile::index_t N = 4; static constexpr ck_tile::index_t K = 33; - // The minimum number of bytes needed for the flags array is GRID * 4B = 33 * 4B = 132B. So, the - // number of bytes for the flags array should be 2 * 128B = 256B to ensure the total byte size - // of the array is 128B-aligned. - static constexpr ck_tile::index_t GRID = 33; + // The minimum number of bytes needed for the flags array is MAX_ACTIVE_WGS * 4B = 33 * 4B = + // 132B. So, the number of bytes for the flags array should be 2 * 128B = 256B to ensure the + // total byte size of the array is 128B-aligned. + static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 33; static constexpr ck_tile::index_t M_TILE = 4; static constexpr ck_tile::index_t N_TILE = 4; @@ -256,10 +256,10 @@ struct StreamKTilePartitionerBaseConfigFlagsSizeGreaterThan128Bytes struct StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile : public StreamKTilePartitionerBaseConfig { - static constexpr ck_tile::index_t M = 16; - static constexpr ck_tile::index_t N = 4; - static constexpr ck_tile::index_t K = 16; - static constexpr ck_tile::index_t GRID = 8; + static constexpr ck_tile::index_t M = 16; + static constexpr ck_tile::index_t N = 4; + static constexpr ck_tile::index_t K = 16; + static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 8; static constexpr ck_tile::index_t M_TILE = 4; static constexpr ck_tile::index_t N_TILE = 4; @@ -272,10 +272,10 @@ struct StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile struct StreamKTilePartitionerBaseConfigDPOnly : public StreamKTilePartitionerBaseConfig { - static constexpr ck_tile::index_t M = 12; - static constexpr ck_tile::index_t N = 4; - static constexpr ck_tile::index_t K = 16; - static constexpr ck_tile::index_t GRID = 3; + static constexpr ck_tile::index_t M = 12; + static constexpr ck_tile::index_t N = 4; + static constexpr ck_tile::index_t K = 16; + static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 3; static constexpr ck_tile::index_t M_TILE = 4; static constexpr ck_tile::index_t N_TILE = 2; @@ -288,10 +288,10 @@ struct StreamKTilePartitionerBaseConfigDPOnly : public StreamKTilePartitionerBas struct StreamKTilePartitionerBaseConfigSKOnly : public StreamKTilePartitionerBaseConfig { - static constexpr ck_tile::index_t M = 4; - static constexpr ck_tile::index_t N = 4; - static constexpr ck_tile::index_t K = 16; - static constexpr ck_tile::index_t GRID = 3; + static constexpr ck_tile::index_t M = 4; + static constexpr ck_tile::index_t N = 4; + static constexpr ck_tile::index_t K = 16; + static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 3; static constexpr ck_tile::index_t M_TILE = 4; static constexpr ck_tile::index_t N_TILE = 2; @@ -304,10 +304,10 @@ struct StreamKTilePartitionerBaseConfigSKOnly : public StreamKTilePartitionerBas struct StreamKTilePartitionerBaseConfigSKOnlyLargeK : public StreamKTilePartitionerBaseConfig { - static constexpr ck_tile::index_t M = 8; - static constexpr ck_tile::index_t N = 2; - static constexpr ck_tile::index_t K = 10; - static constexpr ck_tile::index_t GRID = 5; + static constexpr ck_tile::index_t M = 8; + static constexpr ck_tile::index_t N = 2; + static constexpr ck_tile::index_t K = 10; + static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 5; static constexpr ck_tile::index_t M_TILE = 4; static constexpr ck_tile::index_t N_TILE = 2; @@ -321,10 +321,10 @@ struct StreamKTilePartitionerBaseConfigSKOnlyLargeK : public StreamKTilePartitio struct StreamKTilePartitionerBaseConfigEdgeCase : public StreamKTilePartitionerBaseConfig { - static constexpr ck_tile::index_t M = 4; - static constexpr ck_tile::index_t N = 4; - static constexpr ck_tile::index_t K = 16; - static constexpr ck_tile::index_t GRID = 4; + static constexpr ck_tile::index_t M = 4; + static constexpr ck_tile::index_t N = 4; + static constexpr ck_tile::index_t K = 16; + static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 4; static constexpr ck_tile::index_t M_TILE = 4; static constexpr ck_tile::index_t N_TILE = 4; @@ -340,10 +340,10 @@ struct StreamKTilePartitionerBaseConfigLargerCTensor : public StreamKTilePartiti // This config has 3 macro tiles in the M dimension and 4 macro tiles in the N dimension. // This facilitates testing the get_output_tile_index method. - static constexpr ck_tile::index_t M = 12; - static constexpr ck_tile::index_t N = 16; - static constexpr ck_tile::index_t K = 16; - static constexpr ck_tile::index_t GRID = 4; + static constexpr ck_tile::index_t M = 12; + static constexpr ck_tile::index_t N = 16; + static constexpr ck_tile::index_t K = 16; + static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 4; static constexpr ck_tile::index_t M_TILE = 4; static constexpr ck_tile::index_t N_TILE = 4; @@ -366,7 +366,7 @@ void test_get_output_tile_index(ck_tile::index_t tile_idx, // Test parameters ck_tile::StreamKTilePartitionerBase tile_partitioner{ - Config::M, Config::N, Config::K, Config::GRID}; + Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; ck_tile::DeviceMem im_dev(sizeof(ck_tile::index_t)); ck_tile::DeviceMem in_dev(sizeof(ck_tile::index_t)); @@ -402,7 +402,7 @@ void test_get_tile_local_cta_idx(ck_tile::index_t tile_iter_start, // Test parameters ck_tile::StreamKTilePartitionerBase tile_partitioner{ - Config::M, Config::N, Config::K, Config::GRID}; + Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS}; ck_tile::DeviceMem tile_local_cta_idx_dev(sizeof(ck_tile::index_t)); // Launch kernel @@ -426,7 +426,7 @@ struct StreamKTilePartitionerV2PersistentExpected { ck_tile::index_t dp_tiles_per_cta_; ck_tile::index_t extra_dp_tiles_; - ck_tile::index_t grid_; + ck_tile::index_t max_active_wgs_; }; struct StreamKTilePartitionerV2NonPersistentExpected @@ -434,7 +434,7 @@ struct StreamKTilePartitionerV2NonPersistentExpected ck_tile::index_t dp_ctas_; ck_tile::index_t dp_start_block_idx_; ck_tile::index_t sk_start_block_idx_; - ck_tile::index_t grid_; + ck_tile::index_t max_active_wgs_; }; // Persistent @@ -446,7 +446,7 @@ void validate_streamk_persistent( { EXPECT_EQ(tile_partitioner.get_dp_tiles_per_cta(), expected_values.dp_tiles_per_cta_); EXPECT_EQ(tile_partitioner.get_extra_dp_tiles(), expected_values.extra_dp_tiles_); - EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_); + EXPECT_EQ(tile_partitioner.get_max_active_wgs(), expected_values.max_active_wgs_); } // Non-Persistent @@ -459,5 +459,5 @@ void validate_streamk_nonpersistent( EXPECT_EQ(tile_partitioner.get_dp_ctas(), expected_values.dp_ctas_); EXPECT_EQ(tile_partitioner.get_dp_start_block_idx(), expected_values.dp_start_block_idx_); EXPECT_EQ(tile_partitioner.get_sk_start_block_idx(), expected_values.sk_start_block_idx_); - EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_); + EXPECT_EQ(tile_partitioner.get_max_active_wgs(), expected_values.max_active_wgs_); }