From 1e9b1826b5dd61323db88eea5f7ee167afb54907 Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Tue, 16 Sep 2025 23:47:55 +0800 Subject: [PATCH] =?UTF-8?q?[CK=5FTILE][REGRESSION]=20Correct=20blockSize?= =?UTF-8?q?=20in=20Generic2dBlockShape=20(c254f=E2=80=A6=20(#2837)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [CK_TILE][REGRESSION] Correct blockSize in Generic2dBlockShape (5b17f135b7e2c ) WarpPerBlock_M * WarpPerBlock_N are not equal with ThreadPerBlock_M * ThreadPerBlock_N /warpSize. we should calculate BlockSize from WarpPerBlock_M * WarpPerBlock_N To compatible with wave32, function GetBlockSize is added to calculate correct size in host side. * fix blocksize for all kernel related with generic2dblockshap * remove constexpr for blocks [ROCm/composable_kernel commit: b7a806f2442ed04db9e835e3e4e14aaebe3db9b4] --- .../add_rmsnorm2d_rdquant_fwd_kernel.hpp | 6 ++- .../ops/common/generic_2d_block_shape.hpp | 51 ++++++++++++------- .../kernel/layernorm2d_fwd_kernel.hpp | 6 ++- .../rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp | 6 ++- .../kernel/moe_smoothquant_kernel.hpp | 6 ++- .../smoothquant/kernel/smoothquant_kernel.hpp | 6 ++- ..._rmsnorm2d_rdquant_fwd_instance_common.hpp | 2 +- .../moe_smoothquant_instance_common.hpp | 2 +- test/ck_tile/rmsnorm2d/generate.py | 2 +- .../instances/smoothquant_instance_common.hpp | 2 +- 10 files changed, 63 insertions(+), 26 deletions(-) diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp index c7717f08cd..b6eac45285 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp @@ -95,7 +95,11 @@ struct AddRmsnorm2dRdquantFwd return dim3(integer_divide_ceil(hargs.m, Block_M)); } - CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? Problem::BlockShape::template GetBlockSize() + : Problem::BlockShape::template GetBlockSize(); + } // clang-format off template struct t2s; diff --git a/include/ck_tile/ops/common/generic_2d_block_shape.hpp b/include/ck_tile/ops/common/generic_2d_block_shape.hpp index 333762e5d7..9c5d99efc3 100644 --- a/include/ck_tile/ops/common/generic_2d_block_shape.hpp +++ b/include/ck_tile/ops/common/generic_2d_block_shape.hpp @@ -45,47 +45,57 @@ struct Generic2dBlockShape static constexpr index_t Block_N = BlockTile_::at(number<1>{}); static constexpr index_t ThreadPerBlock_M = ThreadPerBlock_::at(number<0>{}); static constexpr index_t ThreadPerBlock_N = ThreadPerBlock_::at(number<1>{}); - static constexpr index_t BlockSize = ThreadPerBlock_M * ThreadPerBlock_N; // vector size along seq static constexpr index_t Vector_M = Vector_::at(number<0>{}); static constexpr index_t Vector_N = Vector_::at(number<1>{}); - static constexpr bool is_warp_per_row = ThreadPerBlock_N <= get_warp_size(); - static_assert((ThreadPerBlock_M * ThreadPerBlock_N) % get_warp_size() == 0); - static constexpr index_t total_warps = (ThreadPerBlock_M * ThreadPerBlock_N) / get_warp_size(); - // num warps along seq, within each block - static constexpr index_t WarpPerBlock_M = []() { + template + static constexpr index_t GetWarpPerBlock_M() + { + constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size(); + constexpr bool is_warp_per_row = ThreadPerBlock_N <= warp_size; + static_assert((ThreadPerBlock_M * ThreadPerBlock_N) % warp_size == 0); + constexpr index_t total_warps = (ThreadPerBlock_M * ThreadPerBlock_N) / warp_size; + if constexpr(is_warp_per_row) { - static_assert(get_warp_size() % ThreadPerBlock_N == 0); - return total_warps * (get_warp_size() / ThreadPerBlock_N); + static_assert(warp_size % ThreadPerBlock_N == 0); + return total_warps * (warp_size / ThreadPerBlock_N); } else { // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N / get_warp_size()); + return total_warps / (ThreadPerBlock_N / warp_size); } - }(); + }; // num of warps along n - static constexpr index_t WarpPerBlock_N = []() { + template + static constexpr index_t GetWarpPerBlock_N() + { + constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size(); + constexpr bool is_warp_per_row = ThreadPerBlock_N <= warp_size; if constexpr(is_warp_per_row) { - static_assert(get_warp_size() % ThreadPerBlock_N == 0); + static_assert(warp_size % ThreadPerBlock_N == 0); return 1; } else { - static_assert(ThreadPerBlock_N % get_warp_size() == 0); - return ThreadPerBlock_N / get_warp_size(); + static_assert(ThreadPerBlock_N % warp_size == 0); + return ThreadPerBlock_N / warp_size; } - }(); + } + + static constexpr index_t WarpPerBlock_M = GetWarpPerBlock_M(); + static constexpr index_t WarpPerBlock_N = GetWarpPerBlock_N(); // warp size - static constexpr index_t Warp_M = ThreadPerBlock_M / WarpPerBlock_M * Vector_M; - static constexpr index_t Warp_N = ThreadPerBlock_N / WarpPerBlock_N * Vector_N; + static constexpr index_t BlockSize = WarpPerBlock_M * WarpPerBlock_N * get_warp_size(); + static constexpr index_t Warp_M = ThreadPerBlock_M / WarpPerBlock_M * Vector_M; + static constexpr index_t Warp_N = ThreadPerBlock_N / WarpPerBlock_N * Vector_N; static_assert(Warp_M % Vector_M == 0); static_assert(Warp_N % Vector_N == 0); static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0); @@ -98,6 +108,13 @@ struct Generic2dBlockShape // num of threads along seq, within each warp static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; + + template + static constexpr index_t GetBlockSize() + { + constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size(); + return GetWarpPerBlock_M() * GetWarpPerBlock_N() * warp_size; + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp index 6998b358d8..0181a3291f 100644 --- a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -134,7 +134,11 @@ struct Layernorm2dFwd return dim3(integer_divide_ceil(hargs.m, Block_M)); } - CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? Problem::BlockShape::template GetBlockSize() + : Problem::BlockShape::template GetBlockSize(); + } // clang-format off template struct t2s; diff --git a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp index e7f4ce0ba8..32586a6343 100644 --- a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp @@ -124,7 +124,11 @@ struct Rmsnorm2dFwd return dim3(integer_divide_ceil(hargs.m, Block_M)); } - CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? Problem::BlockShape::template GetBlockSize() + : Problem::BlockShape::template GetBlockSize(); + } // clang-format off template struct t2s; diff --git a/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp b/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp index b70e996617..2553b19fd8 100644 --- a/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp +++ b/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp @@ -93,7 +93,11 @@ struct MoeSmoothquant return dim3(hargs.topk, integer_divide_ceil(hargs.tokens, Block_M), 1); } - CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? Problem::BlockShape::template GetBlockSize() + : Problem::BlockShape::template GetBlockSize(); + } // clang-format off template struct t2s; diff --git a/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp b/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp index 7dc913901e..e0ea9692c5 100644 --- a/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp +++ b/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp @@ -82,7 +82,11 @@ struct Smoothquant return dim3(integer_divide_ceil(hargs.m, Block_M)); } - CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? Problem::BlockShape::template GetBlockSize() + : Problem::BlockShape::template GetBlockSize(); + } // clang-format off template struct t2s; diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp index dd90034064..d997596414 100644 --- a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp @@ -58,7 +58,7 @@ float add_rmsnorm2d_rdquant_fwd_(const S& s, A a) using Kernel = ck_tile::AddRmsnorm2dRdquantFwd; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); diff --git a/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp b/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp index f2875c72c8..c6ef822f64 100644 --- a/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp +++ b/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp @@ -53,7 +53,7 @@ float moe_smoothquant_(const S& s, A a) using Kernel = ck_tile::MoeSmoothquant; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); diff --git a/test/ck_tile/rmsnorm2d/generate.py b/test/ck_tile/rmsnorm2d/generate.py index 5eded8b310..3bcc427e83 100644 --- a/test/ck_tile/rmsnorm2d/generate.py +++ b/test/ck_tile/rmsnorm2d/generate.py @@ -201,7 +201,7 @@ float rmsnorm2d_fwd_(const S& s, A a) using Kernel = ck_tile::Rmsnorm2dFwd; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); diff --git a/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp b/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp index 8929289cdb..138afcffaf 100644 --- a/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp +++ b/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp @@ -49,7 +49,7 @@ float smoothquant_(const S& s, A a) using Kernel = ck_tile::Smoothquant; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a);