mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
Universal streamk with atomics (#1360)
* universal streamk with atomics with ckprofiler support. grid_size and streamk strategy are tunable. grid_size of -1 leads to #WGs = maximum occupancy X num_CUs. implementation supports many different streamk policies: 1-tile, 2-tile, 3-tile and 4-tile. streamk strategy of -1 leads to default streamk policy (4-tile). * Update README.md * fixing clang-format issues * removed conflicts in struct members between streamk and universal streamk * corrected arg parsing for streamk and universal streamk * added stream-k policies for 3 tile and 4 tile * fixed argument type issue with parsing cmd args * changes suggested in PR review are made- removing comments and correcting copyright * file permissions updated * added default value support for grid_size and streamk-policy selection set to -1 * print messages for arguments * print messages for arguments * print messages for arguments1
This commit is contained in:
committed by
GitHub
parent
eaa870a1ab
commit
75e622f02f
@@ -1404,4 +1404,326 @@ struct BlockToCTileMap_GemmStreamK
|
||||
}
|
||||
};
|
||||
|
||||
template <uint32_t MPerBlock_,
|
||||
uint32_t NPerBlock_,
|
||||
uint32_t KPerBlock_,
|
||||
StreamKReductionStrategy ReductionStrategy_ = StreamKReductionStrategy::Atomic,
|
||||
uint32_t TileSwizzleSubM_ = 8,
|
||||
index_t GroupNum = 8,
|
||||
index_t M01_ = 4>
|
||||
struct BlockToCTileMap_GemmStreamK_v2
|
||||
{
|
||||
static constexpr uint32_t min_k_iters_per_sk_block = 2;
|
||||
static constexpr uint32_t MPerBlock = MPerBlock_;
|
||||
static constexpr uint32_t NPerBlock = NPerBlock_;
|
||||
static constexpr uint32_t KPerBlock = KPerBlock_;
|
||||
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_;
|
||||
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
|
||||
|
||||
//--------------------------------------
|
||||
// pass to device
|
||||
mutable uint32_t sk_num_blocks;
|
||||
uint32_t sk_num_big_blocks;
|
||||
uint32_t dp_start_block_idx;
|
||||
uint32_t reduction_start_block_idx;
|
||||
uint32_t k_iters_per_big_block;
|
||||
MDiv2 n_tiles;
|
||||
MDiv k_iters_per_tile;
|
||||
MDiv equiv_tiles_big; // for reduction
|
||||
MDiv equiv_tiles_little; // for reduction
|
||||
|
||||
// prefer construct on host
|
||||
__host__ __device__ BlockToCTileMap_GemmStreamK_v2(
|
||||
uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size = 1, uint32_t streamk_sel = 1)
|
||||
{
|
||||
// total output tiles
|
||||
uint32_t num_tiles =
|
||||
math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock);
|
||||
k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock));
|
||||
|
||||
uint32_t dp_tiles, dp_num_blocks, sk_total_iters;
|
||||
|
||||
// default to regular DP GEMM if sk blocks == 0
|
||||
if(streamk_sel == 0)
|
||||
{
|
||||
sk_num_blocks = 0;
|
||||
dp_tiles = num_tiles;
|
||||
sk_num_big_blocks = 0;
|
||||
k_iters_per_big_block = 0;
|
||||
|
||||
dp_num_blocks = num_tiles; // all tile to be dp block
|
||||
dp_start_block_idx = 0;
|
||||
sk_total_iters = 0; // clear this tiles
|
||||
}
|
||||
// 2-tile sk + DP GEMM
|
||||
else
|
||||
{
|
||||
|
||||
// check if there's enough work for DP+ stream-k
|
||||
bool bigEnough = num_tiles > grid_size;
|
||||
// select between stream-k strategies
|
||||
uint32_t sk_tiles = 0;
|
||||
if(streamk_sel == 1) // 1 tile stream-k
|
||||
{
|
||||
sk_tiles = bigEnough ? (num_tiles % grid_size) : num_tiles;
|
||||
}
|
||||
else if(streamk_sel == 2) // 2-tile stream-k
|
||||
{
|
||||
sk_tiles = bigEnough ? (grid_size + num_tiles % grid_size) : num_tiles;
|
||||
}
|
||||
else if(streamk_sel == 3) // 3-tile stream-k
|
||||
{
|
||||
sk_tiles = (num_tiles > (2 * grid_size)) ? (2 * grid_size + num_tiles % grid_size)
|
||||
: num_tiles;
|
||||
}
|
||||
else if(streamk_sel == 4) // 4-tile stream-k
|
||||
{
|
||||
sk_tiles = (num_tiles > (3 * grid_size)) ? (3 * grid_size + num_tiles % grid_size)
|
||||
: num_tiles;
|
||||
}
|
||||
sk_num_blocks = sk_tiles;
|
||||
// remaining tiles are DP tiles
|
||||
dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0;
|
||||
|
||||
sk_total_iters = k_iters_per_tile.get() * sk_tiles;
|
||||
|
||||
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
|
||||
// we need to decide how many iters for each sk block
|
||||
// let m = k_iters_per_sk_block
|
||||
// some of the sk block (little) will cover m iters, some (big) will cover m+1
|
||||
// we have
|
||||
// 1) l + b = sk_blocks
|
||||
// 2) l * m + b * (m + 1) = sk_total_iters
|
||||
// => (l + b) * m + b = sk_total_iters
|
||||
// => sk_blocks * m + b = sk_total_iters
|
||||
// => b = sk_total_iters - m * sk_blocks
|
||||
// NOTE: big could be zero
|
||||
uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
|
||||
sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
|
||||
k_iters_per_big_block = k_iters_per_sk_block + 1;
|
||||
|
||||
dp_num_blocks = dp_tiles;
|
||||
dp_start_block_idx = sk_num_blocks;
|
||||
}
|
||||
|
||||
n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
|
||||
// using multiple blocks for parallel reduction
|
||||
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
|
||||
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get());
|
||||
uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
|
||||
equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
|
||||
equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
return M0 * N0;
|
||||
}
|
||||
__host__ __device__ uint32_t get_sk_total_iters() const
|
||||
{
|
||||
uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block +
|
||||
(sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1);
|
||||
return sk_total_iters;
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_sk_tiles() const
|
||||
{
|
||||
// tiles for sk
|
||||
uint32_t sk_total_iters = get_sk_total_iters();
|
||||
return k_iters_per_tile.div(sk_total_iters);
|
||||
}
|
||||
|
||||
__host__ __device__ index_t get_grid_dims() const
|
||||
{
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
// return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
|
||||
return reduction_start_block_idx + get_sk_tiles();
|
||||
}
|
||||
else
|
||||
return reduction_start_block_idx;
|
||||
}
|
||||
|
||||
__device__ uint32_t get_block_idx() const
|
||||
{
|
||||
// TODO: swizzle block index for better locality
|
||||
return __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
}
|
||||
|
||||
__device__ void
|
||||
get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const
|
||||
{
|
||||
if(block_idx < sk_num_big_blocks)
|
||||
{
|
||||
iter_start = block_idx * k_iters_per_big_block;
|
||||
iter_end = iter_start + k_iters_per_big_block;
|
||||
}
|
||||
else if(block_idx < sk_num_blocks)
|
||||
{
|
||||
iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
|
||||
(block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
|
||||
iter_end = iter_start + (k_iters_per_big_block - 1);
|
||||
}
|
||||
else if(block_idx >= dp_start_block_idx)
|
||||
{
|
||||
uint32_t sk_total_iters = get_sk_total_iters();
|
||||
uint32_t dp_iters_per_block = k_iters_per_tile.get();
|
||||
iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
|
||||
iter_end = iter_start + dp_iters_per_block;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ uint32_t get_current_iter_length(uint32_t iter_start,
|
||||
uint32_t iter_end,
|
||||
uint32_t total_iter_length) const
|
||||
{
|
||||
uint32_t iter_length_mod, iter_length_quo /*unused*/;
|
||||
k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
|
||||
uint32_t current_iter_length = math::min(
|
||||
iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length);
|
||||
return current_iter_length;
|
||||
}
|
||||
|
||||
__device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); }
|
||||
|
||||
__device__ void
|
||||
get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const
|
||||
{
|
||||
k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
|
||||
}
|
||||
|
||||
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
|
||||
{
|
||||
uint32_t m_tile_idx, n_tile_idx;
|
||||
uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock);
|
||||
n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx);
|
||||
|
||||
// // swizzle tile
|
||||
uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock);
|
||||
|
||||
uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m;
|
||||
|
||||
const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem))
|
||||
? tile_swizzle_sub_m
|
||||
: tile_swizzle_sub_m_rem;
|
||||
|
||||
uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
|
||||
m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m;
|
||||
m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m;
|
||||
|
||||
uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value;
|
||||
|
||||
uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
|
||||
|
||||
n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
|
||||
m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
|
||||
return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
|
||||
n_tile_idx_with_adapt);
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
|
||||
{
|
||||
static constexpr uint32_t alignment = 128;
|
||||
uint32_t acc_buffer_bytes =
|
||||
MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes;
|
||||
return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_workspace_size_for_semaphore() const
|
||||
{
|
||||
return get_sk_tiles() * sizeof(uint32_t);
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
|
||||
{
|
||||
return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore();
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_,
|
||||
const MDiv& equiv_tiles_) const
|
||||
{
|
||||
uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
|
||||
uint32_t max_equiv_tiles_ = equiv_tiles_.get() - 1;
|
||||
uint32_t quo_, rem_;
|
||||
equiv_tiles_.divmod(tile_idx_, quo_, rem_);
|
||||
return quo_ * max_equiv_tiles_ + rem_;
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_,
|
||||
uint32_t iters_per_sk_block_) const
|
||||
{
|
||||
return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
|
||||
1);
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_total_acc_buffers() const
|
||||
{
|
||||
uint32_t tiles_cover_big_blocks =
|
||||
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
|
||||
uint32_t tiles_cover_little_blocks =
|
||||
get_tiles_cover_sk_block(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1);
|
||||
|
||||
uint32_t total_intersec_big =
|
||||
get_tile_intersections(tiles_cover_big_blocks, equiv_tiles_big);
|
||||
uint32_t total_intersec_little =
|
||||
get_tile_intersections(tiles_cover_little_blocks, equiv_tiles_little);
|
||||
|
||||
return sk_num_blocks + total_intersec_big + total_intersec_little;
|
||||
}
|
||||
|
||||
__device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
|
||||
{
|
||||
// TODO: from big to little
|
||||
uint32_t tiles_cover_big_blocks =
|
||||
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
|
||||
if(tile_idx_ < tiles_cover_big_blocks)
|
||||
{
|
||||
uint32_t touched_sk_blocks =
|
||||
(tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
|
||||
k_iters_per_big_block;
|
||||
uint32_t current_intersec = get_tile_intersections(tile_idx_, equiv_tiles_big);
|
||||
return touched_sk_blocks + current_intersec;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
|
||||
uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_;
|
||||
uint32_t touched_sk_blocks =
|
||||
(tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
|
||||
iters_per_little_sk_block;
|
||||
uint32_t current_intersec =
|
||||
get_tile_intersections(tile_idx_little_reverse, equiv_tiles_little);
|
||||
return get_total_acc_buffers() - (touched_sk_blocks + current_intersec);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
|
||||
{
|
||||
uint32_t iters_per_big_sk_block = k_iters_per_big_block;
|
||||
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
|
||||
if(block_idx_ < sk_num_big_blocks)
|
||||
{
|
||||
uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
|
||||
k_iters_per_tile.get() - 1);
|
||||
uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_big);
|
||||
return block_idx_ + current_intersec;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
|
||||
uint32_t touched_tiles = k_iters_per_tile.div(
|
||||
block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
|
||||
uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_little);
|
||||
return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user