[CK_TILE][REGRESSION] Correct blockSize in Generic2dBlockShape (c254f… (#2837)

* [CK_TILE][REGRESSION] Correct blockSize in Generic2dBlockShape (c254f3d7b4 )

WarpPerBlock_M * WarpPerBlock_N are not equal with ThreadPerBlock_M * ThreadPerBlock_N /warpSize. we should calculate BlockSize from WarpPerBlock_M * WarpPerBlock_N

To compatible with wave32, function GetBlockSize is added to calculate correct size in host side.

* fix blocksize for all kernel related with generic2dblockshap

* remove constexpr for blocks
This commit is contained in:
linqunAMD
2025-09-16 23:47:55 +08:00
committed by GitHub
parent 671adb59c5
commit b7a806f244
10 changed files with 63 additions and 26 deletions

View File

@@ -45,47 +45,57 @@ struct Generic2dBlockShape
static constexpr index_t Block_N = BlockTile_::at(number<1>{});
static constexpr index_t ThreadPerBlock_M = ThreadPerBlock_::at(number<0>{});
static constexpr index_t ThreadPerBlock_N = ThreadPerBlock_::at(number<1>{});
static constexpr index_t BlockSize = ThreadPerBlock_M * ThreadPerBlock_N;
// vector size along seq<M, N>
static constexpr index_t Vector_M = Vector_::at(number<0>{});
static constexpr index_t Vector_N = Vector_::at(number<1>{});
static constexpr bool is_warp_per_row = ThreadPerBlock_N <= get_warp_size();
static_assert((ThreadPerBlock_M * ThreadPerBlock_N) % get_warp_size() == 0);
static constexpr index_t total_warps = (ThreadPerBlock_M * ThreadPerBlock_N) / get_warp_size();
// num warps along seq<M, N>, within each block
static constexpr index_t WarpPerBlock_M = []() {
template <bool isHostWave32>
static constexpr index_t GetWarpPerBlock_M()
{
constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size();
constexpr bool is_warp_per_row = ThreadPerBlock_N <= warp_size;
static_assert((ThreadPerBlock_M * ThreadPerBlock_N) % warp_size == 0);
constexpr index_t total_warps = (ThreadPerBlock_M * ThreadPerBlock_N) / warp_size;
if constexpr(is_warp_per_row)
{
static_assert(get_warp_size() % ThreadPerBlock_N == 0);
return total_warps * (get_warp_size() / ThreadPerBlock_N);
static_assert(warp_size % ThreadPerBlock_N == 0);
return total_warps * (warp_size / ThreadPerBlock_N);
}
else
{
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N / get_warp_size());
return total_warps / (ThreadPerBlock_N / warp_size);
}
}();
};
// num of warps along n
static constexpr index_t WarpPerBlock_N = []() {
template <bool isHostWave32>
static constexpr index_t GetWarpPerBlock_N()
{
constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size();
constexpr bool is_warp_per_row = ThreadPerBlock_N <= warp_size;
if constexpr(is_warp_per_row)
{
static_assert(get_warp_size() % ThreadPerBlock_N == 0);
static_assert(warp_size % ThreadPerBlock_N == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N % get_warp_size() == 0);
return ThreadPerBlock_N / get_warp_size();
static_assert(ThreadPerBlock_N % warp_size == 0);
return ThreadPerBlock_N / warp_size;
}
}();
}
static constexpr index_t WarpPerBlock_M = GetWarpPerBlock_M<false>();
static constexpr index_t WarpPerBlock_N = GetWarpPerBlock_N<false>();
// warp size
static constexpr index_t Warp_M = ThreadPerBlock_M / WarpPerBlock_M * Vector_M;
static constexpr index_t Warp_N = ThreadPerBlock_N / WarpPerBlock_N * Vector_N;
static constexpr index_t BlockSize = WarpPerBlock_M * WarpPerBlock_N * get_warp_size();
static constexpr index_t Warp_M = ThreadPerBlock_M / WarpPerBlock_M * Vector_M;
static constexpr index_t Warp_N = ThreadPerBlock_N / WarpPerBlock_N * Vector_N;
static_assert(Warp_M % Vector_M == 0);
static_assert(Warp_N % Vector_N == 0);
static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0);
@@ -98,6 +108,13 @@ struct Generic2dBlockShape
// num of threads along seq<M, N>, within each warp
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
template <bool isHostWave32>
static constexpr index_t GetBlockSize()
{
constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size();
return GetWarpPerBlock_M<isHostWave32>() * GetWarpPerBlock_N<isHostWave32>() * warp_size;
}
};
} // namespace ck_tile