mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
[CK_TILE][REGRESSION] Correct blockSize in Generic2dBlockShape (c254f… (#2837)
* [CK_TILE][REGRESSION] Correct blockSize in Generic2dBlockShape (c254f3d7b4 )
WarpPerBlock_M * WarpPerBlock_N are not equal with ThreadPerBlock_M * ThreadPerBlock_N /warpSize. we should calculate BlockSize from WarpPerBlock_M * WarpPerBlock_N
To compatible with wave32, function GetBlockSize is added to calculate correct size in host side.
* fix blocksize for all kernel related with generic2dblockshap
* remove constexpr for blocks
This commit is contained in:
@@ -95,7 +95,11 @@ struct AddRmsnorm2dRdquantFwd
|
||||
return dim3(integer_divide_ceil(hargs.m, Block_M));
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
|
||||
CK_TILE_HOST static constexpr auto BlockSize()
|
||||
{
|
||||
return is_wave32() ? Problem::BlockShape::template GetBlockSize<true>()
|
||||
: Problem::BlockShape::template GetBlockSize<false>();
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
|
||||
@@ -45,47 +45,57 @@ struct Generic2dBlockShape
|
||||
static constexpr index_t Block_N = BlockTile_::at(number<1>{});
|
||||
static constexpr index_t ThreadPerBlock_M = ThreadPerBlock_::at(number<0>{});
|
||||
static constexpr index_t ThreadPerBlock_N = ThreadPerBlock_::at(number<1>{});
|
||||
static constexpr index_t BlockSize = ThreadPerBlock_M * ThreadPerBlock_N;
|
||||
|
||||
// vector size along seq<M, N>
|
||||
static constexpr index_t Vector_M = Vector_::at(number<0>{});
|
||||
static constexpr index_t Vector_N = Vector_::at(number<1>{});
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N <= get_warp_size();
|
||||
static_assert((ThreadPerBlock_M * ThreadPerBlock_N) % get_warp_size() == 0);
|
||||
static constexpr index_t total_warps = (ThreadPerBlock_M * ThreadPerBlock_N) / get_warp_size();
|
||||
|
||||
// num warps along seq<M, N>, within each block
|
||||
static constexpr index_t WarpPerBlock_M = []() {
|
||||
template <bool isHostWave32>
|
||||
static constexpr index_t GetWarpPerBlock_M()
|
||||
{
|
||||
constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size();
|
||||
constexpr bool is_warp_per_row = ThreadPerBlock_N <= warp_size;
|
||||
static_assert((ThreadPerBlock_M * ThreadPerBlock_N) % warp_size == 0);
|
||||
constexpr index_t total_warps = (ThreadPerBlock_M * ThreadPerBlock_N) / warp_size;
|
||||
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(get_warp_size() % ThreadPerBlock_N == 0);
|
||||
return total_warps * (get_warp_size() / ThreadPerBlock_N);
|
||||
static_assert(warp_size % ThreadPerBlock_N == 0);
|
||||
return total_warps * (warp_size / ThreadPerBlock_N);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N / get_warp_size());
|
||||
return total_warps / (ThreadPerBlock_N / warp_size);
|
||||
}
|
||||
}();
|
||||
};
|
||||
|
||||
// num of warps along n
|
||||
static constexpr index_t WarpPerBlock_N = []() {
|
||||
template <bool isHostWave32>
|
||||
static constexpr index_t GetWarpPerBlock_N()
|
||||
{
|
||||
constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size();
|
||||
constexpr bool is_warp_per_row = ThreadPerBlock_N <= warp_size;
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(get_warp_size() % ThreadPerBlock_N == 0);
|
||||
static_assert(warp_size % ThreadPerBlock_N == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N % get_warp_size() == 0);
|
||||
return ThreadPerBlock_N / get_warp_size();
|
||||
static_assert(ThreadPerBlock_N % warp_size == 0);
|
||||
return ThreadPerBlock_N / warp_size;
|
||||
}
|
||||
}();
|
||||
}
|
||||
|
||||
static constexpr index_t WarpPerBlock_M = GetWarpPerBlock_M<false>();
|
||||
static constexpr index_t WarpPerBlock_N = GetWarpPerBlock_N<false>();
|
||||
|
||||
// warp size
|
||||
static constexpr index_t Warp_M = ThreadPerBlock_M / WarpPerBlock_M * Vector_M;
|
||||
static constexpr index_t Warp_N = ThreadPerBlock_N / WarpPerBlock_N * Vector_N;
|
||||
static constexpr index_t BlockSize = WarpPerBlock_M * WarpPerBlock_N * get_warp_size();
|
||||
static constexpr index_t Warp_M = ThreadPerBlock_M / WarpPerBlock_M * Vector_M;
|
||||
static constexpr index_t Warp_N = ThreadPerBlock_N / WarpPerBlock_N * Vector_N;
|
||||
static_assert(Warp_M % Vector_M == 0);
|
||||
static_assert(Warp_N % Vector_N == 0);
|
||||
static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0);
|
||||
@@ -98,6 +108,13 @@ struct Generic2dBlockShape
|
||||
// num of threads along seq<M, N>, within each warp
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
|
||||
|
||||
template <bool isHostWave32>
|
||||
static constexpr index_t GetBlockSize()
|
||||
{
|
||||
constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size();
|
||||
return GetWarpPerBlock_M<isHostWave32>() * GetWarpPerBlock_N<isHostWave32>() * warp_size;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -134,7 +134,11 @@ struct Layernorm2dFwd
|
||||
return dim3(integer_divide_ceil(hargs.m, Block_M));
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
|
||||
CK_TILE_HOST static constexpr auto BlockSize()
|
||||
{
|
||||
return is_wave32() ? Problem::BlockShape::template GetBlockSize<true>()
|
||||
: Problem::BlockShape::template GetBlockSize<false>();
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
|
||||
@@ -124,7 +124,11 @@ struct Rmsnorm2dFwd
|
||||
return dim3(integer_divide_ceil(hargs.m, Block_M));
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
|
||||
CK_TILE_HOST static constexpr auto BlockSize()
|
||||
{
|
||||
return is_wave32() ? Problem::BlockShape::template GetBlockSize<true>()
|
||||
: Problem::BlockShape::template GetBlockSize<false>();
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
|
||||
@@ -93,7 +93,11 @@ struct MoeSmoothquant
|
||||
return dim3(hargs.topk, integer_divide_ceil(hargs.tokens, Block_M), 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
|
||||
CK_TILE_HOST static constexpr auto BlockSize()
|
||||
{
|
||||
return is_wave32() ? Problem::BlockShape::template GetBlockSize<true>()
|
||||
: Problem::BlockShape::template GetBlockSize<false>();
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
|
||||
@@ -82,7 +82,11 @@ struct Smoothquant
|
||||
return dim3(integer_divide_ceil(hargs.m, Block_M));
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
|
||||
CK_TILE_HOST static constexpr auto BlockSize()
|
||||
{
|
||||
return is_wave32() ? Problem::BlockShape::template GetBlockSize<true>()
|
||||
: Problem::BlockShape::template GetBlockSize<false>();
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
|
||||
Reference in New Issue
Block a user