From c254f3d7b4cccae5c884b419842a01eec4ed74fc Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Wed, 10 Sep 2025 08:29:20 +0800 Subject: [PATCH] [CK_TILE] Refine Generic2dBlockShape to fix ck_tile example 2,10,11,14 on rdna3 and 4 (#2795) BlockWarps, WarpTile in Generic2dBlockShape are wave size dependent, it causes mangled name mismatch between host and device side. Solution: Replace them with ThreadPerBlock and move BlockWarps, WarpTile calculation into Generic2dBlockShape --- example/ck_tile/02_layernorm2d/generate.py | 41 +--------- .../10_rmsnorm2d/example_rmsnorm2d_fwd.cpp | 10 +-- example/ck_tile/10_rmsnorm2d/generate.py | 43 +--------- .../add_rmsnorm2d_rdquant_fwd.hpp | 48 ++--------- .../example_add_rmsnorm2d_rdquant_fwd.cpp | 9 +-- .../12_smoothquant/example_smoothquant.cpp | 9 +-- .../ck_tile/12_smoothquant/smoothquant.hpp | 46 +---------- .../14_moe_smoothquant/moe_smoothquant.hpp | 45 +---------- .../ops/common/generic_2d_block_shape.hpp | 80 ++++++++++++------- .../add_rmsnorm2d_rdquant_fwd.hpp | 48 ++--------- test/ck_tile/layernorm2d/generate.py | 43 +--------- .../moe_smoothquant/moe_smoothquant.hpp | 46 +---------- test/ck_tile/rmsnorm2d/generate.py | 43 +--------- test/ck_tile/smoothquant/smoothquant.hpp | 45 +---------- 14 files changed, 103 insertions(+), 453 deletions(-) diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index c4366f6662..b7512b2999 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -75,54 +75,17 @@ struct layernorm2d_fwd_traits_ using SmoothScaleDataType = ck_tile::remove_cvref_t; using YScaleDataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); - } - else - { - // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); - return ThreadPerBlock_N_ / ck_tile::get_warp_size(); - } - }(); - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; using Vector = ck_tile::sequence<1, Vector_N_>; + using ThreadPerBlock = ck_tile::sequence; - using Shape = ck_tile::Generic2dBlockShape; + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; diff --git a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp index 511efeeaec..2ca5157eda 100644 --- a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp @@ -71,11 +71,11 @@ bool run(const ck_tile::ArgParser& arg_parser) constexpr bool kTwoPass = true; - using BlockWarps = ck_tile::sequence<2, 2>; - using BlockTile = ck_tile::sequence<2, 128>; - using WarpTile = ck_tile::sequence<1, 64>; - using Vector = ck_tile::sequence<1, 1>; - using Shape = ck_tile::Generic2dBlockShape; + using BlockTile = ck_tile::sequence<2, 128>; + using Vector = ck_tile::sequence<1, 1>; + using ThreadPerBlock = ck_tile::sequence<2, 128>; + + using Shape = ck_tile::Generic2dBlockShape; using PipelineTraits = ck_tile::Rmsnorm2dFwdTraits; using UnquantYDataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); - } - else - { - // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); - return ThreadPerBlock_N_ / ck_tile::get_warp_size(); - } - }(); - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Generic2dBlockShape; + using ThreadPerBlock = ck_tile::sequence; + + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kSaveInvRms = kSaveInvRms_; diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp index faa134e5c4..b7bd7ac7df 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -80,55 +80,17 @@ struct add_rmsnorm2d_rdquant_fwd_traits_ using InputDataType = ck_tile::remove_cvref_t; using QuantizedDataType = ck_tile::remove_cvref_t; - static constexpr auto WarpSize = ck_tile::get_warp_size(); - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); - return total_warps * (WarpSize / ThreadPerBlock_N_); - } - else - { - // static_assert(WarpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / WarpSize); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % WarpSize == 0); - return ThreadPerBlock_N_ / WarpSize; - } - }(); - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + using BlockTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + using ThreadPerBlock = ck_tile::sequence; - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; - using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Generic2dBlockShape; + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kSaveX = kSaveX_; diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp index ace5fe0c4f..ca94bc1b71 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp @@ -99,12 +99,11 @@ bool run(const ck_tile::ArgParser& arg_parser) constexpr bool kThreePass = true; - using BlockWarps = ck_tile::sequence<4, 1>; - using BlockTile = ck_tile::sequence<4, 128>; - using WarpTile = ck_tile::sequence<1, 64>; - using Vector = ck_tile::sequence<1, 1>; + using BlockTile = ck_tile::sequence<4, 128>; + using Vector = ck_tile::sequence<1, 1>; + using ThreadPerBlock = ck_tile::sequence<4, 64>; - using Shape = ck_tile::Generic2dBlockShape; + using Shape = ck_tile::Generic2dBlockShape; using Problem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem; - using BlockTile = ck_tile::sequence<2, 128>; - using WarpTile = ck_tile::sequence<1, 64>; - using Vector = ck_tile::sequence<1, 1>; + using BlockTile = ck_tile::sequence<2, 128>; + using Vector = ck_tile::sequence<1, 1>; + using ThreadPerBlock = ck_tile::sequence<2, 128>; - using Shape = ck_tile::Generic2dBlockShape; + using Shape = ck_tile::Generic2dBlockShape; using Problem = ck_tile::SmoothquantPipelineProblem; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); - } - else - { - // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); - return ThreadPerBlock_N_ / ck_tile::get_warp_size(); - } - }(); - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; - - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; - using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Generic2dBlockShape; + using BlockTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + using ThreadPerBlock = ck_tile::sequence; + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kTwoPass = kTwoPass_; diff --git a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp index 36cf477a42..eff4bef025 100644 --- a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp +++ b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp @@ -38,54 +38,17 @@ struct moe_smoothquant_traits_ using InputType = ck_tile::remove_cvref_t; using OutputType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); - } - else - { - // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); - return ThreadPerBlock_N_ / ck_tile::get_warp_size(); - } - }(); - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + using BlockTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + using ThreadPerBlock = ck_tile::sequence; - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; - using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Generic2dBlockShape; + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kTwoPass = kTwoPass_; diff --git a/include/ck_tile/ops/common/generic_2d_block_shape.hpp b/include/ck_tile/ops/common/generic_2d_block_shape.hpp index c0bfd93198..333762e5d7 100644 --- a/include/ck_tile/ops/common/generic_2d_block_shape.hpp +++ b/include/ck_tile/ops/common/generic_2d_block_shape.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -35,43 +35,69 @@ namespace ck_tile { +-----------+-----------+-----------+-----------+-----------+ // clang-format on */ -template - typename WarpPerBlock_, // num warps along seq - typename WarpTile_, // warp size, seq - typename Vector_> // contiguous pixels(vector size) along seq)> +template + typename ThreadPerBlock_, // num threads along seq + typename Vector_> // contiguous pixels(vector size) along seq)> struct Generic2dBlockShape { // block size - static constexpr index_t Block_M = BlockTile_::at(number<0>{}); - static constexpr index_t Block_N = BlockTile_::at(number<1>{}); - - // num warps along seq, within each block - static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{}); - static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{}); - - // warp size - static constexpr index_t Warp_M = WarpTile_::at(number<0>{}); - static constexpr index_t Warp_N = WarpTile_::at(number<1>{}); - - static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0); - static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0); - // repeat of each thread along seq - static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); - static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); + static constexpr index_t Block_M = BlockTile_::at(number<0>{}); + static constexpr index_t Block_N = BlockTile_::at(number<1>{}); + static constexpr index_t ThreadPerBlock_M = ThreadPerBlock_::at(number<0>{}); + static constexpr index_t ThreadPerBlock_N = ThreadPerBlock_::at(number<1>{}); + static constexpr index_t BlockSize = ThreadPerBlock_M * ThreadPerBlock_N; // vector size along seq static constexpr index_t Vector_M = Vector_::at(number<0>{}); static constexpr index_t Vector_N = Vector_::at(number<1>{}); + static constexpr bool is_warp_per_row = ThreadPerBlock_N <= get_warp_size(); + static_assert((ThreadPerBlock_M * ThreadPerBlock_N) % get_warp_size() == 0); + static constexpr index_t total_warps = (ThreadPerBlock_M * ThreadPerBlock_N) / get_warp_size(); + + // num warps along seq, within each block + static constexpr index_t WarpPerBlock_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(get_warp_size() % ThreadPerBlock_N == 0); + return total_warps * (get_warp_size() / ThreadPerBlock_N); + } + else + { + // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N / get_warp_size()); + } + }(); + + // num of warps along n + static constexpr index_t WarpPerBlock_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(get_warp_size() % ThreadPerBlock_N == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N % get_warp_size() == 0); + return ThreadPerBlock_N / get_warp_size(); + } + }(); + + // warp size + static constexpr index_t Warp_M = ThreadPerBlock_M / WarpPerBlock_M * Vector_M; + static constexpr index_t Warp_N = ThreadPerBlock_N / WarpPerBlock_N * Vector_N; static_assert(Warp_M % Vector_M == 0); static_assert(Warp_N % Vector_N == 0); - // num of threads along seq, within each warp - static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; - static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; - static constexpr index_t ThreadPerBlock_M = Block_M / Repeat_M / Vector_M; - static constexpr index_t ThreadPerBlock_N = Block_N / Repeat_N / Vector_N; + static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0); + static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0); - static constexpr index_t BlockSize = ThreadPerBlock_M * ThreadPerBlock_N; + // repeat of each thread along seq + static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); + static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); + + // num of threads along seq, within each warp + static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; + static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; }; } // namespace ck_tile diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp index faa134e5c4..b7bd7ac7df 100644 --- a/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp +++ b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -80,55 +80,17 @@ struct add_rmsnorm2d_rdquant_fwd_traits_ using InputDataType = ck_tile::remove_cvref_t; using QuantizedDataType = ck_tile::remove_cvref_t; - static constexpr auto WarpSize = ck_tile::get_warp_size(); - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); - return total_warps * (WarpSize / ThreadPerBlock_N_); - } - else - { - // static_assert(WarpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / WarpSize); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % WarpSize == 0); - return ThreadPerBlock_N_ / WarpSize; - } - }(); - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + using BlockTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + using ThreadPerBlock = ck_tile::sequence; - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; - using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Generic2dBlockShape; + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kSaveX = kSaveX_; diff --git a/test/ck_tile/layernorm2d/generate.py b/test/ck_tile/layernorm2d/generate.py index c4366f6662..f7446c0148 100644 --- a/test/ck_tile/layernorm2d/generate.py +++ b/test/ck_tile/layernorm2d/generate.py @@ -75,54 +75,17 @@ struct layernorm2d_fwd_traits_ using SmoothScaleDataType = ck_tile::remove_cvref_t; using YScaleDataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); - } - else - { - // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); - return ThreadPerBlock_N_ / ck_tile::get_warp_size(); - } - }(); - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Generic2dBlockShape; + using ThreadPerBlock = ck_tile::sequence; + + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; diff --git a/test/ck_tile/moe_smoothquant/moe_smoothquant.hpp b/test/ck_tile/moe_smoothquant/moe_smoothquant.hpp index ced9b4ef3d..40b09dca00 100644 --- a/test/ck_tile/moe_smoothquant/moe_smoothquant.hpp +++ b/test/ck_tile/moe_smoothquant/moe_smoothquant.hpp @@ -38,54 +38,16 @@ struct moe_smoothquant_traits_ using InputType = ck_tile::remove_cvref_t; using OutputType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); - } - else - { - // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); - return ThreadPerBlock_N_ / ck_tile::get_warp_size(); - } - }(); - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; - - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; - using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Generic2dBlockShape; + using BlockTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + using ThreadPerBlock = ck_tile::sequence; + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kTwoPass = kTwoPass_; diff --git a/test/ck_tile/rmsnorm2d/generate.py b/test/ck_tile/rmsnorm2d/generate.py index 1a1c842b3c..5eded8b310 100644 --- a/test/ck_tile/rmsnorm2d/generate.py +++ b/test/ck_tile/rmsnorm2d/generate.py @@ -74,54 +74,17 @@ struct rmsnorm2d_fwd_traits_ using YScaleDataType = ck_tile::remove_cvref_t; using UnquantYDataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); - } - else - { - // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); - return ThreadPerBlock_N_ / ck_tile::get_warp_size(); - } - }(); - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Generic2dBlockShape; + using ThreadPerBlock = ck_tile::sequence; + + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kSaveInvRms = kSaveInvRms_; diff --git a/test/ck_tile/smoothquant/smoothquant.hpp b/test/ck_tile/smoothquant/smoothquant.hpp index b1d5dae3d3..ef0c36aa98 100644 --- a/test/ck_tile/smoothquant/smoothquant.hpp +++ b/test/ck_tile/smoothquant/smoothquant.hpp @@ -49,54 +49,17 @@ struct smoothquant_traits_ { using DataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); - } - else - { - // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); - return ThreadPerBlock_N_ / ck_tile::get_warp_size(); - } - }(); - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + using BlockTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + using ThreadPerBlock = ck_tile::sequence; - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; - using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Generic2dBlockShape; + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kTwoPass = kTwoPass_;