diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index 08a8f85df3..673f5abc34 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -10,8 +10,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" -#include -#include namespace ck_tile { @@ -812,5 +810,4 @@ struct StreamKTilePartitioner uint32_t M_, N_, K_; uint32_t num_tile_m_, num_tile_n_, num_tile_k_; }; - } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp index faab4cd55c..1962f3518a 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp @@ -14,20 +14,20 @@ namespace ck_tile { * This partitioner is responsible for mapping workgroups to tiles in the C tensor * for the Stream-K algorithm. * - * @tparam BlockGemmShapeType A class providing basic GEMM parameters. - * @tparam ReductionStrategy An enum that defines the reduction strategy for the results in the C - * Tensor. + * @tparam BlockGemmShapeType A class providing basic GEMM parameters. + * @tparam ReductionStrategyType An enum that defines the reduction strategy for the results in + * the C Tensor. */ template + StreamKReductionStrategy ReductionStrategyType = StreamKReductionStrategy::Atomic> struct StreamKTilePartitionerBase { using BlockGemmShape = BlockGemmShapeType; - static constexpr index_t MPerBlock = BlockGemmShape::kM; - static constexpr index_t NPerBlock = BlockGemmShape::kN; - static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr StreamKReductionStrategy StreamKReductionStrategy = ReductionStrategy; + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategyType; StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid); @@ -51,59 +51,62 @@ struct StreamKTilePartitionerBase /** * @brief Calculates the start and end iteration given the cta_idx. * - * @param iter Reference to an index_t; will be set to the starting iteration by the + * @param iter_start Reference to an index_t; will be set to the starting iteration by the * function. - * @param iter_end Reference to an index_t; will be set to the non-inclusive end iteration by + * @param iter_end Reference to an index_t; will be set to the non-inclusive end iteration by * the function. - * @param cta_idx The current Stream-K workgroup's index. + * @param cta_idx The current Stream-K workgroup's index. * @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a * non-persistent DP section is used, then a Stream-K workgroup's `cta_idx` should be something * like `blockIdx.x` minus number of DP workgroups. */ CK_TILE_DEVICE void - get_iter_boundaries(index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept; + get_iter_boundaries(index_t& iter_start, index_t& iter_end, index_t cta_idx) const noexcept; /** * @brief Calculates the 1D tile index in the C tensor for a workgroup. * - * @param iter The starting iteration. - * @return index_t The 1D tile index. + * @param iter_start The starting iteration. + * @return index_t The 1D tile index. */ - CK_TILE_DEVICE index_t get_tile_index(index_t iter) const noexcept; + CK_TILE_DEVICE index_t get_tile_index(index_t iter_start) const noexcept; /** * @brief Calculates the starting and ending tile boundaries for the given 1D tile index. * - * @param tile_iter Reference to an index_t; will be set to the tile's start iteration by + * @param tile_iter_start Reference to an index_t; will be set to the tile's start iteration by * the function. - * @param tile_iter_end Reference to an index_t; will be set to the non-inclusive tile's end + * @param tile_iter_end Reference to an index_t; will be set to the non-inclusive tile's end * iteration by the function. * @param tile_idx The 1D C tensor tile index for the workgroup. */ - CK_TILE_DEVICE void get_tile_boundaries(index_t& tile_iter, + CK_TILE_DEVICE void get_tile_boundaries(index_t& tile_iter_start, index_t& tile_iter_end, index_t tile_idx) const noexcept; /** * @brief Calculates the workgroup's starting iteration that is local to a tile. * - * @param iter The starting iteration. + * @param iter_start The starting iteration. + * @param tile_iter_start The starting iteration of the tile (i.e., the tile's starting + * boundary). * @return index_t The local starting iteration. The value is in range [0, `iters_per_tile_`). - * @note Assumes `iter` >= `tile_iter`. + * @note Assumes `iter_start` >= `tile_iter_start`. */ - CK_TILE_DEVICE static index_t get_local_iter(index_t iter, index_t tile_iter) noexcept; + CK_TILE_DEVICE static index_t get_local_iter(index_t iter_start, + index_t tile_iter_start) noexcept; /** * @brief Calculates the workgroup's non-inclusive end iteration that is local to a tile. * - * @param tile_iter The starting tile iteration. - * @param iter_end The non-inclusive end iteration. - * @param tile_iter_end The non-inclusive end iteration of the tile. - * @return index_t The local non-inclusive end iteration. - * @note Assumes `iter_end` >= `tile_iter` and `tile_iter_end` >= `tile_iter`. + * @param tile_iter_start The starting tile iteration. + * @param iter_end The non-inclusive end iteration. + * @param tile_iter_end The non-inclusive end iteration of the tile. + * @return index_t The local non-inclusive end iteration. + * @note Assumes `iter_end` >= `tile_iter_start` and `tile_iter_end` >= `tile_iter_start`. */ CK_TILE_DEVICE static index_t - get_local_iter_end(index_t tile_iter, index_t iter_end, index_t tile_iter_end) noexcept; + get_local_iter_end(index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept; /** * @brief Calculates the workgroups 2D tile index in the C tensor given the 1D tile index. @@ -122,83 +125,85 @@ struct StreamKTilePartitionerBase */ CK_TILE_HOST index_t get_workspace_size(index_t acc_element_bytes) const noexcept; + /** + * @brief Returns the number of macro tiles in the C tensor. + */ CK_TILE_HOST_DEVICE index_t get_num_tiles() const noexcept; + /** + * @brief Returns the maximum number of active workgroups; this is assumed to be number of CUs * + * occupancy. + */ CK_TILE_HOST_DEVICE index_t get_grid() const noexcept; + /** + * @brief Returns the number of tiles in the C tensor that will use the data-parallel (DP) + * approach. + */ CK_TILE_HOST_DEVICE index_t get_dp_tiles() const noexcept; + /** + * @brief Returns the number of tiles in the C tensor that will use the Stream-K approach. + */ CK_TILE_HOST_DEVICE index_t get_sk_tiles() const noexcept; + /** + * @brief Returns the number of workgroups that will participate in Stream-K in the `sk_tiles_`. + */ CK_TILE_HOST_DEVICE index_t get_sk_ctas() const noexcept; + /** + * @brief Returns the total number of Stream-K iterations. + */ CK_TILE_HOST_DEVICE index_t get_total_sk_iters() const noexcept; + /** + * @brief Returns the total number of iterations per tile in the C tensor. In other words, this + * is the total number of macro tiles along the K dimension of A and B. + */ CK_TILE_HOST_DEVICE index_t get_iters_per_tile() const noexcept; + /** + * @brief Returns the total number of Stream-K iterations for each `sk_cta`. This is the lower + * bound (i.e., all `sk_ctas_` are guaranteed to perform at least this many iterations). + */ CK_TILE_HOST_DEVICE index_t get_iters_per_sk_cta() const noexcept; + /** + * @brief Returns the remainder resulting from `total_sk_iters_` divided by `sk_ctas_`. When + * this is non-zero, the first `extra_iters_` `sk_ctas_` will get one additional iteration + * assigned to them; such work groups will perform (`iters_per_sk_cta_` + 1) iterations. + */ CK_TILE_HOST_DEVICE index_t get_extra_iters() const noexcept; + /** + * @brief Returns the total number of DP iterations. + */ CK_TILE_HOST_DEVICE index_t get_total_dp_iters() const noexcept; + /** + * @brief Returns the n dimension for the GEMM problem. + */ CK_TILE_HOST_DEVICE index_t get_n() const noexcept; protected: - /** - * @brief The number of macro tiles in the C tensor. - */ index_t num_tiles_; - /** - * @brief The maximum number of active workgroups; this is assumed to be number of CUs * - * occupancy. - */ index_t grid_; - /** - * @brief The number of tiles in the C tensor that will use the data-parallel (DP) approach. - */ index_t dp_tiles_; private: /** - * @brief The number of full tiles assigned to each `sk_cta` when performing DP + 2 Tile SK. + * @brief The number of full tiles assigned to each `sk_cta` when performing DP + 2 Tile + * Stream-K. */ index_t full_tiles_ = 1; - /** - * @brief The number of tiles in the C tensor that will use the Stream-K approach. - */ index_t sk_tiles_; - /** - * @brief The number of workgroups that will participate in Stream-K in the `sk_tiles_`. - */ index_t sk_ctas_; - /** - * @brief The total number of Stream-K iterations. - */ index_t total_sk_iters_; - /** - * @brief The total number of iterations per tile in the C tensor. In other words, this is the - * total number of macro tiles along the K dimension of A and B. - */ index_t iters_per_tile_; - /** - * @brief The total number of Stream-K iterations for each `sk_cta`. This is the lower bound - * (i.e., all `sk_ctas_` are guaranteed to perform at least this many iterations). - */ index_t iters_per_sk_cta_; - /** - * @brief The remainder resulting from `total_sk_iters_` divided by `sk_ctas_`. When this is - * non-zero, the first `extra_iters_` `sk_ctas_` will get one additional iteration assigned to - * them; such work groups will perform (`iters_per_sk_cta_` + 1) iterations. - */ index_t extra_iters_; - /** - * @brief The total number of DP iterations. - */ index_t total_dp_iters_; - /** - * @brief The n dimension for the GEMM problem. - */ index_t n_; }; @@ -207,15 +212,17 @@ struct StreamKTilePartitionerBase * * This partitioner is responsible for mapping workgroups to tiles in the C tensor * for the Stream-K algorithm. This struct is derived from - * StreamKTilePartitionerBase. Behavior of the + * StreamKTilePartitionerBase. Behavior of the * StreamKTilePartitioner based on persistency will be in the template specializations. * - * @tparam BlockGemmShapeType A class providing basic GEMM parameters. - * @tparam ReductionStrategy An enum that defines the reduction strategy for the results in the C - * Tensor. + * @tparam BlockGemmShapeType A class providing basic GEMM parameters. + * @tparam ReductionStrategyType An enum that defines the reduction strategy for the results in + * the C Tensor. * @tparam Persistent A bool that indicates whether to use a Persistent approach */ -template +template struct StreamKTilePartitioner_v2; /** @@ -225,13 +232,13 @@ struct StreamKTilePartitioner_v2; * for the Stream-K algorithm when using a Persistent approach where no extra workgroups * are allocated for data parallel. * - * @tparam BlockGemmShapeType A class providing basic GEMM parameters. - * @tparam ReductionStrategy An enum that defines the reduction strategy for the results in the C - * Tensor. + * @tparam BlockGemmShapeType A class providing basic GEMM parameters. + * @tparam ReductionStrategyType An enum that defines the reduction strategy for the results in + * the C Tensor. */ -template -struct StreamKTilePartitioner_v2 - : StreamKTilePartitionerBase +template +struct StreamKTilePartitioner_v2 + : StreamKTilePartitionerBase { StreamKTilePartitioner_v2(ck_tile::index_t m, ck_tile::index_t n, @@ -248,19 +255,20 @@ struct StreamKTilePartitioner_v2 */ CK_TILE_HOST auto grid_size() const noexcept -> dim3; + /** + * @brief Returns the total number of DP tiles per workgroup. + */ CK_TILE_HOST_DEVICE index_t get_dp_tiles_per_cta() const noexcept; + + /** + * @brief Returns the total number of DP tiles left over when `dp_tiles_` is not evenly + * divisible by `grid_`. + */ CK_TILE_HOST_DEVICE index_t get_extra_dp_tiles() const noexcept; protected: - /** - * @brief The total number of DP tiles per workgroup. - */ - int dp_tiles_per_cta_; - - /** - * @brief The total number of DP tiles left over when dp_tiles is not evenly divisible by grid. - */ - int extra_dp_tiles_; + index_t dp_tiles_per_cta_; + index_t extra_dp_tiles_; }; /** @@ -271,12 +279,12 @@ struct StreamKTilePartitioner_v2 * are allocated for the data parallel section. * * @tparam BlockGemmShapeType A class providing basic GEMM parameters. - * @tparam ReductionStrategy An enum that defines the reduction strategy for the results in the C - * Tensor. + * @tparam ReductionStrategyType An enum that defines the reduction strategy for the results in + * the C Tensor. */ -template -struct StreamKTilePartitioner_v2 - : StreamKTilePartitionerBase +template +struct StreamKTilePartitioner_v2 + : StreamKTilePartitionerBase { StreamKTilePartitioner_v2(ck_tile::index_t m, ck_tile::index_t n, @@ -292,25 +300,26 @@ struct StreamKTilePartitioner_v2 * @return dim_3 The launching grid size for the kernel. */ CK_TILE_HOST auto grid_size() const noexcept -> dim3; + + /** + * @brief Returns the total number of DP workgroups. + */ CK_TILE_HOST_DEVICE index_t get_dp_ctas() const noexcept; + + /** + * @brief Returns starting DP workgroup index. It is always zero. + */ CK_TILE_HOST_DEVICE index_t get_dp_start_block_idx() const noexcept; + + /** + * @brief The index that starts the Stream-K workgroups. It is set to the number of `dp_tiles_`. + */ CK_TILE_HOST_DEVICE index_t get_sk_start_block_idx() const noexcept; protected: - /** - * @brief The total number of DP workgroups. - */ - int dp_ctas_; - - /** - * @brief The index that starts the DP workgroups, always 0 in our implementation. - */ - int dp_start_block_idx_; - - /** - * @brief The index that starts the Stream-K workgroups, set to the number of dp_tiles. - */ - int sk_start_block_idx_; + index_t dp_ctas_; + index_t dp_start_block_idx_; + index_t sk_start_block_idx_; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp index cb31839546..0dba775182 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp @@ -1,10 +1,11 @@ // Copyright © Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT - +#pragma once +#include "streamk_gemm_tile_partitioner.hpp" namespace ck_tile { -template -StreamKTilePartitionerBase::StreamKTilePartitionerBase( +template +StreamKTilePartitionerBase::StreamKTilePartitionerBase( index_t m, index_t n, index_t k, index_t grid) : grid_{grid}, n_{n} { @@ -43,68 +44,68 @@ StreamKTilePartitionerBase::StreamKTilePa total_dp_iters_ = dp_tiles_ * iters_per_tile_; } -template +template CK_TILE_HOST index_t -StreamKTilePartitionerBase::get_partials_buffer_size( +StreamKTilePartitionerBase::get_partials_buffer_size( index_t acc_element_bytes) const noexcept { return MPerBlock * NPerBlock * acc_element_bytes * sk_ctas_; } -template +template CK_TILE_HOST index_t -StreamKTilePartitionerBase::get_flags_buffer_size() +StreamKTilePartitionerBase::get_flags_buffer_size() const noexcept { return sizeof(index_t) * sk_ctas_; } -template +template CK_TILE_DEVICE void -StreamKTilePartitionerBase::get_iter_boundaries( +StreamKTilePartitionerBase::get_iter_boundaries( index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept { - index_t extra_iters__before_me = ck_tile::min(cta_idx, extra_iters_); - iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters__before_me; + index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_); + iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me; iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_); } -template +template CK_TILE_DEVICE index_t -StreamKTilePartitionerBase::get_tile_index( +StreamKTilePartitionerBase::get_tile_index( index_t iter) const noexcept { return iter / iters_per_tile_; } -template +template CK_TILE_DEVICE void -StreamKTilePartitionerBase::get_tile_boundaries( +StreamKTilePartitionerBase::get_tile_boundaries( index_t& tile_iter, index_t& tile_iter_end, index_t tile_idx) const noexcept { tile_iter = tile_idx * iters_per_tile_; tile_iter_end = tile_iter + iters_per_tile_; } -template +template CK_TILE_DEVICE /* static */ index_t -StreamKTilePartitionerBase::get_local_iter( +StreamKTilePartitionerBase::get_local_iter( index_t iter, index_t tile_iter) noexcept { return iter - tile_iter; } -template +template CK_TILE_DEVICE /* static */ index_t -StreamKTilePartitionerBase::get_local_iter_end( +StreamKTilePartitionerBase::get_local_iter_end( index_t tile_iter, index_t iter_end, index_t tile_iter_end) noexcept { return ck_tile::min(iter_end, tile_iter_end) - tile_iter; } -template +template CK_TILE_DEVICE auto -StreamKTilePartitionerBase::get_output_tile_index( +StreamKTilePartitionerBase::get_output_tile_index( index_t tile_idx) const noexcept -> tuple { const index_t n_macro_tiles = integer_divide_ceil(n_, NPerBlock); @@ -114,12 +115,12 @@ StreamKTilePartitionerBase::get_output_ti return make_tuple(im, in); } -template +template CK_TILE_HOST index_t -StreamKTilePartitionerBase::get_workspace_size( +StreamKTilePartitionerBase::get_workspace_size( index_t acc_element_bytes) const noexcept { - if constexpr(StreamKReductionStrategy == StreamKReductionStrategy::Reduction) + if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) { return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size(); @@ -130,104 +131,111 @@ StreamKTilePartitionerBase::get_workspace } } -template +template CK_TILE_HOST_DEVICE index_t -StreamKTilePartitionerBase::get_num_tiles() const noexcept +StreamKTilePartitionerBase::get_num_tiles() + const noexcept { return num_tiles_; } -template +template CK_TILE_HOST_DEVICE index_t -StreamKTilePartitionerBase::get_grid() const noexcept +StreamKTilePartitionerBase::get_grid() const noexcept { return grid_; } -template +template CK_TILE_HOST_DEVICE index_t -StreamKTilePartitionerBase::get_dp_tiles() const noexcept +StreamKTilePartitionerBase::get_dp_tiles() const noexcept { return dp_tiles_; } -template +template CK_TILE_HOST_DEVICE index_t -StreamKTilePartitionerBase::get_sk_tiles() const noexcept +StreamKTilePartitionerBase::get_sk_tiles() const noexcept { return sk_tiles_; } -template +template CK_TILE_HOST_DEVICE index_t -StreamKTilePartitionerBase::get_sk_ctas() const noexcept +StreamKTilePartitionerBase::get_sk_ctas() const noexcept { return sk_ctas_; } -template +template CK_TILE_HOST_DEVICE index_t -StreamKTilePartitionerBase::get_total_sk_iters() +StreamKTilePartitionerBase::get_total_sk_iters() const noexcept { return total_sk_iters_; } -template +template CK_TILE_HOST_DEVICE index_t -StreamKTilePartitionerBase::get_iters_per_tile() +StreamKTilePartitionerBase::get_iters_per_tile() const noexcept { return iters_per_tile_; } -template +template CK_TILE_HOST_DEVICE index_t -StreamKTilePartitionerBase::get_iters_per_sk_cta() +StreamKTilePartitionerBase::get_iters_per_sk_cta() const noexcept { return iters_per_sk_cta_; } -template +template CK_TILE_HOST_DEVICE index_t -StreamKTilePartitionerBase::get_extra_iters() const noexcept +StreamKTilePartitionerBase::get_extra_iters() + const noexcept { return extra_iters_; } -template +template CK_TILE_HOST_DEVICE index_t -StreamKTilePartitionerBase::get_total_dp_iters() +StreamKTilePartitionerBase::get_total_dp_iters() const noexcept { return total_dp_iters_; } -template +template CK_TILE_HOST_DEVICE index_t -StreamKTilePartitionerBase::get_n() const noexcept +StreamKTilePartitionerBase::get_n() const noexcept { return n_; } -template +template struct StreamKTilePartitioner_v2; // child class for Persistent Tile Partitioner -template -StreamKTilePartitioner_v2::StreamKTilePartitioner_v2( - ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid) - : StreamKTilePartitionerBase(m, n, k, grid) +template +StreamKTilePartitioner_v2:: + StreamKTilePartitioner_v2(ck_tile::index_t m, + ck_tile::index_t n, + ck_tile::index_t k, + ck_tile::index_t grid) + : StreamKTilePartitionerBase(m, n, k, grid) { // inherit from base constructor dp_tiles_per_cta_ = this->dp_tiles_ / this->grid_; extra_dp_tiles_ = this->dp_tiles_ % this->grid_; } -template +template CK_TILE_HOST auto -StreamKTilePartitioner_v2::grid_size() const noexcept - -> dim3 +StreamKTilePartitioner_v2::grid_size() + const noexcept -> dim3 { if(extra_dp_tiles_ == 0) { @@ -239,61 +247,64 @@ StreamKTilePartitioner_v2::grid_siz } } -template +template CK_TILE_HOST_DEVICE index_t -StreamKTilePartitioner_v2::get_dp_tiles_per_cta() +StreamKTilePartitioner_v2::get_dp_tiles_per_cta() const noexcept { return dp_tiles_per_cta_; } -template +template CK_TILE_HOST_DEVICE index_t -StreamKTilePartitioner_v2::get_extra_dp_tiles() +StreamKTilePartitioner_v2::get_extra_dp_tiles() const noexcept { return extra_dp_tiles_; } // child class for Non-Persistent Tile Partitioner -template -StreamKTilePartitioner_v2::StreamKTilePartitioner_v2( - ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid) - : StreamKTilePartitionerBase(m, n, k, grid) +template +StreamKTilePartitioner_v2:: + StreamKTilePartitioner_v2(ck_tile::index_t m, + ck_tile::index_t n, + ck_tile::index_t k, + ck_tile::index_t grid) + : StreamKTilePartitionerBase(m, n, k, grid) { // inherit from base constructor dp_ctas_ = this->dp_tiles_; dp_start_block_idx_ = 0; sk_start_block_idx_ = this->dp_tiles_; } -template +template CK_TILE_HOST auto -StreamKTilePartitioner_v2::grid_size() const noexcept - -> dim3 +StreamKTilePartitioner_v2::grid_size() + const noexcept -> dim3 { return dim3(dp_ctas_ + this->get_sk_ctas(), 1, 1); } -template +template CK_TILE_HOST_DEVICE index_t -StreamKTilePartitioner_v2::get_dp_ctas() +StreamKTilePartitioner_v2::get_dp_ctas() const noexcept { return dp_ctas_; } -template +template CK_TILE_HOST_DEVICE index_t -StreamKTilePartitioner_v2::get_dp_start_block_idx() - const noexcept +StreamKTilePartitioner_v2:: + get_dp_start_block_idx() const noexcept { return dp_start_block_idx_; } -template +template CK_TILE_HOST_DEVICE index_t -StreamKTilePartitioner_v2::get_sk_start_block_idx() - const noexcept +StreamKTilePartitioner_v2:: + get_sk_start_block_idx() const noexcept { return sk_start_block_idx_; } diff --git a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp index 968fadda51..89d72d844b 100644 --- a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp +++ b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp @@ -86,20 +86,24 @@ TEST(StreamKTilePartitionerBaseGetLocalIter, GetLocalIter) StreamKTilePartitionerBaseMethodId::GET_LOCAL_ITER>; // Test parameters - ck_tile::DeviceMem local_iter_dev(sizeof(ck_tile::index_t)); - ck_tile::index_t iter = 3; - ck_tile::index_t tile_iter = 2; + ck_tile::DeviceMem local_iter_start_dev(sizeof(ck_tile::index_t)); + ck_tile::index_t iter_start = 3; + ck_tile::index_t tile_iter_start = 2; // Launch kernel - auto kargs = Kernel::MakeKernelArgs( - iter, tile_iter, Config::UNUSED, local_iter_dev.GetDeviceBuffer(), nullptr, Config::UNUSED); + auto kargs = Kernel::MakeKernelArgs(iter_start, + tile_iter_start, + Config::UNUSED, + local_iter_start_dev.GetDeviceBuffer(), + nullptr, + Config::UNUSED); ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1}, ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs)); // Validate result - ck_tile::index_t local_iter; - local_iter_dev.FromDevice(&local_iter); - EXPECT_EQ(local_iter, iter - tile_iter); + ck_tile::index_t local_iter_start; + local_iter_start_dev.FromDevice(&local_iter_start); + EXPECT_EQ(local_iter_start, iter_start - tile_iter_start); } TEST(StreamKTilePartitionerBaseGetLocalIterEnd, MinIsTileIterEnd) @@ -111,12 +115,12 @@ TEST(StreamKTilePartitionerBaseGetLocalIterEnd, MinIsTileIterEnd) StreamKTilePartitionerBaseMethodId::GET_LOCAL_ITER_END>; // Test parameters ck_tile::DeviceMem local_iter_end_dev(sizeof(ck_tile::index_t)); - ck_tile::index_t tile_iter = 6; - ck_tile::index_t iter_end = 9; - ck_tile::index_t tile_iter_end = 8; + ck_tile::index_t tile_iter_start = 6; + ck_tile::index_t iter_end = 9; + ck_tile::index_t tile_iter_end = 8; // Launch kernel - auto kargs = Kernel::MakeKernelArgs(tile_iter, + auto kargs = Kernel::MakeKernelArgs(tile_iter_start, iter_end, tile_iter_end, local_iter_end_dev.GetDeviceBuffer(), @@ -128,13 +132,13 @@ TEST(StreamKTilePartitionerBaseGetLocalIterEnd, MinIsTileIterEnd) // Validate results ck_tile::index_t local_iter_end; local_iter_end_dev.FromDevice(&local_iter_end); - EXPECT_EQ(local_iter_end, tile_iter_end - tile_iter); + EXPECT_EQ(local_iter_end, tile_iter_end - tile_iter_start); } TEST(StreamKTilePartitionerBaseGetLocalIterEnd, MinIsIterEnd) { // Types - // Note: For this test, the Config is used for types only, the function get_locatl_iter_end is + // Note: For this test, the Config is used for types only, the function get_local_iter_end is // static; thus, the test parameters are independent of the Config in this case. using Config = StreamKTilePartitionerBaseConfigDP2TileSK; using TilePartitioner = ck_tile::StreamKTilePartitionerBase; @@ -142,12 +146,12 @@ TEST(StreamKTilePartitionerBaseGetLocalIterEnd, MinIsIterEnd) StreamKTilePartitionerBaseMethodId::GET_LOCAL_ITER_END>; // Test parameters ck_tile::DeviceMem local_iter_end_dev(sizeof(ck_tile::index_t)); - ck_tile::index_t tile_iter = 12; - ck_tile::index_t iter_end = 13; - ck_tile::index_t tile_iter_end = 14; + ck_tile::index_t tile_iter_start = 12; + ck_tile::index_t iter_end = 13; + ck_tile::index_t tile_iter_end = 14; // Launch kernel - auto kargs = Kernel::MakeKernelArgs(tile_iter, + auto kargs = Kernel::MakeKernelArgs(tile_iter_start, iter_end, tile_iter_end, local_iter_end_dev.GetDeviceBuffer(), @@ -159,7 +163,7 @@ TEST(StreamKTilePartitionerBaseGetLocalIterEnd, MinIsIterEnd) // Validate results ck_tile::index_t local_iter_end; local_iter_end_dev.FromDevice(&local_iter_end); - EXPECT_EQ(local_iter_end, iter_end - tile_iter); + EXPECT_EQ(local_iter_end, iter_end - tile_iter_start); } TEST(StreamKTilePartitionerBaseGetTileBoundaries, GetTileBoundaries) @@ -174,7 +178,7 @@ TEST(StreamKTilePartitionerBaseGetTileBoundaries, GetTileBoundaries) // Test parameters ck_tile::StreamKTilePartitionerBase tile_partitioner{ Config::M, Config::N, Config::K, Config::GRID}; - ck_tile::DeviceMem tile_iter_dev(sizeof(ck_tile::index_t)); + ck_tile::DeviceMem tile_iter_start_dev(sizeof(ck_tile::index_t)); ck_tile::DeviceMem tile_iter_end_dev(sizeof(ck_tile::index_t)); ck_tile::index_t tile_idx = 1; @@ -182,19 +186,19 @@ TEST(StreamKTilePartitionerBaseGetTileBoundaries, GetTileBoundaries) auto kargs = Kernel::MakeKernelArgs(Config::PLACEHOLDER, Config::PLACEHOLDER, tile_idx, - tile_iter_dev.GetDeviceBuffer(), + tile_iter_start_dev.GetDeviceBuffer(), tile_iter_end_dev.GetDeviceBuffer(), tile_partitioner); ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1}, ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs)); // Validate results - ck_tile::index_t tile_iter, tile_iter_end; - tile_iter_dev.FromDevice(&tile_iter); + ck_tile::index_t tile_iter_start, tile_iter_end; + tile_iter_start_dev.FromDevice(&tile_iter_start); tile_iter_end_dev.FromDevice(&tile_iter_end); // There are 2 iters per tile. Thus, for tile_idx 1, we expect 2 and 4 to be the start and end, // respectively. - EXPECT_EQ(tile_iter, 2); + EXPECT_EQ(tile_iter_start, 2); EXPECT_EQ(tile_iter_end, 4); } @@ -210,10 +214,10 @@ TEST(StreamKTilePartitionerBaseGetTileIndex, GetTileIndex) ck_tile::StreamKTilePartitionerBase tile_partitioner{ Config::M, Config::N, Config::K, Config::GRID}; ck_tile::DeviceMem tile_idx_dev(sizeof(ck_tile::index_t)); - ck_tile::index_t iter = 8; + ck_tile::index_t iter_start = 8; // Launch kernel - auto kargs = Kernel::MakeKernelArgs(iter, + auto kargs = Kernel::MakeKernelArgs(iter_start, Config::UNUSED, Config::UNUSED, tile_idx_dev.GetDeviceBuffer(), @@ -241,7 +245,7 @@ TEST(StreamKTilePartitionerBaseGetIterBoundaries, ZeroExtraItersBeforeMe) // Test parameters ck_tile::StreamKTilePartitionerBase tile_partitioner{ Config::M, Config::N, Config::K, Config::GRID}; - ck_tile::DeviceMem iter_dev(sizeof(ck_tile::index_t)); + ck_tile::DeviceMem iter_start_dev(sizeof(ck_tile::index_t)); ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t)); ck_tile::index_t cta_idx = 0; @@ -249,17 +253,17 @@ TEST(StreamKTilePartitionerBaseGetIterBoundaries, ZeroExtraItersBeforeMe) auto kargs = Kernel::MakeKernelArgs(Config::PLACEHOLDER, Config::PLACEHOLDER, cta_idx, - iter_dev.GetDeviceBuffer(), + iter_start_dev.GetDeviceBuffer(), iter_end_dev.GetDeviceBuffer(), tile_partitioner); ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1}, ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs)); // Validate results - ck_tile::index_t iter, iter_end; - iter_dev.FromDevice(&iter); + ck_tile::index_t iter_start, iter_end; + iter_start_dev.FromDevice(&iter_start); iter_end_dev.FromDevice(&iter_end); - EXPECT_EQ(iter, 6); + EXPECT_EQ(iter_start, 6); EXPECT_EQ(iter_end, 9); } @@ -275,7 +279,7 @@ TEST(StreamKTilePartitionerBaseGetIterBoundaries, NonZeroExtraItersBeforeMe) // Test parameters ck_tile::StreamKTilePartitionerBase tile_partitioner{ Config::M, Config::N, Config::K, Config::GRID}; - ck_tile::DeviceMem iter_dev(sizeof(ck_tile::index_t)); + ck_tile::DeviceMem iter_start_dev(sizeof(ck_tile::index_t)); ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t)); ck_tile::index_t cta_idx = 1; @@ -283,17 +287,17 @@ TEST(StreamKTilePartitionerBaseGetIterBoundaries, NonZeroExtraItersBeforeMe) auto kargs = Kernel::MakeKernelArgs(Config::PLACEHOLDER, Config::PLACEHOLDER, cta_idx, - iter_dev.GetDeviceBuffer(), + iter_start_dev.GetDeviceBuffer(), iter_end_dev.GetDeviceBuffer(), tile_partitioner); ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1}, ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs)); // Validate results - ck_tile::index_t iter, iter_end; - iter_dev.FromDevice(&iter); + ck_tile::index_t iter_start, iter_end; + iter_start_dev.FromDevice(&iter_start); iter_end_dev.FromDevice(&iter_end); - EXPECT_EQ(iter, 9); + EXPECT_EQ(iter_start, 9); EXPECT_EQ(iter_end, 12); } @@ -309,7 +313,7 @@ TEST(StreamKTilePartitionerBaseGetIterBoundaries, MinIsExtraIters) // Test parameters ck_tile::StreamKTilePartitionerBase tile_partitioner{ Config::M, Config::N, Config::K, Config::GRID}; - ck_tile::DeviceMem iter_dev(sizeof(ck_tile::index_t)); + ck_tile::DeviceMem iter_start_dev(sizeof(ck_tile::index_t)); ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t)); ck_tile::index_t cta_idx = 2; @@ -317,17 +321,17 @@ TEST(StreamKTilePartitionerBaseGetIterBoundaries, MinIsExtraIters) auto kargs = Kernel::MakeKernelArgs(Config::PLACEHOLDER, Config::PLACEHOLDER, cta_idx, - iter_dev.GetDeviceBuffer(), + iter_start_dev.GetDeviceBuffer(), iter_end_dev.GetDeviceBuffer(), tile_partitioner); ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1}, ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs)); // Validate results - ck_tile::index_t iter, iter_end; - iter_dev.FromDevice(&iter); + ck_tile::index_t iter_start, iter_end; + iter_start_dev.FromDevice(&iter_start); iter_end_dev.FromDevice(&iter_end); - EXPECT_EQ(iter, 12); + EXPECT_EQ(iter_start, 12); EXPECT_EQ(iter_end, 14); } diff --git a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp index 03f149f6b6..4fc654a7ea 100644 --- a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp +++ b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp @@ -45,9 +45,7 @@ struct KernelWrapper // Specialized derived class to support unique operator() functions. There is one template // specialization per member in the StreamKTilePartitionerBaseMethodId enum. template -struct KernelWrapperSpecialized : public KernelWrapper<> -{ -}; +struct KernelWrapperSpecialized; template struct KernelWrapperSpecialized