mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE][REGRESSION] Correct blockSize in Generic2dBlockShape (c254f… (#2837)
* [CK_TILE][REGRESSION] Correct blockSize in Generic2dBlockShape (c254f3d7b4 )
WarpPerBlock_M * WarpPerBlock_N are not equal with ThreadPerBlock_M * ThreadPerBlock_N /warpSize. we should calculate BlockSize from WarpPerBlock_M * WarpPerBlock_N
To compatible with wave32, function GetBlockSize is added to calculate correct size in host side.
* fix blocksize for all kernel related with generic2dblockshap
* remove constexpr for blocks
This commit is contained in:
@@ -45,47 +45,57 @@ struct Generic2dBlockShape
|
||||
static constexpr index_t Block_N = BlockTile_::at(number<1>{});
|
||||
static constexpr index_t ThreadPerBlock_M = ThreadPerBlock_::at(number<0>{});
|
||||
static constexpr index_t ThreadPerBlock_N = ThreadPerBlock_::at(number<1>{});
|
||||
static constexpr index_t BlockSize = ThreadPerBlock_M * ThreadPerBlock_N;
|
||||
|
||||
// vector size along seq<M, N>
|
||||
static constexpr index_t Vector_M = Vector_::at(number<0>{});
|
||||
static constexpr index_t Vector_N = Vector_::at(number<1>{});
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N <= get_warp_size();
|
||||
static_assert((ThreadPerBlock_M * ThreadPerBlock_N) % get_warp_size() == 0);
|
||||
static constexpr index_t total_warps = (ThreadPerBlock_M * ThreadPerBlock_N) / get_warp_size();
|
||||
|
||||
// num warps along seq<M, N>, within each block
|
||||
static constexpr index_t WarpPerBlock_M = []() {
|
||||
template <bool isHostWave32>
|
||||
static constexpr index_t GetWarpPerBlock_M()
|
||||
{
|
||||
constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size();
|
||||
constexpr bool is_warp_per_row = ThreadPerBlock_N <= warp_size;
|
||||
static_assert((ThreadPerBlock_M * ThreadPerBlock_N) % warp_size == 0);
|
||||
constexpr index_t total_warps = (ThreadPerBlock_M * ThreadPerBlock_N) / warp_size;
|
||||
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(get_warp_size() % ThreadPerBlock_N == 0);
|
||||
return total_warps * (get_warp_size() / ThreadPerBlock_N);
|
||||
static_assert(warp_size % ThreadPerBlock_N == 0);
|
||||
return total_warps * (warp_size / ThreadPerBlock_N);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N / get_warp_size());
|
||||
return total_warps / (ThreadPerBlock_N / warp_size);
|
||||
}
|
||||
}();
|
||||
};
|
||||
|
||||
// num of warps along n
|
||||
static constexpr index_t WarpPerBlock_N = []() {
|
||||
template <bool isHostWave32>
|
||||
static constexpr index_t GetWarpPerBlock_N()
|
||||
{
|
||||
constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size();
|
||||
constexpr bool is_warp_per_row = ThreadPerBlock_N <= warp_size;
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(get_warp_size() % ThreadPerBlock_N == 0);
|
||||
static_assert(warp_size % ThreadPerBlock_N == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N % get_warp_size() == 0);
|
||||
return ThreadPerBlock_N / get_warp_size();
|
||||
static_assert(ThreadPerBlock_N % warp_size == 0);
|
||||
return ThreadPerBlock_N / warp_size;
|
||||
}
|
||||
}();
|
||||
}
|
||||
|
||||
static constexpr index_t WarpPerBlock_M = GetWarpPerBlock_M<false>();
|
||||
static constexpr index_t WarpPerBlock_N = GetWarpPerBlock_N<false>();
|
||||
|
||||
// warp size
|
||||
static constexpr index_t Warp_M = ThreadPerBlock_M / WarpPerBlock_M * Vector_M;
|
||||
static constexpr index_t Warp_N = ThreadPerBlock_N / WarpPerBlock_N * Vector_N;
|
||||
static constexpr index_t BlockSize = WarpPerBlock_M * WarpPerBlock_N * get_warp_size();
|
||||
static constexpr index_t Warp_M = ThreadPerBlock_M / WarpPerBlock_M * Vector_M;
|
||||
static constexpr index_t Warp_N = ThreadPerBlock_N / WarpPerBlock_N * Vector_N;
|
||||
static_assert(Warp_M % Vector_M == 0);
|
||||
static_assert(Warp_N % Vector_N == 0);
|
||||
static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0);
|
||||
@@ -98,6 +108,13 @@ struct Generic2dBlockShape
|
||||
// num of threads along seq<M, N>, within each warp
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
|
||||
|
||||
template <bool isHostWave32>
|
||||
static constexpr index_t GetBlockSize()
|
||||
{
|
||||
constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size();
|
||||
return GetWarpPerBlock_M<isHostWave32>() * GetWarpPerBlock_N<isHostWave32>() * warp_size;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user