diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index 673f5abc34..08a8f85df3 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -10,6 +10,8 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" +#include +#include namespace ck_tile { @@ -810,4 +812,5 @@ struct StreamKTilePartitioner uint32_t M_, N_, K_; uint32_t num_tile_m_, num_tile_n_, num_tile_k_; }; + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp index 201684adc5..faab4cd55c 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp @@ -202,6 +202,117 @@ struct StreamKTilePartitionerBase index_t n_; }; +/** + * @brief Template for the Stream-K tile partitioner derived struct. + * + * This partitioner is responsible for mapping workgroups to tiles in the C tensor + * for the Stream-K algorithm. This struct is derived from + * StreamKTilePartitionerBase. Behavior of the + * StreamKTilePartitioner based on persistency will be in the template specializations. + * + * @tparam BlockGemmShapeType A class providing basic GEMM parameters. + * @tparam ReductionStrategy An enum that defines the reduction strategy for the results in the C + * Tensor. + * @tparam Persistent A bool that indicates whether to use a Persistent approach + */ +template +struct StreamKTilePartitioner_v2; + +/** + * @brief Persistent Stream-K tile partitioner derived struct. + * + * This partitioner is responsible for mapping workgroups to tiles in the C tensor + * for the Stream-K algorithm when using a Persistent approach where no extra workgroups + * are allocated for data parallel. + * + * @tparam BlockGemmShapeType A class providing basic GEMM parameters. + * @tparam ReductionStrategy An enum that defines the reduction strategy for the results in the C + * Tensor. + */ +template +struct StreamKTilePartitioner_v2 + : StreamKTilePartitionerBase +{ + StreamKTilePartitioner_v2(ck_tile::index_t m, + ck_tile::index_t n, + ck_tile::index_t k, + ck_tile::index_t grid); + + public: + /** + * @brief Calculates the launching grid size for the Stream-K kernel. In the Persistent + * case, no extra workgroups are allocated for the data parallel section, making the grid + * size num_cu * occupancy. + * + * @return dim_3 The launching grid size for the kernel. + */ + CK_TILE_HOST auto grid_size() const noexcept -> dim3; + + CK_TILE_HOST_DEVICE index_t get_dp_tiles_per_cta() const noexcept; + CK_TILE_HOST_DEVICE index_t get_extra_dp_tiles() const noexcept; + + protected: + /** + * @brief The total number of DP tiles per workgroup. + */ + int dp_tiles_per_cta_; + + /** + * @brief The total number of DP tiles left over when dp_tiles is not evenly divisible by grid. + */ + int extra_dp_tiles_; +}; + +/** + * @brief Non-Persistent Stream-K tile partitioner derived struct. + * + * This partitioner is responsible for mapping workgroups to tiles in the C tensor + * for the Stream-K algorithm when using a Non-Persistent approach where extra workgroups + * are allocated for the data parallel section. + * + * @tparam BlockGemmShapeType A class providing basic GEMM parameters. + * @tparam ReductionStrategy An enum that defines the reduction strategy for the results in the C + * Tensor. + */ +template +struct StreamKTilePartitioner_v2 + : StreamKTilePartitionerBase +{ + StreamKTilePartitioner_v2(ck_tile::index_t m, + ck_tile::index_t n, + ck_tile::index_t k, + ck_tile::index_t grid); + + public: + /** + * @brief Calculates the launching grid size for the Stream-K kernel. In the Non-Persistent + * case, extra workgroups are allocated for the data parallel section, making the grid + * size the total number of Stream-K and data parallel workgroups. + * + * @return dim_3 The launching grid size for the kernel. + */ + CK_TILE_HOST auto grid_size() const noexcept -> dim3; + CK_TILE_HOST_DEVICE index_t get_dp_ctas() const noexcept; + CK_TILE_HOST_DEVICE index_t get_dp_start_block_idx() const noexcept; + CK_TILE_HOST_DEVICE index_t get_sk_start_block_idx() const noexcept; + + protected: + /** + * @brief The total number of DP workgroups. + */ + int dp_ctas_; + + /** + * @brief The index that starts the DP workgroups, always 0 in our implementation. + */ + int dp_start_block_idx_; + + /** + * @brief The index that starts the Stream-K workgroups, set to the number of dp_tiles. + */ + int sk_start_block_idx_; +}; + } // namespace ck_tile #include "streamk_gemm_tile_partitioner_impl.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp index 12bc110cc2..cb31839546 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp @@ -211,4 +211,91 @@ StreamKTilePartitionerBase::get_n() const return n_; } +template +struct StreamKTilePartitioner_v2; + +// child class for Persistent Tile Partitioner +template +StreamKTilePartitioner_v2::StreamKTilePartitioner_v2( + ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid) + : StreamKTilePartitionerBase(m, n, k, grid) +{ // inherit from base constructor + dp_tiles_per_cta_ = this->dp_tiles_ / this->grid_; + extra_dp_tiles_ = this->dp_tiles_ % this->grid_; +} + +template +CK_TILE_HOST auto +StreamKTilePartitioner_v2::grid_size() const noexcept + -> dim3 +{ + if(extra_dp_tiles_ == 0) + { + return dim3(this->grid_, 1, 1); + } + else + { + return dim3(this->num_tiles_, 1, 1); + } +} + +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitioner_v2::get_dp_tiles_per_cta() + const noexcept +{ + return dp_tiles_per_cta_; +} + +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitioner_v2::get_extra_dp_tiles() + const noexcept +{ + return extra_dp_tiles_; +} + +// child class for Non-Persistent Tile Partitioner +template +StreamKTilePartitioner_v2::StreamKTilePartitioner_v2( + ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid) + : StreamKTilePartitionerBase(m, n, k, grid) +{ // inherit from base constructor + dp_ctas_ = this->dp_tiles_; + dp_start_block_idx_ = 0; + sk_start_block_idx_ = this->dp_tiles_; +} + +template +CK_TILE_HOST auto +StreamKTilePartitioner_v2::grid_size() const noexcept + -> dim3 +{ + return dim3(dp_ctas_ + this->get_sk_ctas(), 1, 1); +} + +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitioner_v2::get_dp_ctas() + const noexcept +{ + return dp_ctas_; +} + +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitioner_v2::get_dp_start_block_idx() + const noexcept +{ + return dp_start_block_idx_; +} + +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitioner_v2::get_sk_start_block_idx() + const noexcept +{ + return sk_start_block_idx_; +} + } // namespace ck_tile diff --git a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp index e89fe14773..968fadda51 100644 --- a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp +++ b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp @@ -347,3 +347,148 @@ TEST(StreamKTilePartitionerBaseGetOutputTileIndex, TestAllMappings) } } } + +// Persistent +TEST(StreamKTilePartitioner_v2_PersistentConstructor, SKOnly) +{ + using Config = StreamKTilePartitionerBaseConfigSKOnly; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + StreamKTilePartitionerV2PersistentExpected expected_values{0, 0, 3}; + validate_streamk_v2_persistent(expected_values, tile_partitioner); +} + +TEST(StreamKTilePartitioner_v2_PersistentConstructor, DPOnly) +{ + using Config = StreamKTilePartitionerBaseConfigDPOnly; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + StreamKTilePartitionerV2PersistentExpected expected_values{2, 0, 3}; + validate_streamk_v2_persistent(expected_values, tile_partitioner); +} + +TEST(StreamKTilePartitioner_v2_PersistentConstructor, DP2TileSK) +{ + using Config = StreamKTilePartitionerBaseConfigDP2TileSK; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + StreamKTilePartitionerV2PersistentExpected expected_values{1, 0, 3}; + validate_streamk_v2_persistent(expected_values, tile_partitioner); +} + +TEST(StreamKTilePartitioner_v2_PersistentConstructor, EdgeCase) +{ + using Config = StreamKTilePartitionerBaseConfigEdgeCase; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + StreamKTilePartitionerV2PersistentExpected expected_values{0, 1, 4}; + validate_streamk_v2_persistent(expected_values, tile_partitioner); +} + +TEST(StreamKTilePartitioner_v2_GridSize_Persistent, SKOnly) +{ + using Config = StreamKTilePartitionerBaseConfigSKOnly; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + const auto g = tile_partitioner.grid_size(); + EXPECT_EQ(g.x, Config::GRID); +} + +TEST(StreamKTilePartitioner_v2_GridSize_Persistent, EdgeCase) +{ + using Config = StreamKTilePartitionerBaseConfigEdgeCase; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + const auto g = tile_partitioner.grid_size(); + EXPECT_EQ(g.x, 1); +} + +// Non-Persistent Tests +TEST(StreamKTilePartitioner_v2_NonPersistentConstructor, SKOnly) +{ + using Config = StreamKTilePartitionerBaseConfigSKOnly; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + StreamKTilePartitionerV2NonPersistentExpected expected_values{0, 0, 0, 3}; + validate_streamk_v2_nonpersistent(expected_values, tile_partitioner); +} + +TEST(StreamKTilePartitioner_v2_NonPersistentConstructor, DPOnly) +{ + using Config = StreamKTilePartitionerBaseConfigDPOnly; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + StreamKTilePartitionerV2NonPersistentExpected expected_values{6, 0, 6, 3}; + validate_streamk_v2_nonpersistent(expected_values, tile_partitioner); +} + +TEST(StreamKTilePartitioner_v2_NonPersistentConstructor, DP2TileSK) +{ + using Config = StreamKTilePartitionerBaseConfigDP2TileSK; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + StreamKTilePartitionerV2NonPersistentExpected expected_values{3, 0, 3, 3}; + validate_streamk_v2_nonpersistent(expected_values, tile_partitioner); +} + +TEST(StreamKTilePartitioner_v2_NonPersistentConstructor, EdgeCase) +{ + using Config = StreamKTilePartitionerBaseConfigEdgeCase; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + StreamKTilePartitionerV2NonPersistentExpected expected_values{1, 0, 1, 4}; + validate_streamk_v2_nonpersistent(expected_values, tile_partitioner); +} + +TEST(StreamKTilePartitioner_v2_GridSize_NonPersistent, DP2TileSK) +{ + using Config = StreamKTilePartitionerBaseConfigDP2TileSK; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + const auto g = tile_partitioner.grid_size(); + EXPECT_EQ(g.x, 6); +} diff --git a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp index f88c92e0e4..03f149f6b6 100644 --- a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp +++ b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp @@ -297,4 +297,45 @@ void test_get_output_tile_index(ck_tile::index_t tile_idx, in_dev.FromDevice(&in); EXPECT_EQ(im, im_expected); EXPECT_EQ(in, in_expected); +}; + +// Configs for TilePartitioner Child structs +struct StreamKTilePartitionerV2PersistentExpected +{ + ck_tile::index_t dp_tiles_per_cta_; + ck_tile::index_t extra_dp_tiles_; + ck_tile::index_t grid_; +}; + +struct StreamKTilePartitionerV2NonPersistentExpected +{ + ck_tile::index_t dp_ctas_; + ck_tile::index_t dp_start_block_idx_; + ck_tile::index_t sk_start_block_idx_; + ck_tile::index_t grid_; +}; + +// Persistent +template +void validate_streamk_v2_persistent( + StreamKTilePartitionerV2PersistentExpected& expected_values, + ck_tile::StreamKTilePartitioner_v2& + tile_partitioner) +{ + EXPECT_EQ(tile_partitioner.get_dp_tiles_per_cta(), expected_values.dp_tiles_per_cta_); + EXPECT_EQ(tile_partitioner.get_extra_dp_tiles(), expected_values.extra_dp_tiles_); + EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_); +} + +// Non-Persistent +template +void validate_streamk_v2_nonpersistent( + StreamKTilePartitionerV2NonPersistentExpected& expected_values, + ck_tile::StreamKTilePartitioner_v2& + tile_partitioner) +{ + EXPECT_EQ(tile_partitioner.get_dp_ctas(), expected_values.dp_ctas_); + EXPECT_EQ(tile_partitioner.get_dp_start_block_idx(), expected_values.dp_start_block_idx_); + EXPECT_EQ(tile_partitioner.get_sk_start_block_idx(), expected_values.sk_start_block_idx_); + EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_); }