From 49ec224deb1a8e0f050720ac7ed101bd54e62f67 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Fri, 15 Aug 2025 22:11:54 +0000 Subject: [PATCH] Merge commit 'c06e8b4a66e03c50790d077d30afe1b1aa0b6f85' into develop --- example/ck_tile/03_gemm/README.md | 2 + example/ck_tile/03_gemm/gemm_utils.hpp | 37 ++++++------------- .../03_gemm/gemm_weight_preshuffle.cpp | 4 +- 3 files changed, 15 insertions(+), 28 deletions(-) diff --git a/example/ck_tile/03_gemm/README.md b/example/ck_tile/03_gemm/README.md index 59ef2640b7..c9e392dbd5 100644 --- a/example/ck_tile/03_gemm/README.md +++ b/example/ck_tile/03_gemm/README.md @@ -12,6 +12,8 @@ sh ../script/cmake-ck-dev.sh ../ make tile_example_gemm_basic -j # The memory bound pipeline on the gemm calculation make tile_example_gemm_universal -j +# The weight preshuffle pipeline on the gemm calculation +make tile_example_gemm_weight_preshuffle -j ``` This will result in an executable `build/bin/tile_example_gemm_basic` & `build/bin/tile_example_gemm_universal` diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 5f477b3821..ab481b97a0 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -34,21 +34,6 @@ constexpr ck_tile::index_t get_k_warp_tile() return 32; #endif } -template -constexpr ck_tile::index_t get_k_warp_tile_flatmm() -{ -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif -} struct GemmConfigBase { @@ -232,11 +217,11 @@ struct GemmConfigComputeV5 : public GemmConfigBase }; template -struct GemmConfigPreshuffle_1 : public GemmConfigBase +struct GemmConfigPreshuffleDecode : public GemmConfigBase { - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; @@ -244,17 +229,17 @@ struct GemmConfigPreshuffle_1 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr int kBlockPerCu = 2; + static constexpr int kBlockPerCu = 1; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V1; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; static constexpr bool Preshuffle = true; - static constexpr bool DoubleSmemBuffer = false; + static constexpr bool DoubleSmemBuffer = true; }; template -struct GemmConfigPreshuffle_2 : public GemmConfigBase +struct GemmConfigPreshufflePrefill : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -266,7 +251,7 @@ struct GemmConfigPreshuffle_2 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; @@ -470,7 +455,7 @@ auto create_args(int argc, char* argv[]) .insert("init", "0", "0:random, 1:linear, 2:constant(1)") .insert("persistent", "0", "0:non-persistent, 1:persistent") .insert("flush_cache", "true", "flush cache before running the kernel, defaults to true") - .insert("rotating_count", "1", "rotating count, defaults to 1"); + .insert("rotating_count", "1000", "rotating count, defaults to 1000"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp index 8a7560bf86..2057f1e4f5 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -141,7 +141,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) << "pipeline: " << GemmPipeline::GetName() << '\n' << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + << ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}" << std::endl; } if(s.flush_cache_) { @@ -280,7 +280,7 @@ int main(int argc, char* argv[]) try { - return !run_gemm_example(arg_parser); + return !run_gemm_example(arg_parser); } catch(const std::runtime_error& e) {