From b89ddd23e5ee6bf10c3df242220cdc05257e0797 Mon Sep 17 00:00:00 2001 From: ThomasNing Date: Fri, 18 Jul 2025 07:11:14 +0000 Subject: [PATCH] add the Preshuffle3 --- example/ck_tile/03_gemm/gemm_utils.hpp | 26 +++++++++++++++++-- .../03_gemm/gemm_weight_preshuffle.cpp | 2 +- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index e9508248cb..eed7423d9a 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -233,7 +233,7 @@ struct GemmConfigComputeV5 : public GemmConfigBase }; template -struct GemmConfigPreshufle_1 : public GemmConfigBase +struct GemmConfigPreshuffle_1 : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -255,7 +255,7 @@ struct GemmConfigPreshufle_1 : public GemmConfigBase }; template -struct GemmConfigPreshufle_2 : public GemmConfigBase +struct GemmConfigPreshuffle_2 : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -276,6 +276,28 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase static constexpr bool DoubleSmemBuffer = true; }; +template +struct GemmConfigPreshuffle_3 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = true; +}; + template struct GemmTypeConfig; diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp index 18a8137aab..247b21a3fc 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -282,7 +282,7 @@ int main(int argc, char* argv[]) { try { - return !run_gemm_example(argc, argv); + return !run_gemm_example(argc, argv); } catch(const std::runtime_error& e) {