From 0daa4023fc316d2d576ef14517e652494c0a3c53 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Tue, 16 Sep 2025 16:13:22 +0000 Subject: [PATCH] Merge commit 'b7a806f2442ed04db9e835e3e4e14aaebe3db9b4' into develop --- ...d_contraction_multiple_d_wmma_cshuffle.hpp | 8 ++- ...ise_batched_gemm_gemm_wmma_cshuffle_v3.hpp | 51 ++++++++++------ ...atched_gemm_softmax_gemm_wmma_cshuffle.hpp | 59 ++++++++++++------- .../gpu/grid/gridwise_fpAintB_gemm_wmma.hpp | 32 ++++++---- ...gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 28 +++++++-- .../gpu/grid/gridwise_gemm_wmma.hpp | 32 ++++++---- .../add_rmsnorm2d_rdquant_fwd_kernel.hpp | 6 +- .../ops/common/generic_2d_block_shape.hpp | 51 ++++++++++------ .../kernel/layernorm2d_fwd_kernel.hpp | 6 +- .../rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp | 6 +- .../kernel/moe_smoothquant_kernel.hpp | 6 +- .../smoothquant/kernel/smoothquant_kernel.hpp | 6 +- ..._rmsnorm2d_rdquant_fwd_instance_common.hpp | 2 +- .../moe_smoothquant_instance_common.hpp | 2 +- test/ck_tile/rmsnorm2d/generate.py | 2 +- .../instances/smoothquant_instance_common.hpp | 2 +- 16 files changed, 203 insertions(+), 96 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp index ab3f3856aa..537e6dab28 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include +#include "ck/utility/env.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" @@ -853,7 +854,10 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle arg.e_grid_desc_m_n_, arg.block_2_ctile_map_)) { - printf("GridwiseOp: Validity check failure\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: Validity check failure\n"); + } return false; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp index b61c7a09eb..fa7eb4faaa 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp @@ -398,41 +398,54 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) { - print("GridwiseOp: M/N Length err, A_M/N = %d, %d | C_M/N = %d, %d\n", - M, - N, - c_grid_desc_m_n.GetLength(I0), - c_grid_desc_m_n.GetLength(I1)); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + print("GridwiseOp: M/N Length err, A_M/N = %d, %d | C_M/N = %d, %d\n", + M, + N, + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + } return false; } if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0)) { - print("GridwiseOp: M/L/K/N Division err, M/L/K/N = %d, %d, %d, %d | M/L/K/NPerBlock = " - "%d, %d, %d, %d\n", - M, - L, - K, - N, - MPerBlock, - LPerBlock, - KPerBlock, - NPerBlock); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + print("GridwiseOp: M/L/K/N Division err, M/L/K/N = %d, %d, %d, %d | " + "M/L/K/NPerBlock = " + "%d, %d, %d, %d\n", + M, + L, + K, + N, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock); + } return false; } // check gemm1 gridwise gemm pipeline if(!(LPerBlock % LTilePerBlock == 0)) { - print("GridwiseOp: inner loop division, L/LTilePerblock: %d, %d\n", - LPerBlock, - LTilePerBlock); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + print("GridwiseOp: inner loop division, L/LTilePerblock: %d, %d\n", + LPerBlock, + LTilePerBlock); + } return false; } if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) { - print("GridwiseOp: invalid block_2_ctile_map\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + print("GridwiseOp: invalid block_2_ctile_map\n"); + } return false; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp index 1754e07e6a..502c449ef1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/utility/env.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -569,26 +570,33 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) { - printf("GridwiseOp: M/N Length err, A_M/N = %d, %d | C_M/N = %d, %d\n", - M, - N, - c_grid_desc_m_n.GetLength(I0), - c_grid_desc_m_n.GetLength(I1)); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: M/N Length err, A_M/N = %d, %d | C_M/N = %d, %d\n", + M, + N, + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + } return false; } if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0)) { - printf("GridwiseOp: M/L/K/N Division err, M/L/K/N = %d, %d, %d, %d | M/L/K/NPerBlock = " - "%d, %d, %d, %d\n", - M, - L, - K, - N, - MPerBlock, - LPerBlock, - KPerBlock, - NPerBlock); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: M/L/K/N Division err, M/L/K/N = %d, %d, %d, %d | " + "M/L/K/NPerBlock = " + "%d, %d, %d, %d\n", + M, + L, + K, + N, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock); + } return false; } @@ -596,23 +604,32 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma const auto num_gemm0_k_loop = K / KPerBlock; if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop)) { - printf("GridwiseOp: outer loop unsupport\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: outer loop unsupport\n"); + } return false; } // check gemm1 gridwise gemm pipeline if(!(LPerBlock % LTilePerBlock == 0)) { - printf("GridwiseOp: inner loop division, L/LTilePerblock: %d, %d\n", - LPerBlock, - LTilePerBlock); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: inner loop division, L/LTilePerblock: %d, %d\n", + LPerBlock, + LTilePerBlock); + } return false; } const auto num_gemm1_k_inner_loop = LPerBlock / LTilePerBlock; if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop)) { - printf("GridwiseOp: inner loop unsupport\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: inner loop unsupport\n"); + } return false; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp index 8011fa56d3..c8b154228f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/utility/env.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -466,20 +467,26 @@ struct GridwiseFpAintBGemm_Wmma if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && K == GetBProblemsizeNK()[I1])) { - printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n", - GetAProblemsizeMK()[I0], - GetAProblemsizeMK()[I1], - GetBProblemsizeNK()[I0], - GetBProblemsizeNK()[I1], - c_grid_desc_m_n.GetLength(I0), - c_grid_desc_m_n.GetLength(I1)); - printf("GridwiseOp err: ProblemSize check"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n", + GetAProblemsizeMK()[I0], + GetAProblemsizeMK()[I1], + GetBProblemsizeNK()[I0], + GetBProblemsizeNK()[I1], + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + printf("GridwiseOp err: ProblemSize check"); + } return false; } if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) { - printf("GridwiseOp err: ProblemSize division"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp err: ProblemSize division"); + } return false; } @@ -488,7 +495,10 @@ struct GridwiseFpAintBGemm_Wmma if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { - printf("GridwiseOp err: Pipeline not support this k_loop"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp err: Pipeline not support this k_loop"); + } return false; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 46979a5620..7d68d64ed8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/utility/env.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -653,13 +654,19 @@ struct GridwiseGemmMultipleD_Wmma if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && K == GetBProblemsizeNK()[I1])) { - printf("GridwiseOp: ABE descriptor dimension cross check failure\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: ABE descriptor dimension cross check failure\n"); + } return false; } if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) { - printf("GridwiseOp: Problemsize descriptor dimension check failure\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: Problemsize descriptor dimension check failure\n"); + } return false; } @@ -747,20 +754,29 @@ struct GridwiseGemmMultipleD_Wmma if(!valid) { - printf("GridwiseOp: D descriptor dimension check failure\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: D descriptor dimension check failure\n"); + } return false; } if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && K == GetBProblemsizeNK()[I1])) { - printf("GridwiseOp: ABE descriptor dimension cross check failure\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: ABE descriptor dimension cross check failure\n"); + } return false; } if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) { - printf("GridwiseOp: Problemsize descriptor dimension check failure\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: Problemsize descriptor dimension check failure\n"); + } return false; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 4a15958adb..65f74de3cf 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/utility/env.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -458,20 +459,26 @@ struct GridwiseGemm_Wmma if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && K == GetBProblemsizeNK()[I1])) { - printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n", - GetAProblemsizeMK()[I0], - GetAProblemsizeMK()[I1], - GetBProblemsizeNK()[I0], - GetBProblemsizeNK()[I1], - c_grid_desc_m_n.GetLength(I0), - c_grid_desc_m_n.GetLength(I1)); - printf("GridwiseOp err: ProblemSize check"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n", + GetAProblemsizeMK()[I0], + GetAProblemsizeMK()[I1], + GetBProblemsizeNK()[I0], + GetBProblemsizeNK()[I1], + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + printf("GridwiseOp err: ProblemSize check"); + } return false; } if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) { - printf("GridwiseOp err: ProblemSize division"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp err: ProblemSize division"); + } return false; } @@ -480,7 +487,10 @@ struct GridwiseGemm_Wmma if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { - printf("GridwiseOp err: Pipeline not support this k_loop"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp err: Pipeline not support this k_loop"); + } return false; } diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp index c7717f08cd..b6eac45285 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp @@ -95,7 +95,11 @@ struct AddRmsnorm2dRdquantFwd return dim3(integer_divide_ceil(hargs.m, Block_M)); } - CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? Problem::BlockShape::template GetBlockSize() + : Problem::BlockShape::template GetBlockSize(); + } // clang-format off template struct t2s; diff --git a/include/ck_tile/ops/common/generic_2d_block_shape.hpp b/include/ck_tile/ops/common/generic_2d_block_shape.hpp index 333762e5d7..9c5d99efc3 100644 --- a/include/ck_tile/ops/common/generic_2d_block_shape.hpp +++ b/include/ck_tile/ops/common/generic_2d_block_shape.hpp @@ -45,47 +45,57 @@ struct Generic2dBlockShape static constexpr index_t Block_N = BlockTile_::at(number<1>{}); static constexpr index_t ThreadPerBlock_M = ThreadPerBlock_::at(number<0>{}); static constexpr index_t ThreadPerBlock_N = ThreadPerBlock_::at(number<1>{}); - static constexpr index_t BlockSize = ThreadPerBlock_M * ThreadPerBlock_N; // vector size along seq static constexpr index_t Vector_M = Vector_::at(number<0>{}); static constexpr index_t Vector_N = Vector_::at(number<1>{}); - static constexpr bool is_warp_per_row = ThreadPerBlock_N <= get_warp_size(); - static_assert((ThreadPerBlock_M * ThreadPerBlock_N) % get_warp_size() == 0); - static constexpr index_t total_warps = (ThreadPerBlock_M * ThreadPerBlock_N) / get_warp_size(); - // num warps along seq, within each block - static constexpr index_t WarpPerBlock_M = []() { + template + static constexpr index_t GetWarpPerBlock_M() + { + constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size(); + constexpr bool is_warp_per_row = ThreadPerBlock_N <= warp_size; + static_assert((ThreadPerBlock_M * ThreadPerBlock_N) % warp_size == 0); + constexpr index_t total_warps = (ThreadPerBlock_M * ThreadPerBlock_N) / warp_size; + if constexpr(is_warp_per_row) { - static_assert(get_warp_size() % ThreadPerBlock_N == 0); - return total_warps * (get_warp_size() / ThreadPerBlock_N); + static_assert(warp_size % ThreadPerBlock_N == 0); + return total_warps * (warp_size / ThreadPerBlock_N); } else { // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N / get_warp_size()); + return total_warps / (ThreadPerBlock_N / warp_size); } - }(); + }; // num of warps along n - static constexpr index_t WarpPerBlock_N = []() { + template + static constexpr index_t GetWarpPerBlock_N() + { + constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size(); + constexpr bool is_warp_per_row = ThreadPerBlock_N <= warp_size; if constexpr(is_warp_per_row) { - static_assert(get_warp_size() % ThreadPerBlock_N == 0); + static_assert(warp_size % ThreadPerBlock_N == 0); return 1; } else { - static_assert(ThreadPerBlock_N % get_warp_size() == 0); - return ThreadPerBlock_N / get_warp_size(); + static_assert(ThreadPerBlock_N % warp_size == 0); + return ThreadPerBlock_N / warp_size; } - }(); + } + + static constexpr index_t WarpPerBlock_M = GetWarpPerBlock_M(); + static constexpr index_t WarpPerBlock_N = GetWarpPerBlock_N(); // warp size - static constexpr index_t Warp_M = ThreadPerBlock_M / WarpPerBlock_M * Vector_M; - static constexpr index_t Warp_N = ThreadPerBlock_N / WarpPerBlock_N * Vector_N; + static constexpr index_t BlockSize = WarpPerBlock_M * WarpPerBlock_N * get_warp_size(); + static constexpr index_t Warp_M = ThreadPerBlock_M / WarpPerBlock_M * Vector_M; + static constexpr index_t Warp_N = ThreadPerBlock_N / WarpPerBlock_N * Vector_N; static_assert(Warp_M % Vector_M == 0); static_assert(Warp_N % Vector_N == 0); static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0); @@ -98,6 +108,13 @@ struct Generic2dBlockShape // num of threads along seq, within each warp static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; + + template + static constexpr index_t GetBlockSize() + { + constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size(); + return GetWarpPerBlock_M() * GetWarpPerBlock_N() * warp_size; + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp index 6998b358d8..0181a3291f 100644 --- a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -134,7 +134,11 @@ struct Layernorm2dFwd return dim3(integer_divide_ceil(hargs.m, Block_M)); } - CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? Problem::BlockShape::template GetBlockSize() + : Problem::BlockShape::template GetBlockSize(); + } // clang-format off template struct t2s; diff --git a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp index e7f4ce0ba8..32586a6343 100644 --- a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp @@ -124,7 +124,11 @@ struct Rmsnorm2dFwd return dim3(integer_divide_ceil(hargs.m, Block_M)); } - CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? Problem::BlockShape::template GetBlockSize() + : Problem::BlockShape::template GetBlockSize(); + } // clang-format off template struct t2s; diff --git a/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp b/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp index b70e996617..2553b19fd8 100644 --- a/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp +++ b/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp @@ -93,7 +93,11 @@ struct MoeSmoothquant return dim3(hargs.topk, integer_divide_ceil(hargs.tokens, Block_M), 1); } - CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? Problem::BlockShape::template GetBlockSize() + : Problem::BlockShape::template GetBlockSize(); + } // clang-format off template struct t2s; diff --git a/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp b/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp index 7dc913901e..e0ea9692c5 100644 --- a/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp +++ b/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp @@ -82,7 +82,11 @@ struct Smoothquant return dim3(integer_divide_ceil(hargs.m, Block_M)); } - CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? Problem::BlockShape::template GetBlockSize() + : Problem::BlockShape::template GetBlockSize(); + } // clang-format off template struct t2s; diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp index dd90034064..d997596414 100644 --- a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp @@ -58,7 +58,7 @@ float add_rmsnorm2d_rdquant_fwd_(const S& s, A a) using Kernel = ck_tile::AddRmsnorm2dRdquantFwd; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); diff --git a/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp b/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp index f2875c72c8..c6ef822f64 100644 --- a/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp +++ b/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp @@ -53,7 +53,7 @@ float moe_smoothquant_(const S& s, A a) using Kernel = ck_tile::MoeSmoothquant; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); diff --git a/test/ck_tile/rmsnorm2d/generate.py b/test/ck_tile/rmsnorm2d/generate.py index 5eded8b310..3bcc427e83 100644 --- a/test/ck_tile/rmsnorm2d/generate.py +++ b/test/ck_tile/rmsnorm2d/generate.py @@ -201,7 +201,7 @@ float rmsnorm2d_fwd_(const S& s, A a) using Kernel = ck_tile::Rmsnorm2dFwd; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); diff --git a/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp b/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp index 8929289cdb..138afcffaf 100644 --- a/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp +++ b/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp @@ -49,7 +49,7 @@ float smoothquant_(const S& s, A a) using Kernel = ck_tile::Smoothquant; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a);