From 6b08653fb78c040c55d888285e66e565825998e5 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Mon, 15 Dec 2025 15:18:13 +0000 Subject: [PATCH] Merge commit '7e93eed8787afd175d3a045303096a4a98638f4b' into develop --- .../run_contraction_bilinear_example.inc | 15 +++- .../run_contraction_scale_example.inc | 15 +++- example/ck_tile/03_gemm/gemm_utils.hpp | 58 ++++----------- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 49 +++---------- .../17_grouped_gemm/grouped_gemm_multi_d.hpp | 18 ----- .../17_grouped_gemm/quant_grouped_gemm.hpp | 39 +--------- .../38_block_scale_gemm/gemm_utils.hpp | 48 +++--------- include/ck/host_utility/device_prop.hpp | 9 ++- .../gpu/device/device_base.hpp | 52 +++++++++++++ ...ce_contraction_multiple_d_xdl_cshuffle.hpp | 53 ++++++++++---- include/ck_tile/host/tensor_shuffle_utils.hpp | 62 +++++++++------- .../ops/gemm/pipeline/tile_gemm_shape.hpp | 22 ++++++ .../test_grouped_gemm_preshuffle_util.hpp | 73 +++---------------- .../gemm_preshuffle_common.hpp | 39 ---------- .../gemm_preshuffle_profiler.hpp | 21 ++++-- 15 files changed, 248 insertions(+), 325 deletions(-) diff --git a/example/26_contraction/run_contraction_bilinear_example.inc b/example/26_contraction/run_contraction_bilinear_example.inc index 78135d6296..69eb42defd 100644 --- a/example/26_contraction/run_contraction_bilinear_example.inc +++ b/example/26_contraction/run_contraction_bilinear_example.inc @@ -233,7 +233,20 @@ int run_contraction_bilinear_example(int argc, char* argv[]) } } - return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + if(ck::is_gfx11_supported()) + { + return ck::utils::check_err(e_ms_ns_device_result, + e_ms_ns_host_result, + "Error: Incorrect results!", + 1e-4, + 1e-4) + ? 0 + : 1; + } + else + { + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } } return 0; diff --git a/example/26_contraction/run_contraction_scale_example.inc b/example/26_contraction/run_contraction_scale_example.inc index 67f29dbc36..a7451fab71 100644 --- a/example/26_contraction/run_contraction_scale_example.inc +++ b/example/26_contraction/run_contraction_scale_example.inc @@ -216,7 +216,20 @@ int run_contraction_scale_example(int argc, char* argv[]) } } - return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + if(ck::is_gfx11_supported()) + { + return ck::utils::check_err(e_ms_ns_device_result, + e_ms_ns_host_result, + "Error: Incorrect results!", + 1e-4, + 1e-4) + ? 0 + : 1; + } + else + { + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } } return 0; diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 47c47334e7..f79494a478 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -12,40 +12,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - constexpr bool is_8bit_float = - std::is_same_v || std::is_same_v; - if constexpr(M_Warp_Tile == 32) - return is_8bit_float ? 64 : 16; - else - return is_8bit_float ? 128 : 32; -#else - if constexpr(M_Warp_Tile == 32) - return 16; - else - return 32; -#endif -} - -template -constexpr ck_tile::index_t get_k_warp_tile_flatmm() -{ -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif -} - struct GemmConfigBase { static constexpr bool kPadM = false; @@ -122,7 +88,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; @@ -141,7 +108,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; @@ -160,7 +128,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; @@ -204,7 +173,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; @@ -223,7 +193,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; @@ -242,7 +213,8 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5; @@ -282,7 +254,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr int kBlockPerCu = 1; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; @@ -306,7 +279,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index c5a400b4dd..67b411c1f0 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -11,40 +11,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - constexpr bool is_8bit_float = - std::is_same_v || std::is_same_v; - if constexpr(M_Warp_Tile == 32) - return is_8bit_float ? 64 : 16; - else - return is_8bit_float ? 128 : 32; -#else - if constexpr(M_Warp_Tile == 32) - return 16; - else - return 32; -#endif -} - -template -constexpr ck_tile::index_t get_k_warp_tile_flatmm() -{ -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif -} - template struct GemmTypeConfig; @@ -111,7 +77,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; @@ -134,7 +101,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; @@ -157,7 +125,8 @@ struct GemmConfigComputeV4_V2 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; @@ -178,7 +147,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool kPadK = true; @@ -203,7 +173,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp index 30a25d83d7..2724834bb5 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp @@ -11,24 +11,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - constexpr bool is_8bit_float = - std::is_same_v || std::is_same_v; - if constexpr(M_Warp_Tile == 32) - return is_8bit_float ? 64 : 16; - else - return is_8bit_float ? 128 : 32; -#else - if constexpr(M_Warp_Tile == 32) - return 16; - else - return 32; -#endif -} - struct GemmConfigBase { static constexpr bool kPadM = false; diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp index 0317685770..1fa8a03087 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp @@ -10,40 +10,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - constexpr bool is_8bit_float = - std::is_same_v || std::is_same_v; - if constexpr(M_Warp_Tile == 32) - return is_8bit_float ? 64 : 16; - else - return is_8bit_float ? 128 : 32; -#else - if constexpr(M_Warp_Tile == 32) - return 16; - else - return 32; -#endif -} - -template -constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif -} - template struct GemmTypeConfig; @@ -100,7 +66,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); }; template @@ -117,7 +84,7 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - get_k_from_preshuffled_warp_tile(); + ck_tile::get_k_warp_tile(); static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 7a4760e1da..37fc998e5b 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -24,39 +24,6 @@ inline size_t hash_multiple_strings(const std::vector& inputs) return combined_hash; } -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - constexpr bool is_8bit_float = - std::is_same_v || std::is_same_v; - if constexpr(M_Warp_Tile == 32) - return is_8bit_float ? 64 : 16; - else - return is_8bit_float ? 128 : 32; -#else - if constexpr(M_Warp_Tile == 32) - return 16; - else - return 32; -#endif -} -template -constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif -} - template static constexpr inline auto is_row_major(Layout layout_) { @@ -124,7 +91,8 @@ struct GemmConfigQuantDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); }; template @@ -140,7 +108,8 @@ struct GemmConfigRowColQuant : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); }; template @@ -157,7 +126,7 @@ struct GemmConfigPreshuffleQuantDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - get_k_from_preshuffled_warp_tile(); + ck_tile::get_k_warp_tile(); static constexpr bool PreshuffleQuant = true; }; @@ -176,7 +145,7 @@ struct GemmConfigPreshuffleB_BQuant_Decode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - get_k_from_preshuffled_warp_tile(); + ck_tile::get_k_warp_tile(); static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; @@ -206,7 +175,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - get_k_from_preshuffled_warp_tile(); + ck_tile::get_k_warp_tile(); static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; @@ -236,7 +205,8 @@ struct GemmConfigQuantPrefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); }; template diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 8739f65740..43e9350f8f 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -72,7 +72,12 @@ inline bool is_xdl_supported() is_gfx12_supported() || is_gfx11_supported(); } -template +template inline bool is_xdl_wmma_supported() { if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || @@ -82,7 +87,7 @@ inline bool is_xdl_wmma_supported() } else if(is_gfx12_supported() || is_gfx11_supported()) { - if constexpr((MPerXDL != 16) || (NPerXDL != 16)) + if constexpr((MPerXDL32 != 16) || (NPerXDL32 != 16)) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 3e37aac86e..9179a279c5 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -17,6 +17,7 @@ #endif #endif #include "ck/utility/get_id.hpp" +#include "ck/utility/sequence.hpp" namespace ck { namespace tensor_operation { @@ -96,6 +97,57 @@ static constexpr auto GetNXdlPerWave2() IsWave64>(); \ } +template +static constexpr auto GetWarpTileConfig() +{ + constexpr auto MXdlPerWave64 = MXdlPerWave_; + constexpr auto MXdlPerWave32 = MXdlPerWave_ * MPerXDL_ / 16; + constexpr auto CShuffleMXdlPerWavePerShuffle32 = CShuffleMXdlPerWavePerShuffle_ * MPerXDL_ / 16; + + constexpr auto NXdlPerWave = + IsWave64 + ? GetNXdlPerWave2() + : GetNXdlPerWave2(); + + if constexpr(IsWave64 == false && NXdlPerWave != 0) + { + constexpr auto CShuffleNXdlPerWavePerShuffle32 = + NXdlPerWave >= CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16 + ? CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16 + : CShuffleNXdlPerWavePerShuffle_; + static_assert(CShuffleNXdlPerWavePerShuffle32 > 0); + return Sequence<16, + 16, + MXdlPerWave32, + NXdlPerWave, + CShuffleMXdlPerWavePerShuffle32, + CShuffleNXdlPerWavePerShuffle32>{}; + } + else + { + return Sequence{}; + } +} + #define INVOKER_RUN_IMPL \ float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \ { \ diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index 83dbebb8d6..fff435f1c2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -166,11 +166,27 @@ struct DeviceContractionMultipleD_Xdl_CShuffle { using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -321,7 +337,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype BDataType, @@ -340,10 +356,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -360,13 +376,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = @@ -588,7 +604,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } @@ -783,6 +804,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle << MPerBlock << ", " << NPerBlock << ", " << KPerBlock << ", " + << MPerXDL << ", " + << NPerXDL << ", " << AK1 << ", " << BK1 << ", " << ABlockTransferSrcVectorDim << ", " diff --git a/include/ck_tile/host/tensor_shuffle_utils.hpp b/include/ck_tile/host/tensor_shuffle_utils.hpp index a1edce804f..5c99ae8a1c 100644 --- a/include/ck_tile/host/tensor_shuffle_utils.hpp +++ b/include/ck_tile/host/tensor_shuffle_utils.hpp @@ -68,7 +68,7 @@ auto shuffle_bq(const ck_tile::HostTensor* t, int block_bq_k) } template -auto shuffle_b(const ck_tile::HostTensor& t) +auto shuffle_b(const ck_tile::HostTensor& t, const GemmConfig& gemmConfig) { assert(t.get_lengths().size() == 2); int n_ = t.get_lengths()[1]; @@ -78,10 +78,10 @@ auto shuffle_b(const ck_tile::HostTensor& t) { constexpr int divisor = 2; constexpr int kABK1PerLane = 8; - constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - k_ / GemmConfig::K_Warp_Tile, + int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane; + ck_tile::HostTensor t_view({n_ / gemmConfig.N_Warp_Tile, + gemmConfig.N_Warp_Tile, + k_ / gemmConfig.K_Warp_Tile, kABK0PerLane, divisor, kABK1PerLane}); @@ -98,18 +98,24 @@ auto shuffle_b(const ck_tile::HostTensor& t) else { assert(is_wave32() == false); - divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4; } - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - k_ / GemmConfig::K_Warp_Tile, + ck_tile::HostTensor t_view({n_ / gemmConfig.N_Warp_Tile, + gemmConfig.N_Warp_Tile, + k_ / gemmConfig.K_Warp_Tile, divisor, - GemmConfig::K_Warp_Tile / divisor}); + gemmConfig.K_Warp_Tile / divisor}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } } +template +auto shuffle_b(const ck_tile::HostTensor& t) +{ + return shuffle_b(t, GemmConfig{}); +} + template auto bq_permuteN(const ck_tile::HostTensor& t, index_t group_n) { @@ -129,22 +135,22 @@ auto bq_permuteN(const ck_tile::HostTensor& t, index_t group_n) } template -auto shuffle_b_permuteN(const ck_tile::HostTensor& t) +auto shuffle_b_permuteN(const ck_tile::HostTensor& t, const GemmConfig& gemmConfig) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp; if(ck_tile::is_gfx12_supported()) { constexpr int divisor = 2; constexpr int kABK1PerLane = 8; - constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, - GemmConfig::N_Warp, - GemmConfig::N_Warp_Tile, + int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane; + ck_tile::HostTensor t_view({n_ / gemmConfig.N_Tile, + gemmConfig.N_Warp, + gemmConfig.N_Warp_Tile, NRepeat, - k_ / GemmConfig::K_Warp_Tile, + k_ / gemmConfig.K_Warp_Tile, kABK0PerLane, divisor, kABK1PerLane}); @@ -161,17 +167,23 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor& t) else { assert(is_wave32() == false); - divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4; } - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, - GemmConfig::N_Warp, - GemmConfig::N_Warp_Tile, + ck_tile::HostTensor t_view({n_ / gemmConfig.N_Tile, + gemmConfig.N_Warp, + gemmConfig.N_Warp_Tile, NRepeat, - k_ / GemmConfig::K_Warp_Tile, + k_ / gemmConfig.K_Warp_Tile, divisor, - GemmConfig::K_Warp_Tile / divisor}); + gemmConfig.K_Warp_Tile / divisor}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); } } + +template +auto shuffle_b_permuteN(const ck_tile::HostTensor& t) +{ + return shuffle_b_permuteN(t, GemmConfig{}); +} } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp index 8029f6a2c7..aa8469be4f 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp @@ -43,4 +43,26 @@ struct TileGemmShape } }; +template +constexpr index_t get_k_warp_tile() +{ +#if CK_TILE_USE_WMMA + return 16; +#else +#if defined(CK_GFX950_SUPPORT) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 16 : 32; + else + return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 32 : 64; +#endif +#endif +} + } // namespace ck_tile diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp index 5628b6feae..a7189e7865 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp @@ -11,26 +11,6 @@ #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" -template -constexpr ck_tile::index_t get_k_warp_tile_flatmm() -{ -#if CK_TILE_USE_WMMA - return 16; -#else -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif -#endif -} - template class TestCkTileGroupedGemmPreshuffle : public ::testing::Test { @@ -67,7 +47,8 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test static const ck_tile::index_t M_Warp_Tile = 16; static const ck_tile::index_t N_Warp_Tile = 16; - static const ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + static const ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = true; // preshuffle v2 uses ping-pong smem static constexpr bool TransposeC = false; // transpose c is not supported @@ -101,46 +82,6 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>); } - template - auto shuffle_b(const ck_tile::HostTensor& t) - { - assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - - if(ck_tile::is_gfx12_supported()) - { - constexpr int divisor = 2; - constexpr int kABK1PerLane = 8; - constexpr int kABK0PerLane = K_Warp_Tile / divisor / kABK1PerLane; - ck_tile::HostTensor t_view({n_ / N_Warp_Tile, - N_Warp_Tile, - k_ / K_Warp_Tile, - kABK0PerLane, - divisor, - kABK1PerLane}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); - } - else - { - int divisor = 1; - if(ck_tile::is_gfx11_supported()) - { - divisor = 1; - } - else - { - assert(is_wave32() == false); - divisor = N_Warp_Tile == 32 ? 2 : 4; - } - ck_tile::HostTensor t_view( - {n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); - } - } - template void invoke_grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, @@ -340,6 +281,14 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test } } + struct BShuffleGemmConfig + { + static constexpr ck_tile::index_t N_Warp_Tile = + TestCkTileGroupedGemmPreshuffle::N_Warp_Tile; + static constexpr ck_tile::index_t K_Warp_Tile = + TestCkTileGroupedGemmPreshuffle::K_Warp_Tile; + }; + public: void Run(const std::vector& Ms, const std::vector& Ns, @@ -424,7 +373,7 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); // Host-side preshuffle of B - auto b_shuffle_host = shuffle_b(b_k_n_tensors[i]); + auto b_shuffle_host = ck_tile::shuffle_b(b_k_n_tensors[i]); a_m_k_dev_buf.push_back(std::make_unique( a_m_k_tensors[i].get_element_space_size_in_bytes())); diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp index bb0b8090fa..8c0c5f78d4 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp @@ -81,42 +81,3 @@ inline KernelTraits extract_traits_from_name(const std::string& kernel_name) return traits; } - -template -auto shuffle_b(const ck_tile::HostTensor& t, - ck_tile::index_t N_Warp_Tile, - ck_tile::index_t K_Warp_Tile) -{ - assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - int divisor = N_Warp_Tile == 32 ? 2 : 4; - ck_tile::HostTensor t_view( - {n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); -} - -template -auto shuffle_b_permuteN(const ck_tile::HostTensor& t, - ck_tile::index_t N_Warp_Tile, - ck_tile::index_t K_Warp_Tile, - ck_tile::index_t N_Tile, - ck_tile::index_t N_Warp) -{ - assert(t.get_lengths().size() == 2); - - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - int divisor = N_Warp_Tile == 32 ? 2 : 4; - int NRepeat = N_Tile / N_Warp_Tile / N_Warp; - ck_tile::HostTensor t_view({n_ / N_Tile, - N_Warp, - N_Warp_Tile, - NRepeat, - k_ / K_Warp_Tile, - divisor, - K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); -} diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp index 739bd7e677..cad53b472f 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp @@ -111,21 +111,30 @@ class GemmProfiler c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); + struct GemmConfig + { + ck_tile::index_t N_Warp_Tile; + ck_tile::index_t K_Warp_Tile; + ck_tile::index_t N_Tile; + ck_tile::index_t N_Warp; + }; + for(const auto& callable : callables) { - ck_tile::index_t N_Warp_Tile = std::get<1>(config.warp_tile_dims); - ck_tile::index_t K_Warp_Tile = std::get<2>(config.warp_tile_dims); - ck_tile::index_t N_Tile = std::get<1>(config.tile_dims); - ck_tile::index_t N_Warp = std::get<1>(config.warp_dims); + GemmConfig gemmConfig = {}; + gemmConfig.N_Warp_Tile = std::get<1>(config.warp_tile_dims); + gemmConfig.K_Warp_Tile = std::get<2>(config.warp_tile_dims); + gemmConfig.N_Tile = std::get<1>(config.tile_dims); + gemmConfig.N_Warp = std::get<1>(config.warp_dims); ck_tile::HostTensor b_shuffle_host = [&]() { if(config.permuteN) { - return shuffle_b_permuteN(b_k_n, N_Warp_Tile, K_Warp_Tile, N_Tile, N_Warp); + return ck_tile::shuffle_b_permuteN(b_k_n, gemmConfig); } else { - return shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile); + return ck_tile::shuffle_b(b_k_n, gemmConfig); } }();