mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_TILE] fix example reduces, permute and elementwise on gfx11 & gfx12 (#2810)
1. Refine Reduce2dShape to support both wave32 and wave64 2. Fix example reduce, permute and elementwise on gfx11 and gfx12 --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -26,6 +26,10 @@ struct Reduce
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
CK_TILE_HOST static constexpr auto BlockSize()
|
||||
{
|
||||
return is_wave32() ? kBlockSize / 2 : kBlockSize;
|
||||
}
|
||||
|
||||
private:
|
||||
// Helper function to calculate optimal vector size for input tensor
|
||||
|
||||
@@ -25,11 +25,18 @@ struct Reduce2dShape
|
||||
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
|
||||
static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});
|
||||
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N;
|
||||
static constexpr index_t RepeatInWarp =
|
||||
Warp_M * Warp_N / ThreadTile_M / ThreadTile_N / ck_tile::get_warp_size();
|
||||
static constexpr index_t RepeatInWarp_M =
|
||||
(Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? RepeatInWarp : 1;
|
||||
static constexpr index_t RepeatInWarp_N =
|
||||
(Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? 1 : RepeatInWarp;
|
||||
|
||||
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
|
||||
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M / RepeatInWarp_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N / RepeatInWarp_N;
|
||||
|
||||
static constexpr index_t Repeat_M = Block_M * RepeatInWarp_M / (WarpPerBlock_M * Warp_M);
|
||||
static constexpr index_t Repeat_N = Block_N * RepeatInWarp_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
static constexpr index_t BlockSize =
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
|
||||
Reference in New Issue
Block a user