diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp index 027e2fdd94..7c6adc3ec2 100644 --- a/include/ck_tile/ops/common.hpp +++ b/include/ck_tile/ops/common.hpp @@ -6,3 +6,4 @@ #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/common/streamk_common.hpp b/include/ck_tile/ops/common/streamk_common.hpp new file mode 100644 index 0000000000..5dbe6223c4 --- /dev/null +++ b/include/ck_tile/ops/common/streamk_common.hpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { +enum StreamKReductionStrategy : uint32_t +{ + Atomic = 0u, + Reduction = 1u +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index b621468e92..92ae6411a5 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -9,6 +9,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" namespace ck_tile { @@ -364,4 +365,454 @@ struct GemmSpatiallyLocalTilePartitioner index_t N; }; +/** + * @brief Stream-K tile partitioner that dynamically balances work across workgroups + * + * This partitioner is responsible for mapping workgroups to tiles in the C tensor + * for the Stream-K algorithm which decomposes the GEMM problem + * into smaller work units and distributes them more evenly across available blocks, + * improving load balancing especially for cases where the K dimension is large. + * + * @tparam BlockGemmShapeType A class providing basic GEMM parameters. + * @tparam ReductionStrategy A class that defines the reduction strategy for the results in + * the C Tensor. + * @tparam TileSwizzleSubM A value that defines the size of the swizzle group along the m + * dimension, where the swizzle group denotes consecutive tiles down a column. For instance a + * swizzle group of 8 denotes tiles 0, 1, ..., 7, map to tiles [0,0], [1,0], ..., [7,0] in the C + * tensor. + */ +template +struct StreamKTilePartitioner +{ + using BlockGemmShape = BlockGemmShapeType; + + static constexpr uint32_t MPerBlock = BlockGemmShape::kM; + static constexpr uint32_t NPerBlock = BlockGemmShape::kN; + static constexpr uint32_t KPerBlock = BlockGemmShape::kK; + + CK_TILE_HOST_DEVICE StreamKTilePartitioner() noexcept = delete; + + /** + * @brief Construct Stream-K tile partitioner with problem dimensions + */ + CK_TILE_HOST_DEVICE StreamKTilePartitioner(uint32_t M, + uint32_t N, + uint32_t K, + uint32_t num_cu, + uint32_t occupancy, + uint32_t sk_blocks = 0xffffffff) noexcept + : M_(M), N_(N), K_(K) + { + num_tile_m_ = integer_divide_ceil(M, MPerBlock); + num_tile_n_ = integer_divide_ceil(N, NPerBlock); + num_tile_k_ = integer_divide_ceil(K, KPerBlock); + + constexpr uint32_t min_k_iters_per_sk_block = 2; + uint32_t num_tiles = num_tile_m_ * num_tile_n_; + k_iters_per_tile = mdiv(num_tile_k_); + + // one cu can hold one wg at one time, from the whole cZ's point of view + // if number of wg is same as num_cu, we call it 1 dispatch + // if number of wg is 2x num_cu, we call it 2 dispatches. + // one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial + // dispatch) + // + const uint32_t full_dispatches = num_tiles / num_cu; + const uint32_t full_dispatch_tiles = full_dispatches * num_cu; + const uint32_t partial_dispatch_tiles = num_tiles - full_dispatch_tiles; + + uint32_t sk_occupancy = occupancy; + uint32_t dp_tiles = full_dispatch_tiles; + uint32_t sk_tiles = partial_dispatch_tiles; + + if(full_dispatches < occupancy) + { + // in this case, we allocate all blocks as sk blocks + // sk_occupancy = occupancy - full_dispatches; + sk_occupancy = 1; + dp_tiles = full_dispatch_tiles; + sk_tiles = partial_dispatch_tiles; + } + else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1)) + { + // e.g. occupancy = 2, full_dispatches = 3, 5, 7 ... + // occupancy = 3, full_dispatches = 5, 8, 11 ... + // occupancy = 4, full_dispatches = 7, 11 ... + sk_occupancy = 1; // left 1 slot for sk occupancy + dp_tiles = full_dispatch_tiles; + sk_tiles = partial_dispatch_tiles; + } + else + { + // otherwise, we reduce 1 dispatch from dp, together with partial dispatch, + // to construct sk dispatch + sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy); + dp_tiles = full_dispatch_tiles - num_cu; + sk_tiles = partial_dispatch_tiles + num_cu; + } + + // uint32_t dp_iters_per_block = k_iters_per_tile.get(); + uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles; + uint32_t dp_num_blocks = 0; + + { + const uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1); + const uint32_t max_sk_tiles = + (sk_tiles >= num_cu) ? num_cu * sk_occupancy + : min(num_cu, sk_total_iters / min_k_iters_per_sk_block); + + // if use dp for sk-block, how many iters do we need + const uint32_t dp_for_sk_iters = k_iters_per_tile.get(); + + uint32_t best_sk_score = + std::numeric_limits::max(); // we need to find the smallest sk iters + for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles; + tentative_sk_blocks++) + { + const uint32_t tentative_sk_iters_per_block = + (sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks; + const uint32_t tentative_sk_iters = tentative_sk_iters_per_block; + const uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles; + + // the more sk_blocks_per_tile, the worse the overhead + uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile; + if(tentative_sk_blocks % sk_tiles != 0) + { + // penalty for uneven divide + cross_sk_blocks_overhead += + sk_blocks_per_tile * tentative_sk_iters_per_block / 50; + } + + const uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead; + + if(tentative_sk_score < best_sk_score) + { + best_sk_score = tentative_sk_score; + sk_num_blocks = tentative_sk_blocks; + } + } + + if(best_sk_score >= dp_for_sk_iters) + { + sk_num_blocks = 0; + } + + // give a chance to control num of sk blocks + sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks; + + if(sk_num_blocks == 0) + { + sk_num_big_blocks = 0; + k_iters_per_big_block = 0; + + dp_num_blocks = num_tiles; // all tile to be dp block + dp_start_block_idx = 0; + sk_total_iters = 0; // clear this tiles + } + else + { + // k_iters_per_sk_block is the floor of avg each ck block loop over tiles. + // we need to decide how many iters for each sk block + // let m = k_iters_per_sk_block + // some of the sk block (little) will cover m iters, some (big) will cover m+1 + // we have + // 1) l + b = sk_blocks + // 2) l * m + b * (m + 1) = sk_total_iters + // => (l + b) * m + b = sk_total_iters + // => sk_blocks * m + b = sk_total_iters + // => b = sk_total_iters - m * sk_blocks + // NOTE: big could be zero + const uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks; + sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks; + k_iters_per_big_block = k_iters_per_sk_block + 1; + + dp_num_blocks = dp_tiles; + dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu; + } + } + n_tiles = mdiv2(num_tile_n_); + reduction_start_block_idx = dp_start_block_idx + dp_num_blocks; + + if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction) + { + const uint32_t upper_big = lcm(k_iters_per_big_block, k_iters_per_tile.get()); + const uint32_t upper_little = lcm(k_iters_per_big_block - 1, k_iters_per_tile.get()); + equiv_tiles_big = mdiv(upper_big / k_iters_per_tile.get()); + equiv_tiles_little = mdiv(upper_little / k_iters_per_tile.get()); + } + } + + /** + * @brief Calculate optimal grid size for Stream-K + */ + CK_TILE_HOST auto GridSize() const noexcept -> dim3 + { + if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction) + { + return dim3(reduction_start_block_idx + GetSkTiles(), 1, 1); + } + else + return dim3(reduction_start_block_idx, 1, 1); + } + + /** + * @brief Calculate number of loop iterations over K dimension for given work unit + */ + CK_TILE_HOST_DEVICE static auto GetLoopNum(uint32_t K) noexcept -> uint32_t + { + return integer_divide_ceil(K, KPerBlock); // Stream-K processes one K-slice at a time + } + + /** + * @brief Get output tile index for standard 2D mapping (compatibility) + */ + CK_TILE_DEVICE auto + GetOutputTileIndex(uint32_t tile_idx) const noexcept -> tuple + { + uint32_t m_tile_idx, n_tile_idx; + n_tiles.divmod(tile_idx, num_tile_n_, m_tile_idx, n_tile_idx); + + // swizzle tile + + uint32_t tile_swizzle_sub_m_rem = num_tile_m_ % TileSwizzleSubM; + + const auto sub_m_adapt = (m_tile_idx < (num_tile_m_ - tile_swizzle_sub_m_rem)) + ? TileSwizzleSubM + : tile_swizzle_sub_m_rem; + + uint32_t m_tile_idx_sub0, m_tile_idx_sub1; + m_tile_idx_sub0 = m_tile_idx / TileSwizzleSubM; + m_tile_idx_sub1 = m_tile_idx % TileSwizzleSubM; + + uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * num_tile_n_; + + uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt; + + n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt; + m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt; + return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * TileSwizzleSubM, + n_tile_idx_with_adapt); + } + + /** + * @brief Get work range for a given block ID + */ + CK_TILE_DEVICE void + GetBlockItr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const noexcept + { + if(block_idx < sk_num_big_blocks) + { + iter_start = block_idx * k_iters_per_big_block; + iter_end = iter_start + k_iters_per_big_block; + } + else if(block_idx < sk_num_blocks) + { + iter_start = (sk_num_big_blocks * k_iters_per_big_block) + + (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1); + iter_end = iter_start + (k_iters_per_big_block - 1); + } + else if(block_idx >= dp_start_block_idx) + { + uint32_t sk_total_iters = GetSkTotalIters(); + uint32_t dp_iters_per_block = k_iters_per_tile.get(); + iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block; + iter_end = iter_start + dp_iters_per_block; + } + } + + /** + * @brief Get total number of iterations for sk tiles + */ + CK_TILE_HOST_DEVICE uint32_t GetSkTotalIters() const noexcept + { + uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block + + (sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1); + return sk_total_iters; + } + + /** + * @brief Get total number of sk tiles + */ + CK_TILE_HOST_DEVICE uint32_t GetSkTiles() const noexcept + { + // tiles for sk + uint32_t sk_total_iters = GetSkTotalIters(); + return k_iters_per_tile.div(sk_total_iters); + } + + /** + * @brief Get length of loop iterations for stream-k loop + */ + CK_TILE_DEVICE uint32_t GetCurrentIterLength(uint32_t iter_start, + uint32_t iter_end, + uint32_t total_iter_length) const noexcept + { + uint32_t iter_length_mod, iter_length_quo /*unused*/; + k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod); + uint32_t total_iter_length_val = static_cast(total_iter_length); + uint32_t current_iter_length = + min(iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, + total_iter_length_val); + return current_iter_length; + } + + /** + * @brief Get index of tile during a specified iteration + */ + CK_TILE_DEVICE uint32_t GetTileIdx(uint32_t iter) const noexcept + { + return k_iters_per_tile.div(iter); + } + + /** + * @brief Get index of tile during a specified iteration + */ + CK_TILE_DEVICE void + GetTileIdxWithOffset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const noexcept + { + uint32_t tile_idx_val = static_cast(tile_idx); + uint32_t iter_offset_val = static_cast(iter_offset); + k_iters_per_tile.divmod(iter, tile_idx_val, iter_offset_val); + } + + /** + * @brief Calculates the buffer space needed for accumulation + */ + CK_TILE_HOST_DEVICE uint32_t GetWorkSpaceSizeForAcc(uint32_t acc_element_bytes) const noexcept + { + static constexpr uint32_t alignment = 128; + uint32_t acc_buffer_bytes = + MPerBlock * NPerBlock * GetTotalAccBuffers() * acc_element_bytes; + return (acc_buffer_bytes + alignment - 1) / alignment * alignment; + } + + /** + * @brief Calculates the buffer space needed for the semaphore + */ + CK_TILE_HOST_DEVICE uint32_t GetWorkSpaceSizeForSemaphore() const noexcept + { + return GetSkTiles() * sizeof(uint32_t); + } + + /** + * @brief Calculates the total buffer space needed for accumulation and the semaphore + */ + CK_TILE_HOST_DEVICE uint32_t GetWorkSpaceSize(uint32_t acc_element_bytes) const noexcept + { + return GetWorkSpaceSizeForAcc(acc_element_bytes) + GetWorkSpaceSizeForSemaphore(); + } + + /** + * @brief Get location of intersection of tiles for reduction + */ + CK_TILE_HOST_DEVICE uint32_t GetTileIntersections(uint32_t tiles_, + const mdiv& equiv_tiles_) const noexcept + { + uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1); + uint32_t max_equiv_tiles_ = equiv_tiles_.get() - 1; + uint32_t quo_, rem_; + equiv_tiles_.divmod(tile_idx_, quo_, rem_); + return quo_ * max_equiv_tiles_ + rem_; + } + + /** + * @brief Calculate the number of tiles needed for the number of sk blocks + */ + CK_TILE_HOST_DEVICE uint32_t GetTilesCoverSkBlock(uint32_t num_sk_blocks_, + uint32_t iters_per_sk_block_) const noexcept + { + return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() - + 1); + } + + /** + * @brief Calculate the amount of total accumulation buffers required for stream-k + */ + CK_TILE_HOST_DEVICE uint32_t GetTotalAccBuffers() const noexcept + { + uint32_t tiles_cover_big_blocks = + GetTilesCoverSkBlock(sk_num_big_blocks, k_iters_per_big_block); + uint32_t tiles_cover_little_blocks = + GetTilesCoverSkBlock(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1); + + uint32_t total_intersec_big = GetTileIntersections(tiles_cover_big_blocks, equiv_tiles_big); + uint32_t total_intersec_little = + GetTileIntersections(tiles_cover_little_blocks, equiv_tiles_little); + + return sk_num_blocks + total_intersec_big + total_intersec_little; + } + + /** + * @brief Calculate offset based on tile index for big/little tiles + */ + CK_TILE_DEVICE uint32_t GetAccBufferOffsetFromTile(uint32_t tile_idx_) const noexcept + { + uint32_t tiles_cover_big_blocks = + GetTilesCoverSkBlock(sk_num_big_blocks, k_iters_per_big_block); + if(tile_idx_ < tiles_cover_big_blocks) + { + uint32_t touched_sk_blocks = + (tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) / + k_iters_per_big_block; + uint32_t current_intersec = GetTileIntersections(tile_idx_, equiv_tiles_big); + return touched_sk_blocks + current_intersec; + } + else + { + uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1; + uint32_t tile_idx_little_reverse = GetSkTiles() - tile_idx_; + uint32_t touched_sk_blocks = + (tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) / + iters_per_little_sk_block; + uint32_t current_intersec = + GetTileIntersections(tile_idx_little_reverse, equiv_tiles_little); + return GetTotalAccBuffers() - (touched_sk_blocks + current_intersec); + } + } + + /** + * @brief Calculate offset based on block_idx index for big/little streamk blocks + */ + CK_TILE_DEVICE uint32_t GetAccBufferOffsetFromBlock(uint32_t block_idx_) const noexcept + { + uint32_t iters_per_big_sk_block = k_iters_per_big_block; + uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1; + if(block_idx_ < sk_num_big_blocks) + { + uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block + + k_iters_per_tile.get() - 1); + uint32_t current_intersec = GetTileIntersections(touched_tiles, equiv_tiles_big); + return block_idx_ + current_intersec; + } + else + { + uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_; + uint32_t touched_tiles = k_iters_per_tile.div( + block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1); + uint32_t current_intersec = GetTileIntersections(touched_tiles, equiv_tiles_little); + return GetTotalAccBuffers() - (block_idx_little_reverse + current_intersec); + } + } + + // Getters for problem dimensions + CK_TILE_HOST_DEVICE uint32_t GetNumTileM() const noexcept { return num_tile_m_; } + CK_TILE_HOST_DEVICE uint32_t GetNumTileN() const noexcept { return num_tile_n_; } + CK_TILE_HOST_DEVICE uint32_t GetNumTileK() const noexcept { return num_tile_k_; } + + uint32_t sk_num_blocks; + uint32_t sk_num_big_blocks; + uint32_t dp_start_block_idx; + uint32_t reduction_start_block_idx; + uint32_t k_iters_per_big_block; + mdiv2 n_tiles; + mdiv k_iters_per_tile; + mdiv equiv_tiles_big; // for reduction + mdiv equiv_tiles_little; // for reduction + + private: + uint32_t M_, N_, K_; + uint32_t num_tile_m_, num_tile_n_, num_tile_k_; +}; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp index a05e7b2ad0..77c431e49c 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp @@ -9,15 +9,6 @@ namespace ck_tile { -enum StreamKReductionStrategy : uint32_t -{ - /// @brief Workgroups atomically add their results to the C tensor - Atomic = 0u, - /// @brief For a given tile in the C tensor, one workgroup accumulates results of other - /// contributing workgroups - Reduction = 1u -}; - /// @brief The Stream K GEMM kernel host arguments. /// /// @par Overview @@ -37,7 +28,7 @@ struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<> index_t stride_B_, index_t stride_C_, StreamKReductionStrategy reduction_strategy_, - index_t num_sk_blocks_ = -1) + uint32_t num_sk_blocks_ = 0xffffffff) : UniversalGemmHostArgs<>({a_ptr_}, {b_ptr_}, {/*ds_ptr*/}, @@ -56,7 +47,7 @@ struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<> } ck_tile::StreamKReductionStrategy reduction_strategy; - index_t num_sk_blocks; + uint32_t num_sk_blocks; }; template @@ -103,7 +94,7 @@ struct StreamKKernel /// @brief The strategy used by work groups to compute final results in C tensor. StreamKReductionStrategy reduction_strategy; /// @brief The number of stream k blocks. - index_t num_sk_blocks; + uint32_t num_sk_blocks; /// @brief A pointer to a buffer in device memory for accumulating partial via reduction /// strategy. void* workspace_ptr; @@ -152,29 +143,32 @@ struct StreamKKernel CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args) { - index_t occupancy = static_cast(Occupancy()); - index_t num_cu = static_cast(NumCU()); + uint32_t occupancy = static_cast(Occupancy()); + uint32_t num_cu = static_cast(NumCU()); - return StreamKKernelArgs{ - {host_args.as_ptr, - host_args.bs_ptr, - host_args.ds_ptr, - host_args.e_ptr, - host_args.M, - host_args.N, - host_args.K, - host_args.stride_As, - host_args.stride_Bs, - host_args.stride_Ds, - host_args.stride_E, - host_args.k_batch}, - host_args.reduction_strategy, - host_args.num_sk_blocks, - // The workspace pointer is set to nullptr because we must first - // instantiate the TilePartitioner to get the necessary size - /*workspace_ptr =*/nullptr, - TilePartitioner{ - host_args.M, host_args.N, host_args.K, num_cu, occupancy, host_args.num_sk_blocks}}; + return StreamKKernelArgs{{host_args.as_ptr, + host_args.bs_ptr, + host_args.ds_ptr, + host_args.e_ptr, + host_args.M, + host_args.N, + host_args.K, + host_args.stride_As, + host_args.stride_Bs, + host_args.stride_Ds, + host_args.stride_E, + host_args.k_batch}, + host_args.reduction_strategy, + host_args.num_sk_blocks, + // The workspace pointer is set to nullptr because we must first + // instantiate the TilePartitioner to get the necessary size + /*workspace_ptr =*/nullptr, + TilePartitioner{static_cast(host_args.M), + static_cast(host_args.N), + static_cast(host_args.K), + num_cu, + occupancy, + host_args.num_sk_blocks}}; } CK_TILE_HOST static bool