From 1b7c5502e2116a829ca80e9c6c839eba68a32b48 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Thu, 16 Oct 2025 16:13:55 +0000 Subject: [PATCH] Merge commit 'd7278cc664c20613e0b7c45f249f6e7613550ca2' into develop --- .github/workflows/therock-ci-linux.yml | 4 +- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 2 +- include/ck_tile/ops/gemm.hpp | 1 + .../kernel/streamk_gemm_tile_partitioner.hpp | 327 ++++++++++++ .../streamk_gemm_tile_partitioner_impl.hpp | 312 +++++++++++ test/ck_tile/gemm_streamk/CMakeLists.txt | 3 +- .../test_streamk_tile_partitioner.cpp | 498 ++++++++++++++++++ .../test_streamk_tile_partitioner_common.hpp | 339 ++++++++++++ 8 files changed, 1482 insertions(+), 4 deletions(-) create mode 100644 include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp create mode 100644 include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp create mode 100644 test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp create mode 100644 test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp diff --git a/.github/workflows/therock-ci-linux.yml b/.github/workflows/therock-ci-linux.yml index 271c6376ca..beaabbe763 100644 --- a/.github/workflows/therock-ci-linux.yml +++ b/.github/workflows/therock-ci-linux.yml @@ -20,7 +20,7 @@ jobs: permissions: id-token: write container: - image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:044b113562629f4bd2ec5d2e64b32eee11562d48fb1a75d7493daec9dd8d8292 + image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:2f3ebd0beb04c449fdb36933e54bdc69483b914fb9005594d3fc9444c206b54b options: -v /runner/config:/home/awsconfig/ env: AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} @@ -44,7 +44,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: "ROCm/TheRock" - ref: dc05d637054ad197c84b00e24b6262af0ec797c6 # 10-03-2025 commit + ref: c2921b151b8285a1d29942aceb33cfe0fea77ac9 # 10-15-2025 commit path: "TheRock" - name: Setup ccache diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 533f7f2f23..f898d5f7b2 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -608,7 +608,7 @@ class KernelComponentFactory: else: pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) - if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and skip == "f": + if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f": pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't')) pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) if receipt == 1 and bias != "bias": diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 2a4f9d21e3..6b587f81d5 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -33,6 +33,7 @@ #include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" +#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp new file mode 100644 index 0000000000..1962f3518a --- /dev/null +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp @@ -0,0 +1,327 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" + +namespace ck_tile { + +/** + * @brief Stream-K tile partitioner base class. + * + * This partitioner is responsible for mapping workgroups to tiles in the C tensor + * for the Stream-K algorithm. + * + * @tparam BlockGemmShapeType A class providing basic GEMM parameters. + * @tparam ReductionStrategyType An enum that defines the reduction strategy for the results in + * the C Tensor. + */ +template +struct StreamKTilePartitionerBase +{ + using BlockGemmShape = BlockGemmShapeType; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategyType; + + StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid); + + private: + /** + * @brief Calculates the total space needed for the partials buffer. + * + * @param acc_element_bytes The number of bytes for the accumulator data type used in the GEMM. + * @return index_t The number of bytes needed for the partials buffer. + */ + CK_TILE_HOST index_t get_partials_buffer_size(index_t acc_element_bytes) const noexcept; + + /** + * @brief Calculates the total space needed for the flags buffer. + * + * @return index_t The number of bytes needed for the flags buffer. + */ + CK_TILE_HOST index_t get_flags_buffer_size() const noexcept; + + public: + /** + * @brief Calculates the start and end iteration given the cta_idx. + * + * @param iter_start Reference to an index_t; will be set to the starting iteration by the + * function. + * @param iter_end Reference to an index_t; will be set to the non-inclusive end iteration by + * the function. + * @param cta_idx The current Stream-K workgroup's index. + * @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a + * non-persistent DP section is used, then a Stream-K workgroup's `cta_idx` should be something + * like `blockIdx.x` minus number of DP workgroups. + */ + CK_TILE_DEVICE void + get_iter_boundaries(index_t& iter_start, index_t& iter_end, index_t cta_idx) const noexcept; + + /** + * @brief Calculates the 1D tile index in the C tensor for a workgroup. + * + * @param iter_start The starting iteration. + * @return index_t The 1D tile index. + */ + CK_TILE_DEVICE index_t get_tile_index(index_t iter_start) const noexcept; + + /** + * @brief Calculates the starting and ending tile boundaries for the given 1D tile index. + * + * @param tile_iter_start Reference to an index_t; will be set to the tile's start iteration by + * the function. + * @param tile_iter_end Reference to an index_t; will be set to the non-inclusive tile's end + * iteration by the function. + * @param tile_idx The 1D C tensor tile index for the workgroup. + */ + CK_TILE_DEVICE void get_tile_boundaries(index_t& tile_iter_start, + index_t& tile_iter_end, + index_t tile_idx) const noexcept; + + /** + * @brief Calculates the workgroup's starting iteration that is local to a tile. + * + * @param iter_start The starting iteration. + * @param tile_iter_start The starting iteration of the tile (i.e., the tile's starting + * boundary). + * @return index_t The local starting iteration. The value is in range [0, `iters_per_tile_`). + * @note Assumes `iter_start` >= `tile_iter_start`. + */ + CK_TILE_DEVICE static index_t get_local_iter(index_t iter_start, + index_t tile_iter_start) noexcept; + + /** + * @brief Calculates the workgroup's non-inclusive end iteration that is local to a tile. + * + * @param tile_iter_start The starting tile iteration. + * @param iter_end The non-inclusive end iteration. + * @param tile_iter_end The non-inclusive end iteration of the tile. + * @return index_t The local non-inclusive end iteration. + * @note Assumes `iter_end` >= `tile_iter_start` and `tile_iter_end` >= `tile_iter_start`. + */ + CK_TILE_DEVICE static index_t + get_local_iter_end(index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept; + + /** + * @brief Calculates the workgroups 2D tile index in the C tensor given the 1D tile index. + * + * @param tile_idx The 1D tile index in the C tensor for the workgroup. + * @return index_t The corresponding 2D tile index in the C tensor for the workgroup. + */ + CK_TILE_DEVICE auto + get_output_tile_index(index_t tile_idx) const noexcept -> tuple; + + /** + * @brief Calculates the total space needed for the partials and flags buffers. + * + * @param acc_element_bytes The number of bytes for the accumulator data type used in the GEMM. + * @return index_t The number of bytes needed for the partials and flags buffers. + */ + CK_TILE_HOST index_t get_workspace_size(index_t acc_element_bytes) const noexcept; + + /** + * @brief Returns the number of macro tiles in the C tensor. + */ + CK_TILE_HOST_DEVICE index_t get_num_tiles() const noexcept; + + /** + * @brief Returns the maximum number of active workgroups; this is assumed to be number of CUs * + * occupancy. + */ + CK_TILE_HOST_DEVICE index_t get_grid() const noexcept; + + /** + * @brief Returns the number of tiles in the C tensor that will use the data-parallel (DP) + * approach. + */ + CK_TILE_HOST_DEVICE index_t get_dp_tiles() const noexcept; + + /** + * @brief Returns the number of tiles in the C tensor that will use the Stream-K approach. + */ + CK_TILE_HOST_DEVICE index_t get_sk_tiles() const noexcept; + + /** + * @brief Returns the number of workgroups that will participate in Stream-K in the `sk_tiles_`. + */ + CK_TILE_HOST_DEVICE index_t get_sk_ctas() const noexcept; + + /** + * @brief Returns the total number of Stream-K iterations. + */ + CK_TILE_HOST_DEVICE index_t get_total_sk_iters() const noexcept; + + /** + * @brief Returns the total number of iterations per tile in the C tensor. In other words, this + * is the total number of macro tiles along the K dimension of A and B. + */ + CK_TILE_HOST_DEVICE index_t get_iters_per_tile() const noexcept; + + /** + * @brief Returns the total number of Stream-K iterations for each `sk_cta`. This is the lower + * bound (i.e., all `sk_ctas_` are guaranteed to perform at least this many iterations). + */ + CK_TILE_HOST_DEVICE index_t get_iters_per_sk_cta() const noexcept; + + /** + * @brief Returns the remainder resulting from `total_sk_iters_` divided by `sk_ctas_`. When + * this is non-zero, the first `extra_iters_` `sk_ctas_` will get one additional iteration + * assigned to them; such work groups will perform (`iters_per_sk_cta_` + 1) iterations. + */ + CK_TILE_HOST_DEVICE index_t get_extra_iters() const noexcept; + + /** + * @brief Returns the total number of DP iterations. + */ + CK_TILE_HOST_DEVICE index_t get_total_dp_iters() const noexcept; + + /** + * @brief Returns the n dimension for the GEMM problem. + */ + CK_TILE_HOST_DEVICE index_t get_n() const noexcept; + + protected: + index_t num_tiles_; + index_t grid_; + index_t dp_tiles_; + + private: + /** + * @brief The number of full tiles assigned to each `sk_cta` when performing DP + 2 Tile + * Stream-K. + */ + index_t full_tiles_ = 1; + index_t sk_tiles_; + index_t sk_ctas_; + index_t total_sk_iters_; + index_t iters_per_tile_; + index_t iters_per_sk_cta_; + index_t extra_iters_; + index_t total_dp_iters_; + index_t n_; +}; + +/** + * @brief Template for the Stream-K tile partitioner derived struct. + * + * This partitioner is responsible for mapping workgroups to tiles in the C tensor + * for the Stream-K algorithm. This struct is derived from + * StreamKTilePartitionerBase. Behavior of the + * StreamKTilePartitioner based on persistency will be in the template specializations. + * + * @tparam BlockGemmShapeType A class providing basic GEMM parameters. + * @tparam ReductionStrategyType An enum that defines the reduction strategy for the results in + * the C Tensor. + * @tparam Persistent A bool that indicates whether to use a Persistent approach + */ +template +struct StreamKTilePartitioner_v2; + +/** + * @brief Persistent Stream-K tile partitioner derived struct. + * + * This partitioner is responsible for mapping workgroups to tiles in the C tensor + * for the Stream-K algorithm when using a Persistent approach where no extra workgroups + * are allocated for data parallel. + * + * @tparam BlockGemmShapeType A class providing basic GEMM parameters. + * @tparam ReductionStrategyType An enum that defines the reduction strategy for the results in + * the C Tensor. + */ +template +struct StreamKTilePartitioner_v2 + : StreamKTilePartitionerBase +{ + StreamKTilePartitioner_v2(ck_tile::index_t m, + ck_tile::index_t n, + ck_tile::index_t k, + ck_tile::index_t grid); + + public: + /** + * @brief Calculates the launching grid size for the Stream-K kernel. In the Persistent + * case, no extra workgroups are allocated for the data parallel section, making the grid + * size num_cu * occupancy. + * + * @return dim_3 The launching grid size for the kernel. + */ + CK_TILE_HOST auto grid_size() const noexcept -> dim3; + + /** + * @brief Returns the total number of DP tiles per workgroup. + */ + CK_TILE_HOST_DEVICE index_t get_dp_tiles_per_cta() const noexcept; + + /** + * @brief Returns the total number of DP tiles left over when `dp_tiles_` is not evenly + * divisible by `grid_`. + */ + CK_TILE_HOST_DEVICE index_t get_extra_dp_tiles() const noexcept; + + protected: + index_t dp_tiles_per_cta_; + index_t extra_dp_tiles_; +}; + +/** + * @brief Non-Persistent Stream-K tile partitioner derived struct. + * + * This partitioner is responsible for mapping workgroups to tiles in the C tensor + * for the Stream-K algorithm when using a Non-Persistent approach where extra workgroups + * are allocated for the data parallel section. + * + * @tparam BlockGemmShapeType A class providing basic GEMM parameters. + * @tparam ReductionStrategyType An enum that defines the reduction strategy for the results in + * the C Tensor. + */ +template +struct StreamKTilePartitioner_v2 + : StreamKTilePartitionerBase +{ + StreamKTilePartitioner_v2(ck_tile::index_t m, + ck_tile::index_t n, + ck_tile::index_t k, + ck_tile::index_t grid); + + public: + /** + * @brief Calculates the launching grid size for the Stream-K kernel. In the Non-Persistent + * case, extra workgroups are allocated for the data parallel section, making the grid + * size the total number of Stream-K and data parallel workgroups. + * + * @return dim_3 The launching grid size for the kernel. + */ + CK_TILE_HOST auto grid_size() const noexcept -> dim3; + + /** + * @brief Returns the total number of DP workgroups. + */ + CK_TILE_HOST_DEVICE index_t get_dp_ctas() const noexcept; + + /** + * @brief Returns starting DP workgroup index. It is always zero. + */ + CK_TILE_HOST_DEVICE index_t get_dp_start_block_idx() const noexcept; + + /** + * @brief The index that starts the Stream-K workgroups. It is set to the number of `dp_tiles_`. + */ + CK_TILE_HOST_DEVICE index_t get_sk_start_block_idx() const noexcept; + + protected: + index_t dp_ctas_; + index_t dp_start_block_idx_; + index_t sk_start_block_idx_; +}; + +} // namespace ck_tile + +#include "streamk_gemm_tile_partitioner_impl.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp new file mode 100644 index 0000000000..0dba775182 --- /dev/null +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp @@ -0,0 +1,312 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once +#include "streamk_gemm_tile_partitioner.hpp" +namespace ck_tile { + +template +StreamKTilePartitionerBase::StreamKTilePartitionerBase( + index_t m, index_t n, index_t k, index_t grid) + : grid_{grid}, n_{n} +{ + iters_per_tile_ = integer_divide_ceil(k, KPerBlock); + num_tiles_ = integer_divide_ceil(m, MPerBlock) * integer_divide_ceil(n_, NPerBlock); + + bool big_enough = num_tiles_ > grid_; + index_t remainder_tiles = num_tiles_ % grid_; + + if(remainder_tiles) + { + sk_tiles_ = big_enough ? full_tiles_ * grid_ + (num_tiles_ % grid_) : num_tiles_; + sk_tiles_ = min(num_tiles_, sk_tiles_); + sk_ctas_ = grid_; + total_sk_iters_ = sk_tiles_ * iters_per_tile_; + + // If there still isn't enough work to saturate all CUs, then just revert to DP only. + if(total_sk_iters_ < grid_) + { + sk_tiles_ = 0; + sk_ctas_ = 0; + total_sk_iters_ = 0; + } + } + else // Full DP (i.e., no Stream-K) + { + sk_tiles_ = 0; + sk_ctas_ = 0; + total_sk_iters_ = 0; + } + + iters_per_sk_cta_ = sk_ctas_ ? total_sk_iters_ / sk_ctas_ : 0; + extra_iters_ = sk_ctas_ ? total_sk_iters_ % sk_ctas_ : 0; + + dp_tiles_ = num_tiles_ - sk_tiles_; + total_dp_iters_ = dp_tiles_ * iters_per_tile_; +} + +template +CK_TILE_HOST index_t +StreamKTilePartitionerBase::get_partials_buffer_size( + index_t acc_element_bytes) const noexcept +{ + return MPerBlock * NPerBlock * acc_element_bytes * sk_ctas_; +} + +template +CK_TILE_HOST index_t +StreamKTilePartitionerBase::get_flags_buffer_size() + const noexcept +{ + return sizeof(index_t) * sk_ctas_; +} + +template +CK_TILE_DEVICE void +StreamKTilePartitionerBase::get_iter_boundaries( + index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept +{ + index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_); + iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me; + iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_); +} + +template +CK_TILE_DEVICE index_t +StreamKTilePartitionerBase::get_tile_index( + index_t iter) const noexcept +{ + return iter / iters_per_tile_; +} + +template +CK_TILE_DEVICE void +StreamKTilePartitionerBase::get_tile_boundaries( + index_t& tile_iter, index_t& tile_iter_end, index_t tile_idx) const noexcept +{ + tile_iter = tile_idx * iters_per_tile_; + tile_iter_end = tile_iter + iters_per_tile_; +} + +template +CK_TILE_DEVICE /* static */ index_t +StreamKTilePartitionerBase::get_local_iter( + index_t iter, index_t tile_iter) noexcept +{ + return iter - tile_iter; +} + +template +CK_TILE_DEVICE /* static */ index_t +StreamKTilePartitionerBase::get_local_iter_end( + index_t tile_iter, index_t iter_end, index_t tile_iter_end) noexcept +{ + return ck_tile::min(iter_end, tile_iter_end) - tile_iter; +} + +template +CK_TILE_DEVICE auto +StreamKTilePartitionerBase::get_output_tile_index( + index_t tile_idx) const noexcept -> tuple +{ + const index_t n_macro_tiles = integer_divide_ceil(n_, NPerBlock); + + const index_t im = amd_wave_read_first_lane(tile_idx / n_macro_tiles); + const index_t in = amd_wave_read_first_lane(tile_idx - im * n_macro_tiles); + return make_tuple(im, in); +} + +template +CK_TILE_HOST index_t +StreamKTilePartitionerBase::get_workspace_size( + index_t acc_element_bytes) const noexcept +{ + if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) + { + + return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size(); + } + else // ReductionStrategy is Atomics + { + return 0; + } +} + +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitionerBase::get_num_tiles() + const noexcept +{ + return num_tiles_; +} + +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitionerBase::get_grid() const noexcept +{ + return grid_; +} + +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitionerBase::get_dp_tiles() const noexcept +{ + return dp_tiles_; +} + +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitionerBase::get_sk_tiles() const noexcept +{ + return sk_tiles_; +} + +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitionerBase::get_sk_ctas() const noexcept +{ + return sk_ctas_; +} + +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitionerBase::get_total_sk_iters() + const noexcept +{ + return total_sk_iters_; +} + +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitionerBase::get_iters_per_tile() + const noexcept +{ + return iters_per_tile_; +} + +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitionerBase::get_iters_per_sk_cta() + const noexcept +{ + return iters_per_sk_cta_; +} + +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitionerBase::get_extra_iters() + const noexcept +{ + return extra_iters_; +} + +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitionerBase::get_total_dp_iters() + const noexcept +{ + return total_dp_iters_; +} + +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitionerBase::get_n() const noexcept +{ + return n_; +} + +template +struct StreamKTilePartitioner_v2; + +// child class for Persistent Tile Partitioner +template +StreamKTilePartitioner_v2:: + StreamKTilePartitioner_v2(ck_tile::index_t m, + ck_tile::index_t n, + ck_tile::index_t k, + ck_tile::index_t grid) + : StreamKTilePartitionerBase(m, n, k, grid) +{ // inherit from base constructor + dp_tiles_per_cta_ = this->dp_tiles_ / this->grid_; + extra_dp_tiles_ = this->dp_tiles_ % this->grid_; +} + +template +CK_TILE_HOST auto +StreamKTilePartitioner_v2::grid_size() + const noexcept -> dim3 +{ + if(extra_dp_tiles_ == 0) + { + return dim3(this->grid_, 1, 1); + } + else + { + return dim3(this->num_tiles_, 1, 1); + } +} + +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitioner_v2::get_dp_tiles_per_cta() + const noexcept +{ + return dp_tiles_per_cta_; +} + +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitioner_v2::get_extra_dp_tiles() + const noexcept +{ + return extra_dp_tiles_; +} + +// child class for Non-Persistent Tile Partitioner +template +StreamKTilePartitioner_v2:: + StreamKTilePartitioner_v2(ck_tile::index_t m, + ck_tile::index_t n, + ck_tile::index_t k, + ck_tile::index_t grid) + : StreamKTilePartitionerBase(m, n, k, grid) +{ // inherit from base constructor + dp_ctas_ = this->dp_tiles_; + dp_start_block_idx_ = 0; + sk_start_block_idx_ = this->dp_tiles_; +} + +template +CK_TILE_HOST auto +StreamKTilePartitioner_v2::grid_size() + const noexcept -> dim3 +{ + return dim3(dp_ctas_ + this->get_sk_ctas(), 1, 1); +} + +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitioner_v2::get_dp_ctas() + const noexcept +{ + return dp_ctas_; +} + +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitioner_v2:: + get_dp_start_block_idx() const noexcept +{ + return dp_start_block_idx_; +} + +template +CK_TILE_HOST_DEVICE index_t +StreamKTilePartitioner_v2:: + get_sk_start_block_idx() const noexcept +{ + return sk_start_block_idx_; +} + +} // namespace ck_tile diff --git a/test/ck_tile/gemm_streamk/CMakeLists.txt b/test/ck_tile/gemm_streamk/CMakeLists.txt index ec5d56d46d..331118da59 100644 --- a/test/ck_tile/gemm_streamk/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk/CMakeLists.txt @@ -4,7 +4,7 @@ if(GPU_TARGETS MATCHES "gfx9") include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR}) #TODO: support all arches - #TODO: current stream-k c-shuffle only supports C layout as R + #TODO: current c-shuffle only supports C layout as R add_gtest_executable(test_ck_tile_streamk_smoke ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp #${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp @@ -116,6 +116,7 @@ if(GPU_TARGETS MATCHES "gfx9") # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp # ) + add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp) else() message(DEBUG "Skipping test_ck_tile_streamk tests for current target") endif() diff --git a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp new file mode 100644 index 0000000000..89d72d844b --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp @@ -0,0 +1,498 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_streamk_tile_partitioner_common.hpp" + +TEST(StreamKTilePartitionerBaseConstructor, SKOnly) +{ + using Config = StreamKTilePartitionerBaseConfigSKOnly; + + ck_tile::StreamKTilePartitionerBase tile_partitioner{ + Config::M, Config::N, Config::K, Config::GRID}; + + StreamKTilePartitionerBaseExpected expected_values{ + 2, 0, 3, 4, 1, 2, 1, 0, 2, Config::GRID, Config::N}; + validate_streamk_base_constructor(expected_values, tile_partitioner); +} + +TEST(StreamKTilePartitionerBaseConstructor, DPOnly) +{ + using Config = StreamKTilePartitionerBaseConfigDPOnly; + + ck_tile::StreamKTilePartitionerBase tile_partitioner{ + Config::M, Config::N, Config::K, Config::GRID}; + + StreamKTilePartitionerBaseExpected expected_values{ + 0, 6, 0, 0, 0, 2, 0, 12, 6, Config::GRID, Config::N}; + validate_streamk_base_constructor(expected_values, tile_partitioner); +} + +TEST(StreamKTilePartitionerBaseConstructor, DP2TileSK) +{ + using Config = StreamKTilePartitionerBaseConfigDP2TileSK; + + ck_tile::StreamKTilePartitionerBase tile_partitioner{ + Config::M, Config::N, Config::K, Config::GRID}; + + StreamKTilePartitionerBaseExpected expected_values{ + 4, 3, 3, 8, 2, 2, 2, 6, 7, Config::GRID, Config::N}; + validate_streamk_base_constructor(expected_values, tile_partitioner); +} + +TEST(StreamKTilePartitionerBaseConstructor, EdgeCase) +{ + using Config = StreamKTilePartitionerBaseConfigEdgeCase; + + ck_tile::StreamKTilePartitionerBase tile_partitioner{ + Config::M, Config::N, Config::K, Config::GRID}; + + StreamKTilePartitionerBaseExpected expected_values{ + 0, 1, 0, 0, 0, 2, 0, 2, 1, Config::GRID, Config::N}; + validate_streamk_base_constructor(expected_values, tile_partitioner); +} + +TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, AtomicStrategy) +{ + using Config = StreamKTilePartitionerBaseConfigDP2TileSK; + + ck_tile::StreamKTilePartitionerBase tile_partitioner{ + Config::M, Config::N, Config::K, Config::GRID}; + + EXPECT_EQ(tile_partitioner.get_workspace_size(sizeof(float)), 0); +} + +TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, ReductionStrategy) +{ + using Config = StreamKTilePartitionerBaseConfigDP2TileSK; + + ck_tile::StreamKTilePartitionerBase + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + ck_tile::index_t expected_partials_size = + sizeof(float) * Config::M_TILE * Config::N_TILE * Config::GRID; + ck_tile::index_t expected_flags_size = sizeof(ck_tile::index_t) * Config::GRID; + + EXPECT_EQ(tile_partitioner.get_workspace_size(sizeof(float)), + expected_partials_size + expected_flags_size); +} + +TEST(StreamKTilePartitionerBaseGetLocalIter, GetLocalIter) +{ + // Types + using Config = StreamKTilePartitionerBaseConfigSKOnly; + using TilePartitioner = ck_tile::StreamKTilePartitionerBase; + using Kernel = KernelWrapperSpecialized; + + // Test parameters + ck_tile::DeviceMem local_iter_start_dev(sizeof(ck_tile::index_t)); + ck_tile::index_t iter_start = 3; + ck_tile::index_t tile_iter_start = 2; + + // Launch kernel + auto kargs = Kernel::MakeKernelArgs(iter_start, + tile_iter_start, + Config::UNUSED, + local_iter_start_dev.GetDeviceBuffer(), + nullptr, + Config::UNUSED); + ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1}, + ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs)); + + // Validate result + ck_tile::index_t local_iter_start; + local_iter_start_dev.FromDevice(&local_iter_start); + EXPECT_EQ(local_iter_start, iter_start - tile_iter_start); +} + +TEST(StreamKTilePartitionerBaseGetLocalIterEnd, MinIsTileIterEnd) +{ + // Types + using Config = StreamKTilePartitionerBaseConfigDP2TileSK; + using TilePartitioner = ck_tile::StreamKTilePartitionerBase; + using Kernel = KernelWrapperSpecialized; + // Test parameters + ck_tile::DeviceMem local_iter_end_dev(sizeof(ck_tile::index_t)); + ck_tile::index_t tile_iter_start = 6; + ck_tile::index_t iter_end = 9; + ck_tile::index_t tile_iter_end = 8; + + // Launch kernel + auto kargs = Kernel::MakeKernelArgs(tile_iter_start, + iter_end, + tile_iter_end, + local_iter_end_dev.GetDeviceBuffer(), + nullptr, + Config::UNUSED); + ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1}, + ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs)); + + // Validate results + ck_tile::index_t local_iter_end; + local_iter_end_dev.FromDevice(&local_iter_end); + EXPECT_EQ(local_iter_end, tile_iter_end - tile_iter_start); +} + +TEST(StreamKTilePartitionerBaseGetLocalIterEnd, MinIsIterEnd) +{ + // Types + // Note: For this test, the Config is used for types only, the function get_local_iter_end is + // static; thus, the test parameters are independent of the Config in this case. + using Config = StreamKTilePartitionerBaseConfigDP2TileSK; + using TilePartitioner = ck_tile::StreamKTilePartitionerBase; + using Kernel = KernelWrapperSpecialized; + // Test parameters + ck_tile::DeviceMem local_iter_end_dev(sizeof(ck_tile::index_t)); + ck_tile::index_t tile_iter_start = 12; + ck_tile::index_t iter_end = 13; + ck_tile::index_t tile_iter_end = 14; + + // Launch kernel + auto kargs = Kernel::MakeKernelArgs(tile_iter_start, + iter_end, + tile_iter_end, + local_iter_end_dev.GetDeviceBuffer(), + nullptr, + Config::UNUSED); + ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1}, + ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs)); + + // Validate results + ck_tile::index_t local_iter_end; + local_iter_end_dev.FromDevice(&local_iter_end); + EXPECT_EQ(local_iter_end, iter_end - tile_iter_start); +} + +TEST(StreamKTilePartitionerBaseGetTileBoundaries, GetTileBoundaries) +{ + // Types + using Config = StreamKTilePartitionerBaseConfigSKOnly; + using TilePartitioner = ck_tile::StreamKTilePartitionerBase; + using Kernel = + KernelWrapperSpecialized; + + // Test parameters + ck_tile::StreamKTilePartitionerBase tile_partitioner{ + Config::M, Config::N, Config::K, Config::GRID}; + ck_tile::DeviceMem tile_iter_start_dev(sizeof(ck_tile::index_t)); + ck_tile::DeviceMem tile_iter_end_dev(sizeof(ck_tile::index_t)); + ck_tile::index_t tile_idx = 1; + + // Launch kernel + auto kargs = Kernel::MakeKernelArgs(Config::PLACEHOLDER, + Config::PLACEHOLDER, + tile_idx, + tile_iter_start_dev.GetDeviceBuffer(), + tile_iter_end_dev.GetDeviceBuffer(), + tile_partitioner); + ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1}, + ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs)); + + // Validate results + ck_tile::index_t tile_iter_start, tile_iter_end; + tile_iter_start_dev.FromDevice(&tile_iter_start); + tile_iter_end_dev.FromDevice(&tile_iter_end); + // There are 2 iters per tile. Thus, for tile_idx 1, we expect 2 and 4 to be the start and end, + // respectively. + EXPECT_EQ(tile_iter_start, 2); + EXPECT_EQ(tile_iter_end, 4); +} + +TEST(StreamKTilePartitionerBaseGetTileIndex, GetTileIndex) +{ + // Types + using Config = StreamKTilePartitionerBaseConfigDP2TileSK; + using TilePartitioner = ck_tile::StreamKTilePartitionerBase; + using Kernel = KernelWrapperSpecialized; + + // Test parameters + ck_tile::StreamKTilePartitionerBase tile_partitioner{ + Config::M, Config::N, Config::K, Config::GRID}; + ck_tile::DeviceMem tile_idx_dev(sizeof(ck_tile::index_t)); + ck_tile::index_t iter_start = 8; + + // Launch kernel + auto kargs = Kernel::MakeKernelArgs(iter_start, + Config::UNUSED, + Config::UNUSED, + tile_idx_dev.GetDeviceBuffer(), + nullptr, + tile_partitioner); + ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1}, + ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs)); + + // Validate results + ck_tile::index_t tile_idx; + tile_idx_dev.FromDevice(&tile_idx); + // Since there are 2 iters per tile, iter 8 maps to tile_idx 4. + EXPECT_EQ(tile_idx, 4); +} + +TEST(StreamKTilePartitionerBaseGetIterBoundaries, ZeroExtraItersBeforeMe) +{ + // Types + using Config = StreamKTilePartitionerBaseConfigDP2TileSK; + using TilePartitioner = ck_tile::StreamKTilePartitionerBase; + using Kernel = + KernelWrapperSpecialized; + + // Test parameters + ck_tile::StreamKTilePartitionerBase tile_partitioner{ + Config::M, Config::N, Config::K, Config::GRID}; + ck_tile::DeviceMem iter_start_dev(sizeof(ck_tile::index_t)); + ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t)); + ck_tile::index_t cta_idx = 0; + + // Launch kernel + auto kargs = Kernel::MakeKernelArgs(Config::PLACEHOLDER, + Config::PLACEHOLDER, + cta_idx, + iter_start_dev.GetDeviceBuffer(), + iter_end_dev.GetDeviceBuffer(), + tile_partitioner); + ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1}, + ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs)); + + // Validate results + ck_tile::index_t iter_start, iter_end; + iter_start_dev.FromDevice(&iter_start); + iter_end_dev.FromDevice(&iter_end); + EXPECT_EQ(iter_start, 6); + EXPECT_EQ(iter_end, 9); +} + +TEST(StreamKTilePartitionerBaseGetIterBoundaries, NonZeroExtraItersBeforeMe) +{ + // Types + using Config = StreamKTilePartitionerBaseConfigDP2TileSK; + using TilePartitioner = ck_tile::StreamKTilePartitionerBase; + using Kernel = + KernelWrapperSpecialized; + + // Test parameters + ck_tile::StreamKTilePartitionerBase tile_partitioner{ + Config::M, Config::N, Config::K, Config::GRID}; + ck_tile::DeviceMem iter_start_dev(sizeof(ck_tile::index_t)); + ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t)); + ck_tile::index_t cta_idx = 1; + + // Launch kernel + auto kargs = Kernel::MakeKernelArgs(Config::PLACEHOLDER, + Config::PLACEHOLDER, + cta_idx, + iter_start_dev.GetDeviceBuffer(), + iter_end_dev.GetDeviceBuffer(), + tile_partitioner); + ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1}, + ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs)); + + // Validate results + ck_tile::index_t iter_start, iter_end; + iter_start_dev.FromDevice(&iter_start); + iter_end_dev.FromDevice(&iter_end); + EXPECT_EQ(iter_start, 9); + EXPECT_EQ(iter_end, 12); +} + +TEST(StreamKTilePartitionerBaseGetIterBoundaries, MinIsExtraIters) +{ + // Types + using Config = StreamKTilePartitionerBaseConfigDP2TileSK; + using TilePartitioner = ck_tile::StreamKTilePartitionerBase; + using Kernel = + KernelWrapperSpecialized; + + // Test parameters + ck_tile::StreamKTilePartitionerBase tile_partitioner{ + Config::M, Config::N, Config::K, Config::GRID}; + ck_tile::DeviceMem iter_start_dev(sizeof(ck_tile::index_t)); + ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t)); + ck_tile::index_t cta_idx = 2; + + // Launch kernel + auto kargs = Kernel::MakeKernelArgs(Config::PLACEHOLDER, + Config::PLACEHOLDER, + cta_idx, + iter_start_dev.GetDeviceBuffer(), + iter_end_dev.GetDeviceBuffer(), + tile_partitioner); + ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1}, + ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs)); + + // Validate results + ck_tile::index_t iter_start, iter_end; + iter_start_dev.FromDevice(&iter_start); + iter_end_dev.FromDevice(&iter_end); + EXPECT_EQ(iter_start, 12); + EXPECT_EQ(iter_end, 14); +} + +TEST(StreamKTilePartitionerBaseGetOutputTileIndex, TestAllMappings) +{ + using Config = StreamKTilePartitionerBaseConfigLargerCTensor; + ck_tile::index_t m_macro_tiles = Config::M / Config::M_TILE; + ck_tile::index_t n_macro_tiles = Config::N / Config::N_TILE; + ck_tile::index_t tile_idx = 0; + + for(ck_tile::index_t row = 0; row < m_macro_tiles; ++row) + { + for(ck_tile::index_t col = 0; col < n_macro_tiles; ++col) + { + test_get_output_tile_index(tile_idx, ck_tile::make_tuple(row, col)); + ++tile_idx; + } + } +} + +// Persistent +TEST(StreamKTilePartitioner_v2_PersistentConstructor, SKOnly) +{ + using Config = StreamKTilePartitionerBaseConfigSKOnly; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + StreamKTilePartitionerV2PersistentExpected expected_values{0, 0, 3}; + validate_streamk_v2_persistent(expected_values, tile_partitioner); +} + +TEST(StreamKTilePartitioner_v2_PersistentConstructor, DPOnly) +{ + using Config = StreamKTilePartitionerBaseConfigDPOnly; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + StreamKTilePartitionerV2PersistentExpected expected_values{2, 0, 3}; + validate_streamk_v2_persistent(expected_values, tile_partitioner); +} + +TEST(StreamKTilePartitioner_v2_PersistentConstructor, DP2TileSK) +{ + using Config = StreamKTilePartitionerBaseConfigDP2TileSK; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + StreamKTilePartitionerV2PersistentExpected expected_values{1, 0, 3}; + validate_streamk_v2_persistent(expected_values, tile_partitioner); +} + +TEST(StreamKTilePartitioner_v2_PersistentConstructor, EdgeCase) +{ + using Config = StreamKTilePartitionerBaseConfigEdgeCase; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + StreamKTilePartitionerV2PersistentExpected expected_values{0, 1, 4}; + validate_streamk_v2_persistent(expected_values, tile_partitioner); +} + +TEST(StreamKTilePartitioner_v2_GridSize_Persistent, SKOnly) +{ + using Config = StreamKTilePartitionerBaseConfigSKOnly; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + const auto g = tile_partitioner.grid_size(); + EXPECT_EQ(g.x, Config::GRID); +} + +TEST(StreamKTilePartitioner_v2_GridSize_Persistent, EdgeCase) +{ + using Config = StreamKTilePartitionerBaseConfigEdgeCase; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + const auto g = tile_partitioner.grid_size(); + EXPECT_EQ(g.x, 1); +} + +// Non-Persistent Tests +TEST(StreamKTilePartitioner_v2_NonPersistentConstructor, SKOnly) +{ + using Config = StreamKTilePartitionerBaseConfigSKOnly; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + StreamKTilePartitionerV2NonPersistentExpected expected_values{0, 0, 0, 3}; + validate_streamk_v2_nonpersistent(expected_values, tile_partitioner); +} + +TEST(StreamKTilePartitioner_v2_NonPersistentConstructor, DPOnly) +{ + using Config = StreamKTilePartitionerBaseConfigDPOnly; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + StreamKTilePartitionerV2NonPersistentExpected expected_values{6, 0, 6, 3}; + validate_streamk_v2_nonpersistent(expected_values, tile_partitioner); +} + +TEST(StreamKTilePartitioner_v2_NonPersistentConstructor, DP2TileSK) +{ + using Config = StreamKTilePartitionerBaseConfigDP2TileSK; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + StreamKTilePartitionerV2NonPersistentExpected expected_values{3, 0, 3, 3}; + validate_streamk_v2_nonpersistent(expected_values, tile_partitioner); +} + +TEST(StreamKTilePartitioner_v2_NonPersistentConstructor, EdgeCase) +{ + using Config = StreamKTilePartitionerBaseConfigEdgeCase; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + StreamKTilePartitionerV2NonPersistentExpected expected_values{1, 0, 1, 4}; + validate_streamk_v2_nonpersistent(expected_values, tile_partitioner); +} + +TEST(StreamKTilePartitioner_v2_GridSize_NonPersistent, DP2TileSK) +{ + using Config = StreamKTilePartitionerBaseConfigDP2TileSK; + + ck_tile::StreamKTilePartitioner_v2 + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + const auto g = tile_partitioner.grid_size(); + EXPECT_EQ(g.x, 6); +} diff --git a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp new file mode 100644 index 0000000000..4fc654a7ea --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp @@ -0,0 +1,339 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "gtest/gtest.h" + +enum StreamKTilePartitionerBaseMethodId +{ + GET_LOCAL_ITER, + GET_LOCAL_ITER_END, + GET_TILE_BOUNDARIES, + GET_TILE_INDEX, + GET_ITER_BOUNDARIES, + GET_OUTPUT_TILE_INDEX +}; + +// Base kernel wrapper class to facilitate testing class device functions. +template +struct KernelWrapper +{ + static constexpr ck_tile::index_t kBlockSize = 1; + + struct KernelArgs + { + ck_tile::index_t arg1; + ck_tile::index_t arg2; + ck_tile::index_t arg3; + void* result1; + void* result2; + T tile_partitioner; + }; + + CK_TILE_HOST static KernelArgs MakeKernelArgs(ck_tile::index_t arg1, + ck_tile::index_t arg2, + ck_tile::index_t arg3, + void* result1, + void* result2, + T tile_partitioner) + { + return KernelArgs{arg1, arg2, arg3, result1, result2, tile_partitioner}; + } +}; + +// Specialized derived class to support unique operator() functions. There is one template +// specialization per member in the StreamKTilePartitionerBaseMethodId enum. +template +struct KernelWrapperSpecialized; + +template +struct KernelWrapperSpecialized + : public KernelWrapper<> +{ + using Base = KernelWrapper<>; + + CK_TILE_DEVICE void operator()(Base::KernelArgs kargs) + { + *(static_cast(kargs.result1)) = + TilePartitioner::get_local_iter(kargs.arg1, kargs.arg2); + } +}; + +template +struct KernelWrapperSpecialized + : public KernelWrapper +{ + + using Base = KernelWrapper; + + CK_TILE_DEVICE void operator()(typename Base::KernelArgs kargs) + { + kargs.tile_partitioner.get_tile_boundaries(kargs.arg1, kargs.arg2, kargs.arg3); + *(static_cast(kargs.result1)) = kargs.arg1; + *(static_cast(kargs.result2)) = kargs.arg2; + } +}; + +template +struct KernelWrapperSpecialized + : public KernelWrapper +{ + + using Base = KernelWrapper; + + CK_TILE_DEVICE void operator()(typename Base::KernelArgs kargs) + { + kargs.tile_partitioner.get_iter_boundaries(kargs.arg1, kargs.arg2, kargs.arg3); + *(static_cast(kargs.result1)) = kargs.arg1; + *(static_cast(kargs.result2)) = kargs.arg2; + } +}; + +template +struct KernelWrapperSpecialized + : public KernelWrapper<> +{ + + using Base = KernelWrapper<>; + CK_TILE_DEVICE void operator()(Base::KernelArgs kargs) + { + *(static_cast(kargs.result1)) = + TilePartitioner::get_local_iter_end(kargs.arg1, kargs.arg2, kargs.arg3); + } +}; + +template +struct KernelWrapperSpecialized + : public KernelWrapper +{ + + using Base = KernelWrapper; + + CK_TILE_DEVICE void operator()(typename Base::KernelArgs kargs) + { + *(static_cast(kargs.result1)) = + kargs.tile_partitioner.get_tile_index(kargs.arg1); + } +}; + +template +struct KernelWrapperSpecialized + : public KernelWrapper +{ + + using Base = KernelWrapper; + + CK_TILE_DEVICE void operator()(typename Base::KernelArgs kargs) + { + auto [im, in] = kargs.tile_partitioner.get_output_tile_index(kargs.arg1); + *(static_cast(kargs.result1)) = im; + *(static_cast(kargs.result2)) = in; + } +}; + +struct StreamKTilePartitionerBaseExpected +{ + ck_tile::index_t sk_tiles_; + ck_tile::index_t dp_tiles_; + ck_tile::index_t sk_ctas_; + ck_tile::index_t total_sk_iters_; + ck_tile::index_t iters_per_sk_cta_; + ck_tile::index_t iters_per_tile_; + ck_tile::index_t extra_iters_; + ck_tile::index_t total_dp_iters_; + ck_tile::index_t num_tiles_; + ck_tile::index_t grid_; + ck_tile::index_t n_; +}; + +template +void validate_streamk_base_constructor( + StreamKTilePartitionerBaseExpected& expected_values, + ck_tile::StreamKTilePartitionerBase& tile_partitioner) +{ + EXPECT_EQ(tile_partitioner.get_sk_tiles(), expected_values.sk_tiles_); + EXPECT_EQ(tile_partitioner.get_dp_tiles(), expected_values.dp_tiles_); + EXPECT_EQ(tile_partitioner.get_sk_ctas(), expected_values.sk_ctas_); + EXPECT_EQ(tile_partitioner.get_total_sk_iters(), expected_values.total_sk_iters_); + EXPECT_EQ(tile_partitioner.get_iters_per_sk_cta(), expected_values.iters_per_sk_cta_); + EXPECT_EQ(tile_partitioner.get_extra_iters(), expected_values.extra_iters_); + EXPECT_EQ(tile_partitioner.get_iters_per_tile(), expected_values.iters_per_tile_); + EXPECT_EQ(tile_partitioner.get_total_dp_iters(), expected_values.total_dp_iters_); + EXPECT_EQ(tile_partitioner.get_num_tiles(), expected_values.num_tiles_); + EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_); + EXPECT_EQ(tile_partitioner.get_n(), expected_values.n_); +} + +struct StreamKTilePartitionerBaseConfig +{ + static constexpr ck_tile::index_t PLACEHOLDER = -1; + static constexpr ck_tile::index_t UNUSED = -1; +}; + +// Note: for the configs below, we only use BlockTiles in the TileGemmShape. We do not use +// BlockWarps or WarpTile. + +struct StreamKTilePartitionerBaseConfigDP2TileSK : public StreamKTilePartitionerBaseConfig +{ + static constexpr ck_tile::index_t M = 28; + static constexpr ck_tile::index_t N = 4; + static constexpr ck_tile::index_t K = 16; + static constexpr ck_tile::index_t GRID = 3; + + static constexpr ck_tile::index_t M_TILE = 4; + static constexpr ck_tile::index_t N_TILE = 4; + static constexpr ck_tile::index_t K_TILE = 8; + + using GemmShape = ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; +}; + +struct StreamKTilePartitionerBaseConfigDPOnly : public StreamKTilePartitionerBaseConfig +{ + static constexpr ck_tile::index_t M = 12; + static constexpr ck_tile::index_t N = 4; + static constexpr ck_tile::index_t K = 16; + static constexpr ck_tile::index_t GRID = 3; + + static constexpr ck_tile::index_t M_TILE = 4; + static constexpr ck_tile::index_t N_TILE = 2; + static constexpr ck_tile::index_t K_TILE = 8; + + using GemmShape = ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; +}; + +struct StreamKTilePartitionerBaseConfigSKOnly : public StreamKTilePartitionerBaseConfig +{ + static constexpr ck_tile::index_t M = 4; + static constexpr ck_tile::index_t N = 4; + static constexpr ck_tile::index_t K = 16; + static constexpr ck_tile::index_t GRID = 3; + + static constexpr ck_tile::index_t M_TILE = 4; + static constexpr ck_tile::index_t N_TILE = 2; + static constexpr ck_tile::index_t K_TILE = 8; + + using GemmShape = ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; +}; + +struct StreamKTilePartitionerBaseConfigEdgeCase : public StreamKTilePartitionerBaseConfig +{ + + static constexpr ck_tile::index_t M = 4; + static constexpr ck_tile::index_t N = 4; + static constexpr ck_tile::index_t K = 16; + static constexpr ck_tile::index_t GRID = 4; + + static constexpr ck_tile::index_t M_TILE = 4; + static constexpr ck_tile::index_t N_TILE = 4; + static constexpr ck_tile::index_t K_TILE = 8; + + using GemmShape = ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; +}; + +struct StreamKTilePartitionerBaseConfigLargerCTensor : public StreamKTilePartitionerBaseConfig +{ + // This config has 3 macro tiles in the M dimension and 4 macro tiles in the N dimension. + // This facilitates testing the get_output_tile_index method. + + static constexpr ck_tile::index_t M = 12; + static constexpr ck_tile::index_t N = 16; + static constexpr ck_tile::index_t K = 16; + static constexpr ck_tile::index_t GRID = 4; + + static constexpr ck_tile::index_t M_TILE = 4; + static constexpr ck_tile::index_t N_TILE = 4; + static constexpr ck_tile::index_t K_TILE = 8; + + using GemmShape = ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; +}; + +void test_get_output_tile_index(ck_tile::index_t tile_idx, + ck_tile::tuple expected_2d_idx) +{ + // Types + using Config = StreamKTilePartitionerBaseConfigLargerCTensor; + using TilePartitioner = ck_tile::StreamKTilePartitionerBase; + using Kernel = + KernelWrapperSpecialized; + + // Test parameters + ck_tile::StreamKTilePartitionerBase tile_partitioner{ + Config::M, Config::N, Config::K, Config::GRID}; + ck_tile::DeviceMem im_dev(sizeof(ck_tile::index_t)); + ck_tile::DeviceMem in_dev(sizeof(ck_tile::index_t)); + + // Launch kernel + auto kargs = Kernel::MakeKernelArgs(tile_idx, + Config::UNUSED, + Config::UNUSED, + im_dev.GetDeviceBuffer(), + in_dev.GetDeviceBuffer(), + tile_partitioner); + ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1}, + ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs)); + + // Validate results + const auto [im_expected, in_expected] = expected_2d_idx; + ck_tile::index_t im, in; + im_dev.FromDevice(&im); + in_dev.FromDevice(&in); + EXPECT_EQ(im, im_expected); + EXPECT_EQ(in, in_expected); +}; + +// Configs for TilePartitioner Child structs +struct StreamKTilePartitionerV2PersistentExpected +{ + ck_tile::index_t dp_tiles_per_cta_; + ck_tile::index_t extra_dp_tiles_; + ck_tile::index_t grid_; +}; + +struct StreamKTilePartitionerV2NonPersistentExpected +{ + ck_tile::index_t dp_ctas_; + ck_tile::index_t dp_start_block_idx_; + ck_tile::index_t sk_start_block_idx_; + ck_tile::index_t grid_; +}; + +// Persistent +template +void validate_streamk_v2_persistent( + StreamKTilePartitionerV2PersistentExpected& expected_values, + ck_tile::StreamKTilePartitioner_v2& + tile_partitioner) +{ + EXPECT_EQ(tile_partitioner.get_dp_tiles_per_cta(), expected_values.dp_tiles_per_cta_); + EXPECT_EQ(tile_partitioner.get_extra_dp_tiles(), expected_values.extra_dp_tiles_); + EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_); +} + +// Non-Persistent +template +void validate_streamk_v2_nonpersistent( + StreamKTilePartitionerV2NonPersistentExpected& expected_values, + ck_tile::StreamKTilePartitioner_v2& + tile_partitioner) +{ + EXPECT_EQ(tile_partitioner.get_dp_ctas(), expected_values.dp_ctas_); + EXPECT_EQ(tile_partitioner.get_dp_start_block_idx(), expected_values.dp_start_block_idx_); + EXPECT_EQ(tile_partitioner.get_sk_start_block_idx(), expected_values.sk_start_block_idx_); + EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_); +}