mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
[CK TILE] Stream-K tile partitioner (#2708)
* initial commit for skeleton code * replaced skeleton code with old streamk b2c map functions from old CK, still need to clean up the code * fixed up code to match CK Tile convention: data type changes, naming changes, etc. * change for num_sk_blocks data type * formatting fix * minor fixes * moved reduction argument to template * resolved comments from PR review: standardizing naming, pruning unneeded code * resolve errors from merge of device op PR: moved enum to common file * switching to uint32_t due to implementation constraints: divmod only takes uint32_t and mixing signed and unsigned types causes problems * unsigned type fix * add const qualifier * added documentation for template parameters * documentation edit
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -364,4 +365,454 @@ struct GemmSpatiallyLocalTilePartitioner
|
||||
index_t N;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Stream-K tile partitioner that dynamically balances work across workgroups
|
||||
*
|
||||
* This partitioner is responsible for mapping workgroups to tiles in the C tensor
|
||||
* for the Stream-K algorithm which decomposes the GEMM problem
|
||||
* into smaller work units and distributes them more evenly across available blocks,
|
||||
* improving load balancing especially for cases where the K dimension is large.
|
||||
*
|
||||
* @tparam BlockGemmShapeType A class providing basic GEMM parameters.
|
||||
* @tparam ReductionStrategy A class that defines the reduction strategy for the results in
|
||||
* the C Tensor.
|
||||
* @tparam TileSwizzleSubM A value that defines the size of the swizzle group along the m
|
||||
* dimension, where the swizzle group denotes consecutive tiles down a column. For instance a
|
||||
* swizzle group of 8 denotes tiles 0, 1, ..., 7, map to tiles [0,0], [1,0], ..., [7,0] in the C
|
||||
* tensor.
|
||||
*/
|
||||
template <typename BlockGemmShapeType,
|
||||
StreamKReductionStrategy ReductionStrategy = ck_tile::StreamKReductionStrategy::Atomic,
|
||||
uint32_t TileSwizzleSubM = 8>
|
||||
struct StreamKTilePartitioner
|
||||
{
|
||||
using BlockGemmShape = BlockGemmShapeType;
|
||||
|
||||
static constexpr uint32_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr uint32_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr uint32_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
CK_TILE_HOST_DEVICE StreamKTilePartitioner() noexcept = delete;
|
||||
|
||||
/**
|
||||
* @brief Construct Stream-K tile partitioner with problem dimensions
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE StreamKTilePartitioner(uint32_t M,
|
||||
uint32_t N,
|
||||
uint32_t K,
|
||||
uint32_t num_cu,
|
||||
uint32_t occupancy,
|
||||
uint32_t sk_blocks = 0xffffffff) noexcept
|
||||
: M_(M), N_(N), K_(K)
|
||||
{
|
||||
num_tile_m_ = integer_divide_ceil(M, MPerBlock);
|
||||
num_tile_n_ = integer_divide_ceil(N, NPerBlock);
|
||||
num_tile_k_ = integer_divide_ceil(K, KPerBlock);
|
||||
|
||||
constexpr uint32_t min_k_iters_per_sk_block = 2;
|
||||
uint32_t num_tiles = num_tile_m_ * num_tile_n_;
|
||||
k_iters_per_tile = mdiv(num_tile_k_);
|
||||
|
||||
// one cu can hold one wg at one time, from the whole cZ's point of view
|
||||
// if number of wg is same as num_cu, we call it 1 dispatch
|
||||
// if number of wg is 2x num_cu, we call it 2 dispatches.
|
||||
// one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial
|
||||
// dispatch)
|
||||
//
|
||||
const uint32_t full_dispatches = num_tiles / num_cu;
|
||||
const uint32_t full_dispatch_tiles = full_dispatches * num_cu;
|
||||
const uint32_t partial_dispatch_tiles = num_tiles - full_dispatch_tiles;
|
||||
|
||||
uint32_t sk_occupancy = occupancy;
|
||||
uint32_t dp_tiles = full_dispatch_tiles;
|
||||
uint32_t sk_tiles = partial_dispatch_tiles;
|
||||
|
||||
if(full_dispatches < occupancy)
|
||||
{
|
||||
// in this case, we allocate all blocks as sk blocks
|
||||
// sk_occupancy = occupancy - full_dispatches;
|
||||
sk_occupancy = 1;
|
||||
dp_tiles = full_dispatch_tiles;
|
||||
sk_tiles = partial_dispatch_tiles;
|
||||
}
|
||||
else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1))
|
||||
{
|
||||
// e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
|
||||
// occupancy = 3, full_dispatches = 5, 8, 11 ...
|
||||
// occupancy = 4, full_dispatches = 7, 11 ...
|
||||
sk_occupancy = 1; // left 1 slot for sk occupancy
|
||||
dp_tiles = full_dispatch_tiles;
|
||||
sk_tiles = partial_dispatch_tiles;
|
||||
}
|
||||
else
|
||||
{
|
||||
// otherwise, we reduce 1 dispatch from dp, together with partial dispatch,
|
||||
// to construct sk dispatch
|
||||
sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy);
|
||||
dp_tiles = full_dispatch_tiles - num_cu;
|
||||
sk_tiles = partial_dispatch_tiles + num_cu;
|
||||
}
|
||||
|
||||
// uint32_t dp_iters_per_block = k_iters_per_tile.get();
|
||||
uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles;
|
||||
uint32_t dp_num_blocks = 0;
|
||||
|
||||
{
|
||||
const uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
|
||||
const uint32_t max_sk_tiles =
|
||||
(sk_tiles >= num_cu) ? num_cu * sk_occupancy
|
||||
: min(num_cu, sk_total_iters / min_k_iters_per_sk_block);
|
||||
|
||||
// if use dp for sk-block, how many iters do we need
|
||||
const uint32_t dp_for_sk_iters = k_iters_per_tile.get();
|
||||
|
||||
uint32_t best_sk_score =
|
||||
std::numeric_limits<int>::max(); // we need to find the smallest sk iters
|
||||
for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
|
||||
tentative_sk_blocks++)
|
||||
{
|
||||
const uint32_t tentative_sk_iters_per_block =
|
||||
(sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks;
|
||||
const uint32_t tentative_sk_iters = tentative_sk_iters_per_block;
|
||||
const uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles;
|
||||
|
||||
// the more sk_blocks_per_tile, the worse the overhead
|
||||
uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile;
|
||||
if(tentative_sk_blocks % sk_tiles != 0)
|
||||
{
|
||||
// penalty for uneven divide
|
||||
cross_sk_blocks_overhead +=
|
||||
sk_blocks_per_tile * tentative_sk_iters_per_block / 50;
|
||||
}
|
||||
|
||||
const uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead;
|
||||
|
||||
if(tentative_sk_score < best_sk_score)
|
||||
{
|
||||
best_sk_score = tentative_sk_score;
|
||||
sk_num_blocks = tentative_sk_blocks;
|
||||
}
|
||||
}
|
||||
|
||||
if(best_sk_score >= dp_for_sk_iters)
|
||||
{
|
||||
sk_num_blocks = 0;
|
||||
}
|
||||
|
||||
// give a chance to control num of sk blocks
|
||||
sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks;
|
||||
|
||||
if(sk_num_blocks == 0)
|
||||
{
|
||||
sk_num_big_blocks = 0;
|
||||
k_iters_per_big_block = 0;
|
||||
|
||||
dp_num_blocks = num_tiles; // all tile to be dp block
|
||||
dp_start_block_idx = 0;
|
||||
sk_total_iters = 0; // clear this tiles
|
||||
}
|
||||
else
|
||||
{
|
||||
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
|
||||
// we need to decide how many iters for each sk block
|
||||
// let m = k_iters_per_sk_block
|
||||
// some of the sk block (little) will cover m iters, some (big) will cover m+1
|
||||
// we have
|
||||
// 1) l + b = sk_blocks
|
||||
// 2) l * m + b * (m + 1) = sk_total_iters
|
||||
// => (l + b) * m + b = sk_total_iters
|
||||
// => sk_blocks * m + b = sk_total_iters
|
||||
// => b = sk_total_iters - m * sk_blocks
|
||||
// NOTE: big could be zero
|
||||
const uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
|
||||
sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
|
||||
k_iters_per_big_block = k_iters_per_sk_block + 1;
|
||||
|
||||
dp_num_blocks = dp_tiles;
|
||||
dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
|
||||
}
|
||||
}
|
||||
n_tiles = mdiv2(num_tile_n_);
|
||||
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
|
||||
|
||||
if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
const uint32_t upper_big = lcm(k_iters_per_big_block, k_iters_per_tile.get());
|
||||
const uint32_t upper_little = lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
|
||||
equiv_tiles_big = mdiv(upper_big / k_iters_per_tile.get());
|
||||
equiv_tiles_little = mdiv(upper_little / k_iters_per_tile.get());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculate optimal grid size for Stream-K
|
||||
*/
|
||||
CK_TILE_HOST auto GridSize() const noexcept -> dim3
|
||||
{
|
||||
if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
return dim3(reduction_start_block_idx + GetSkTiles(), 1, 1);
|
||||
}
|
||||
else
|
||||
return dim3(reduction_start_block_idx, 1, 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculate number of loop iterations over K dimension for given work unit
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE static auto GetLoopNum(uint32_t K) noexcept -> uint32_t
|
||||
{
|
||||
return integer_divide_ceil(K, KPerBlock); // Stream-K processes one K-slice at a time
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get output tile index for standard 2D mapping (compatibility)
|
||||
*/
|
||||
CK_TILE_DEVICE auto
|
||||
GetOutputTileIndex(uint32_t tile_idx) const noexcept -> tuple<uint32_t, uint32_t>
|
||||
{
|
||||
uint32_t m_tile_idx, n_tile_idx;
|
||||
n_tiles.divmod(tile_idx, num_tile_n_, m_tile_idx, n_tile_idx);
|
||||
|
||||
// swizzle tile
|
||||
|
||||
uint32_t tile_swizzle_sub_m_rem = num_tile_m_ % TileSwizzleSubM;
|
||||
|
||||
const auto sub_m_adapt = (m_tile_idx < (num_tile_m_ - tile_swizzle_sub_m_rem))
|
||||
? TileSwizzleSubM
|
||||
: tile_swizzle_sub_m_rem;
|
||||
|
||||
uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
|
||||
m_tile_idx_sub0 = m_tile_idx / TileSwizzleSubM;
|
||||
m_tile_idx_sub1 = m_tile_idx % TileSwizzleSubM;
|
||||
|
||||
uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * num_tile_n_;
|
||||
|
||||
uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
|
||||
|
||||
n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
|
||||
m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
|
||||
return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * TileSwizzleSubM,
|
||||
n_tile_idx_with_adapt);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get work range for a given block ID
|
||||
*/
|
||||
CK_TILE_DEVICE void
|
||||
GetBlockItr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const noexcept
|
||||
{
|
||||
if(block_idx < sk_num_big_blocks)
|
||||
{
|
||||
iter_start = block_idx * k_iters_per_big_block;
|
||||
iter_end = iter_start + k_iters_per_big_block;
|
||||
}
|
||||
else if(block_idx < sk_num_blocks)
|
||||
{
|
||||
iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
|
||||
(block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
|
||||
iter_end = iter_start + (k_iters_per_big_block - 1);
|
||||
}
|
||||
else if(block_idx >= dp_start_block_idx)
|
||||
{
|
||||
uint32_t sk_total_iters = GetSkTotalIters();
|
||||
uint32_t dp_iters_per_block = k_iters_per_tile.get();
|
||||
iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
|
||||
iter_end = iter_start + dp_iters_per_block;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get total number of iterations for sk tiles
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE uint32_t GetSkTotalIters() const noexcept
|
||||
{
|
||||
uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block +
|
||||
(sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1);
|
||||
return sk_total_iters;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get total number of sk tiles
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE uint32_t GetSkTiles() const noexcept
|
||||
{
|
||||
// tiles for sk
|
||||
uint32_t sk_total_iters = GetSkTotalIters();
|
||||
return k_iters_per_tile.div(sk_total_iters);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get length of loop iterations for stream-k loop
|
||||
*/
|
||||
CK_TILE_DEVICE uint32_t GetCurrentIterLength(uint32_t iter_start,
|
||||
uint32_t iter_end,
|
||||
uint32_t total_iter_length) const noexcept
|
||||
{
|
||||
uint32_t iter_length_mod, iter_length_quo /*unused*/;
|
||||
k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
|
||||
uint32_t total_iter_length_val = static_cast<uint32_t>(total_iter_length);
|
||||
uint32_t current_iter_length =
|
||||
min(iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod,
|
||||
total_iter_length_val);
|
||||
return current_iter_length;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get index of tile during a specified iteration
|
||||
*/
|
||||
CK_TILE_DEVICE uint32_t GetTileIdx(uint32_t iter) const noexcept
|
||||
{
|
||||
return k_iters_per_tile.div(iter);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get index of tile during a specified iteration
|
||||
*/
|
||||
CK_TILE_DEVICE void
|
||||
GetTileIdxWithOffset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const noexcept
|
||||
{
|
||||
uint32_t tile_idx_val = static_cast<uint32_t>(tile_idx);
|
||||
uint32_t iter_offset_val = static_cast<uint32_t>(iter_offset);
|
||||
k_iters_per_tile.divmod(iter, tile_idx_val, iter_offset_val);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the buffer space needed for accumulation
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE uint32_t GetWorkSpaceSizeForAcc(uint32_t acc_element_bytes) const noexcept
|
||||
{
|
||||
static constexpr uint32_t alignment = 128;
|
||||
uint32_t acc_buffer_bytes =
|
||||
MPerBlock * NPerBlock * GetTotalAccBuffers() * acc_element_bytes;
|
||||
return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the buffer space needed for the semaphore
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE uint32_t GetWorkSpaceSizeForSemaphore() const noexcept
|
||||
{
|
||||
return GetSkTiles() * sizeof(uint32_t);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the total buffer space needed for accumulation and the semaphore
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE uint32_t GetWorkSpaceSize(uint32_t acc_element_bytes) const noexcept
|
||||
{
|
||||
return GetWorkSpaceSizeForAcc(acc_element_bytes) + GetWorkSpaceSizeForSemaphore();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get location of intersection of tiles for reduction
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE uint32_t GetTileIntersections(uint32_t tiles_,
|
||||
const mdiv& equiv_tiles_) const noexcept
|
||||
{
|
||||
uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
|
||||
uint32_t max_equiv_tiles_ = equiv_tiles_.get() - 1;
|
||||
uint32_t quo_, rem_;
|
||||
equiv_tiles_.divmod(tile_idx_, quo_, rem_);
|
||||
return quo_ * max_equiv_tiles_ + rem_;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculate the number of tiles needed for the number of sk blocks
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE uint32_t GetTilesCoverSkBlock(uint32_t num_sk_blocks_,
|
||||
uint32_t iters_per_sk_block_) const noexcept
|
||||
{
|
||||
return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
|
||||
1);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculate the amount of total accumulation buffers required for stream-k
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE uint32_t GetTotalAccBuffers() const noexcept
|
||||
{
|
||||
uint32_t tiles_cover_big_blocks =
|
||||
GetTilesCoverSkBlock(sk_num_big_blocks, k_iters_per_big_block);
|
||||
uint32_t tiles_cover_little_blocks =
|
||||
GetTilesCoverSkBlock(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1);
|
||||
|
||||
uint32_t total_intersec_big = GetTileIntersections(tiles_cover_big_blocks, equiv_tiles_big);
|
||||
uint32_t total_intersec_little =
|
||||
GetTileIntersections(tiles_cover_little_blocks, equiv_tiles_little);
|
||||
|
||||
return sk_num_blocks + total_intersec_big + total_intersec_little;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculate offset based on tile index for big/little tiles
|
||||
*/
|
||||
CK_TILE_DEVICE uint32_t GetAccBufferOffsetFromTile(uint32_t tile_idx_) const noexcept
|
||||
{
|
||||
uint32_t tiles_cover_big_blocks =
|
||||
GetTilesCoverSkBlock(sk_num_big_blocks, k_iters_per_big_block);
|
||||
if(tile_idx_ < tiles_cover_big_blocks)
|
||||
{
|
||||
uint32_t touched_sk_blocks =
|
||||
(tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
|
||||
k_iters_per_big_block;
|
||||
uint32_t current_intersec = GetTileIntersections(tile_idx_, equiv_tiles_big);
|
||||
return touched_sk_blocks + current_intersec;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
|
||||
uint32_t tile_idx_little_reverse = GetSkTiles() - tile_idx_;
|
||||
uint32_t touched_sk_blocks =
|
||||
(tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
|
||||
iters_per_little_sk_block;
|
||||
uint32_t current_intersec =
|
||||
GetTileIntersections(tile_idx_little_reverse, equiv_tiles_little);
|
||||
return GetTotalAccBuffers() - (touched_sk_blocks + current_intersec);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculate offset based on block_idx index for big/little streamk blocks
|
||||
*/
|
||||
CK_TILE_DEVICE uint32_t GetAccBufferOffsetFromBlock(uint32_t block_idx_) const noexcept
|
||||
{
|
||||
uint32_t iters_per_big_sk_block = k_iters_per_big_block;
|
||||
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
|
||||
if(block_idx_ < sk_num_big_blocks)
|
||||
{
|
||||
uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
|
||||
k_iters_per_tile.get() - 1);
|
||||
uint32_t current_intersec = GetTileIntersections(touched_tiles, equiv_tiles_big);
|
||||
return block_idx_ + current_intersec;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
|
||||
uint32_t touched_tiles = k_iters_per_tile.div(
|
||||
block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
|
||||
uint32_t current_intersec = GetTileIntersections(touched_tiles, equiv_tiles_little);
|
||||
return GetTotalAccBuffers() - (block_idx_little_reverse + current_intersec);
|
||||
}
|
||||
}
|
||||
|
||||
// Getters for problem dimensions
|
||||
CK_TILE_HOST_DEVICE uint32_t GetNumTileM() const noexcept { return num_tile_m_; }
|
||||
CK_TILE_HOST_DEVICE uint32_t GetNumTileN() const noexcept { return num_tile_n_; }
|
||||
CK_TILE_HOST_DEVICE uint32_t GetNumTileK() const noexcept { return num_tile_k_; }
|
||||
|
||||
uint32_t sk_num_blocks;
|
||||
uint32_t sk_num_big_blocks;
|
||||
uint32_t dp_start_block_idx;
|
||||
uint32_t reduction_start_block_idx;
|
||||
uint32_t k_iters_per_big_block;
|
||||
mdiv2 n_tiles;
|
||||
mdiv k_iters_per_tile;
|
||||
mdiv equiv_tiles_big; // for reduction
|
||||
mdiv equiv_tiles_little; // for reduction
|
||||
|
||||
private:
|
||||
uint32_t M_, N_, K_;
|
||||
uint32_t num_tile_m_, num_tile_n_, num_tile_k_;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -9,15 +9,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum StreamKReductionStrategy : uint32_t
|
||||
{
|
||||
/// @brief Workgroups atomically add their results to the C tensor
|
||||
Atomic = 0u,
|
||||
/// @brief For a given tile in the C tensor, one workgroup accumulates results of other
|
||||
/// contributing workgroups
|
||||
Reduction = 1u
|
||||
};
|
||||
|
||||
/// @brief The Stream K GEMM kernel host arguments.
|
||||
///
|
||||
/// @par Overview
|
||||
@@ -37,7 +28,7 @@ struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<>
|
||||
index_t stride_B_,
|
||||
index_t stride_C_,
|
||||
StreamKReductionStrategy reduction_strategy_,
|
||||
index_t num_sk_blocks_ = -1)
|
||||
uint32_t num_sk_blocks_ = 0xffffffff)
|
||||
: UniversalGemmHostArgs<>({a_ptr_},
|
||||
{b_ptr_},
|
||||
{/*ds_ptr*/},
|
||||
@@ -56,7 +47,7 @@ struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<>
|
||||
}
|
||||
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy;
|
||||
index_t num_sk_blocks;
|
||||
uint32_t num_sk_blocks;
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
@@ -103,7 +94,7 @@ struct StreamKKernel
|
||||
/// @brief The strategy used by work groups to compute final results in C tensor.
|
||||
StreamKReductionStrategy reduction_strategy;
|
||||
/// @brief The number of stream k blocks.
|
||||
index_t num_sk_blocks;
|
||||
uint32_t num_sk_blocks;
|
||||
/// @brief A pointer to a buffer in device memory for accumulating partial via reduction
|
||||
/// strategy.
|
||||
void* workspace_ptr;
|
||||
@@ -152,29 +143,32 @@ struct StreamKKernel
|
||||
|
||||
CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args)
|
||||
{
|
||||
index_t occupancy = static_cast<index_t>(Occupancy());
|
||||
index_t num_cu = static_cast<index_t>(NumCU());
|
||||
uint32_t occupancy = static_cast<uint32_t>(Occupancy());
|
||||
uint32_t num_cu = static_cast<uint32_t>(NumCU());
|
||||
|
||||
return StreamKKernelArgs{
|
||||
{host_args.as_ptr,
|
||||
host_args.bs_ptr,
|
||||
host_args.ds_ptr,
|
||||
host_args.e_ptr,
|
||||
host_args.M,
|
||||
host_args.N,
|
||||
host_args.K,
|
||||
host_args.stride_As,
|
||||
host_args.stride_Bs,
|
||||
host_args.stride_Ds,
|
||||
host_args.stride_E,
|
||||
host_args.k_batch},
|
||||
host_args.reduction_strategy,
|
||||
host_args.num_sk_blocks,
|
||||
// The workspace pointer is set to nullptr because we must first
|
||||
// instantiate the TilePartitioner to get the necessary size
|
||||
/*workspace_ptr =*/nullptr,
|
||||
TilePartitioner{
|
||||
host_args.M, host_args.N, host_args.K, num_cu, occupancy, host_args.num_sk_blocks}};
|
||||
return StreamKKernelArgs{{host_args.as_ptr,
|
||||
host_args.bs_ptr,
|
||||
host_args.ds_ptr,
|
||||
host_args.e_ptr,
|
||||
host_args.M,
|
||||
host_args.N,
|
||||
host_args.K,
|
||||
host_args.stride_As,
|
||||
host_args.stride_Bs,
|
||||
host_args.stride_Ds,
|
||||
host_args.stride_E,
|
||||
host_args.k_batch},
|
||||
host_args.reduction_strategy,
|
||||
host_args.num_sk_blocks,
|
||||
// The workspace pointer is set to nullptr because we must first
|
||||
// instantiate the TilePartitioner to get the necessary size
|
||||
/*workspace_ptr =*/nullptr,
|
||||
TilePartitioner{static_cast<uint32_t>(host_args.M),
|
||||
static_cast<uint32_t>(host_args.N),
|
||||
static_cast<uint32_t>(host_args.K),
|
||||
num_cu,
|
||||
occupancy,
|
||||
host_args.num_sk_blocks}};
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool
|
||||
|
||||
Reference in New Issue
Block a user