From f1f9b9635cb588cdbda092da16b6591a900f2e52 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Wed, 16 Jul 2025 22:33:03 -0700 Subject: [PATCH] Fixing numerical error, and interchange preshuffle configs to match with flatmm (#2515) [ROCm/composable_kernel commit: 579bd73435bf544a2dfdf39aaa5fe62be1a01f2c] --- example/ck_tile/03_gemm/gemm_utils.hpp | 8 ++++---- example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp | 2 +- example/ck_tile/03_gemm/run_gemm_example.inc | 12 ++++++++++-- example/ck_tile/18_flatmm/run_flatmm_example.inc | 2 +- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 9deccc7f16..7a9b5afaa2 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -241,8 +241,8 @@ struct GemmConfigPreshufle_1 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); static constexpr int kBlockPerCu = 2; @@ -263,8 +263,8 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); static constexpr int kBlockPerCu = 2; diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp index f57c24f458..b7b0701080 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -220,7 +220,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a auto [result, arg_parser] = create_args(argc, argv); bool preshuffle = GemmConfig::Preshuffle; - if(preshuffle && a_layout != "R" && b_layout != "C") + if(preshuffle && (a_layout != "R" || b_layout != "C")) { throw std::runtime_error( "Preshuffle is supported only for A(Row major), B(column major) input matrices!"); diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index f13a4b693b..83836117e9 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -315,8 +315,16 @@ int run_gemm_example_with_layouts(int argc, if(init_method == 0) { - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + if constexpr(preshuffle) + { + ck_tile::FillUniformDistribution{-.5f, .5f}(a_m_k); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_k_n); + } + else + { + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + } } else if(init_method == 1) { diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index b583612cfb..8f39b07be5 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -18,7 +18,7 @@ constexpr const char* DataTypeToString() { return "bf8"; } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { return "bf16"; }