Stream-K Tile Partitioner Base Class with Tests

To better align with the original Stream-K paper, this change implements
a new Stream-K tile partitioner base class. This class will handle the
Stream-K setup that is common to both a persistent and non-persistent DP
section. A later change will implement derived classes to handle the
differences between persistent and non-persistent DP.

This change also includes unit tests for the base tile partitioner.
This commit is contained in:
Emily Martins
2025-10-08 15:53:19 +00:00
committed by Emily Martins
parent 2d1c9e28e2
commit f87f768d16
6 changed files with 1073 additions and 1 deletions

View File

@@ -0,0 +1,207 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace ck_tile {
/**
* @brief Stream-K tile partitioner base class.
*
* This partitioner is responsible for mapping workgroups to tiles in the C tensor
* for the Stream-K algorithm.
*
* @tparam BlockGemmShapeType A class providing basic GEMM parameters.
* @tparam ReductionStrategy An enum that defines the reduction strategy for the results in the C
* Tensor.
*/
template <typename BlockGemmShapeType,
StreamKReductionStrategy ReductionStrategy = StreamKReductionStrategy::Atomic>
struct StreamKTilePartitionerBase
{
using BlockGemmShape = BlockGemmShapeType;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr StreamKReductionStrategy StreamKReductionStrategy = ReductionStrategy;
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid);
private:
/**
* @brief Calculates the total space needed for the partials buffer.
*
* @param acc_element_bytes The number of bytes for the accumulator data type used in the GEMM.
* @return index_t The number of bytes needed for the partials buffer.
*/
CK_TILE_HOST index_t get_partials_buffer_size(index_t acc_element_bytes) const noexcept;
/**
* @brief Calculates the total space needed for the flags buffer.
*
* @return index_t The number of bytes needed for the flags buffer.
*/
CK_TILE_HOST index_t get_flags_buffer_size() const noexcept;
public:
/**
* @brief Calculates the start and end iteration given the cta_idx.
*
* @param iter Reference to an index_t; will be set to the starting iteration by the
* function.
* @param iter_end Reference to an index_t; will be set to the non-inclusive end iteration by
* the function.
* @param cta_idx The current Stream-K workgroup's index.
* @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a
* non-persistent DP section is used, then a Stream-K workgroup's `cta_idx` should be something
* like `blockIdx.x` minus number of DP workgroups.
*/
CK_TILE_DEVICE void
get_iter_boundaries(index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept;
/**
* @brief Calculates the 1D tile index in the C tensor for a workgroup.
*
* @param iter The starting iteration.
* @return index_t The 1D tile index.
*/
CK_TILE_DEVICE index_t get_tile_index(index_t iter) const noexcept;
/**
* @brief Calculates the starting and ending tile boundaries for the given 1D tile index.
*
* @param tile_iter Reference to an index_t; will be set to the tile's start iteration by
* the function.
* @param tile_iter_end Reference to an index_t; will be set to the non-inclusive tile's end
* iteration by the function.
* @param tile_idx The 1D C tensor tile index for the workgroup.
*/
CK_TILE_DEVICE void get_tile_boundaries(index_t& tile_iter,
index_t& tile_iter_end,
index_t tile_idx) const noexcept;
/**
* @brief Calculates the workgroup's starting iteration that is local to a tile.
*
* @param iter The starting iteration.
* @return index_t The local starting iteration. The value is in range [0, `iters_per_tile_`).
* @note Assumes `iter` >= `tile_iter`.
*/
CK_TILE_DEVICE static index_t get_local_iter(index_t iter, index_t tile_iter) noexcept;
/**
* @brief Calculates the workgroup's non-inclusive end iteration that is local to a tile.
*
* @param tile_iter The starting tile iteration.
* @param iter_end The non-inclusive end iteration.
* @param tile_iter_end The non-inclusive end iteration of the tile.
* @return index_t The local non-inclusive end iteration.
* @note Assumes `iter_end` >= `tile_iter` and `tile_iter_end` >= `tile_iter`.
*/
CK_TILE_DEVICE static index_t
get_local_iter_end(index_t tile_iter, index_t iter_end, index_t tile_iter_end) noexcept;
/**
* @brief Calculates the workgroups 2D tile index in the C tensor given the 1D tile index.
*
* @param tile_idx The 1D tile index in the C tensor for the workgroup.
* @return index_t The corresponding 2D tile index in the C tensor for the workgroup.
*/
CK_TILE_DEVICE auto
get_output_tile_index(index_t tile_idx) const noexcept -> tuple<index_t, index_t>;
/**
* @brief Calculates the total space needed for the partials and flags buffers.
*
* @param acc_element_bytes The number of bytes for the accumulator data type used in the GEMM.
* @return index_t The number of bytes needed for the partials and flags buffers.
*/
CK_TILE_HOST index_t get_workspace_size(index_t acc_element_bytes) const noexcept;
CK_TILE_HOST_DEVICE index_t get_num_tiles() const noexcept;
CK_TILE_HOST_DEVICE index_t get_grid() const noexcept;
CK_TILE_HOST_DEVICE index_t get_dp_tiles() const noexcept;
CK_TILE_HOST_DEVICE index_t get_sk_tiles() const noexcept;
CK_TILE_HOST_DEVICE index_t get_sk_ctas() const noexcept;
CK_TILE_HOST_DEVICE index_t get_total_sk_iters() const noexcept;
CK_TILE_HOST_DEVICE index_t get_iters_per_tile() const noexcept;
CK_TILE_HOST_DEVICE index_t get_iters_per_sk_cta() const noexcept;
CK_TILE_HOST_DEVICE index_t get_extra_iters() const noexcept;
CK_TILE_HOST_DEVICE index_t get_total_dp_iters() const noexcept;
CK_TILE_HOST_DEVICE index_t get_n() const noexcept;
protected:
/**
* @brief The number of macro tiles in the C tensor.
*/
index_t num_tiles_;
/**
* @brief The maximum number of active workgroups; this is assumed to be number of CUs *
* occupancy.
*/
index_t grid_;
/**
* @brief The number of tiles in the C tensor that will use the data-parallel (DP) approach.
*/
index_t dp_tiles_;
private:
/**
* @brief The number of full tiles assigned to each `sk_cta` when performing DP + 2 Tile SK.
*/
index_t full_tiles_ = 1;
/**
* @brief The number of tiles in the C tensor that will use the Stream-K approach.
*/
index_t sk_tiles_;
/**
* @brief The number of workgroups that will participate in Stream-K in the `sk_tiles_`.
*/
index_t sk_ctas_;
/**
* @brief The total number of Stream-K iterations.
*/
index_t total_sk_iters_;
/**
* @brief The total number of iterations per tile in the C tensor. In other words, this is the
* total number of macro tiles along the K dimension of A and B.
*/
index_t iters_per_tile_;
/**
* @brief The total number of Stream-K iterations for each `sk_cta`. This is the lower bound
* (i.e., all `sk_ctas_` are guaranteed to perform at least this many iterations).
*/
index_t iters_per_sk_cta_;
/**
* @brief The remainder resulting from `total_sk_iters_` divided by `sk_ctas_`. When this is
* non-zero, the first `extra_iters_` `sk_ctas_` will get one additional iteration assigned to
* them; such work groups will perform (`iters_per_sk_cta_` + 1) iterations.
*/
index_t extra_iters_;
/**
* @brief The total number of DP iterations.
*/
index_t total_dp_iters_;
/**
* @brief The n dimension for the GEMM problem.
*/
index_t n_;
};
} // namespace ck_tile
#include "streamk_gemm_tile_partitioner_impl.hpp"

View File

@@ -0,0 +1,214 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
namespace ck_tile {
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::StreamKTilePartitionerBase(
index_t m, index_t n, index_t k, index_t grid)
: grid_{grid}, n_{n}
{
iters_per_tile_ = integer_divide_ceil(k, KPerBlock);
num_tiles_ = integer_divide_ceil(m, MPerBlock) * integer_divide_ceil(n_, NPerBlock);
bool big_enough = num_tiles_ > grid_;
index_t remainder_tiles = num_tiles_ % grid_;
if(remainder_tiles)
{
sk_tiles_ = big_enough ? full_tiles_ * grid_ + (num_tiles_ % grid_) : num_tiles_;
sk_tiles_ = min(num_tiles_, sk_tiles_);
sk_ctas_ = grid_;
total_sk_iters_ = sk_tiles_ * iters_per_tile_;
// If there still isn't enough work to saturate all CUs, then just revert to DP only.
if(total_sk_iters_ < grid_)
{
sk_tiles_ = 0;
sk_ctas_ = 0;
total_sk_iters_ = 0;
}
}
else // Full DP (i.e., no Stream-K)
{
sk_tiles_ = 0;
sk_ctas_ = 0;
total_sk_iters_ = 0;
}
iters_per_sk_cta_ = sk_ctas_ ? total_sk_iters_ / sk_ctas_ : 0;
extra_iters_ = sk_ctas_ ? total_sk_iters_ % sk_ctas_ : 0;
dp_tiles_ = num_tiles_ - sk_tiles_;
total_dp_iters_ = dp_tiles_ * iters_per_tile_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
CK_TILE_HOST index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_partials_buffer_size(
index_t acc_element_bytes) const noexcept
{
return MPerBlock * NPerBlock * acc_element_bytes * sk_ctas_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
CK_TILE_HOST index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_flags_buffer_size()
const noexcept
{
return sizeof(index_t) * sk_ctas_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
CK_TILE_DEVICE void
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_iter_boundaries(
index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept
{
index_t extra_iters__before_me = ck_tile::min(cta_idx, extra_iters_);
iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters__before_me;
iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_);
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
CK_TILE_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_tile_index(
index_t iter) const noexcept
{
return iter / iters_per_tile_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
CK_TILE_DEVICE void
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_tile_boundaries(
index_t& tile_iter, index_t& tile_iter_end, index_t tile_idx) const noexcept
{
tile_iter = tile_idx * iters_per_tile_;
tile_iter_end = tile_iter + iters_per_tile_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
CK_TILE_DEVICE /* static */ index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_local_iter(
index_t iter, index_t tile_iter) noexcept
{
return iter - tile_iter;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
CK_TILE_DEVICE /* static */ index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_local_iter_end(
index_t tile_iter, index_t iter_end, index_t tile_iter_end) noexcept
{
return ck_tile::min(iter_end, tile_iter_end) - tile_iter;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
CK_TILE_DEVICE auto
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_output_tile_index(
index_t tile_idx) const noexcept -> tuple<index_t, index_t>
{
const index_t n_macro_tiles = integer_divide_ceil(n_, NPerBlock);
const index_t im = amd_wave_read_first_lane(tile_idx / n_macro_tiles);
const index_t in = amd_wave_read_first_lane(tile_idx - im * n_macro_tiles);
return make_tuple(im, in);
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
CK_TILE_HOST index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_workspace_size(
index_t acc_element_bytes) const noexcept
{
if constexpr(StreamKReductionStrategy == StreamKReductionStrategy::Reduction)
{
return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size();
}
else // ReductionStrategy is Atomics
{
return 0;
}
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_num_tiles() const noexcept
{
return num_tiles_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_grid() const noexcept
{
return grid_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_dp_tiles() const noexcept
{
return dp_tiles_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_sk_tiles() const noexcept
{
return sk_tiles_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_sk_ctas() const noexcept
{
return sk_ctas_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_total_sk_iters()
const noexcept
{
return total_sk_iters_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_iters_per_tile()
const noexcept
{
return iters_per_tile_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_iters_per_sk_cta()
const noexcept
{
return iters_per_sk_cta_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_extra_iters() const noexcept
{
return extra_iters_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_total_dp_iters()
const noexcept
{
return total_dp_iters_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_n() const noexcept
{
return n_;
}
} // namespace ck_tile