diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp index 0752dfdde4..edde59081c 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp @@ -143,7 +143,11 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile, transpose_c, - ck_tile::memory_operation_enum::set>>; + ck_tile::memory_operation_enum::set, + 1, + false, + 1, + GemmConfig::TiledMMAPermuteN>>; using Kernel = ck_tile::QuantGemmKernel; diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 0206aa88a8..f60d383afb 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -93,6 +93,7 @@ struct GemmConfigBase static constexpr bool PreshuffleQuant = false; static constexpr bool PreshuffleB = false; static constexpr bool DoubleSmemBuffer = false; + static constexpr bool TiledMMAPermuteN = false; }; template @@ -164,6 +165,9 @@ struct GemmConfigPreshuffleB_Bquant_decode : public GemmConfigBase static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; + + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; }; template @@ -184,6 +188,9 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; + + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; }; template #include #include "ck_tile/host/permute_pk_int4.hpp" - -template -auto shuffle_aq(const ck_tile::HostTensor* t, int block_aq_k) -{ - if(t->get_lengths().size() != 2) - { - throw std::runtime_error("Host tensor is not rank 2 tensor."); - } - int m_ = t->get_lengths()[0]; - int aqk_ = t->get_lengths()[1]; - if(aqk_ % block_aq_k != 0) - { - throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k."); - } - ck_tile::HostTensor t_view({m_, aqk_ / block_aq_k, block_aq_k}); - std::copy(t->begin(), t->end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {1, 0, 2}); -} - -template -auto shuffle_b(const ck_tile::HostTensor& t) -{ - assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - k_ / GemmConfig::K_Warp_Tile, - divisor, - GemmConfig::K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); -} +#include "ck_tile/host/shuffle_utils.hpp" template aq_shuffle_host = - shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize); + ck_tile::shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize); aq_dev_buf_ptr->ToDevice(aq_shuffle_host.data()); } else @@ -412,25 +379,26 @@ int run_gemm_example_with_layouts(int argc, } ck_tile::HostTensor b_k_n_dev = b_k_n; + if constexpr(GemmConfig::PreshuffleB) + { + if constexpr(GemmConfig::TiledMMAPermuteN) + { + printf("PreshuffleB with TiledMMAPermuteN\n"); + b_k_n_dev = ck_tile::shuffle_b_permuteN(b_k_n); + } + else + { + printf("PreshuffleB without TiledMMAPermuteN\n"); + b_k_n_dev = ck_tile::shuffle_b(b_k_n); + } + } if constexpr(std::is_same_v) { - - if constexpr(GemmConfig::PreshuffleB) - { - b_k_n_dev = shuffle_b(b_k_n); - } ck_tile::permute_vectors_i4x4_b(b_k_n_dev); - b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); - } - else - { - if constexpr(GemmConfig::PreshuffleB) - { - b_k_n_dev = shuffle_b(b_k_n); - } - b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); } + b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); + c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); @@ -438,7 +406,15 @@ int run_gemm_example_with_layouts(int argc, QuantMode == ck_tile::QuantType::RowColQuant || QuantMode == ck_tile::QuantType::TensorQuant) { - bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data()); + if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN) + { + printf("Preshuffle BQ with TiledMMAPermuteN \n"); + ck_tile::HostTensor bq_shuffle_host = + ck_tile::shuffle_bq_permuteN(*bq_tensor_ptr); + bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data()); + } + else + bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data()); } invoke_gemm + +namespace ck_tile { +template +auto shuffle_aq(const ck_tile::HostTensor* t, int block_aq_k) +{ + if(t->get_lengths().size() != 2) + { + throw std::runtime_error("Host tensor is not rank 2 tensor."); + } + int m_ = t->get_lengths()[0]; + int aqk_ = t->get_lengths()[1]; + if(aqk_ % block_aq_k != 0) + { + throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k."); + } + ck_tile::HostTensor t_view({m_, aqk_ / block_aq_k, block_aq_k}); + std::copy(t->begin(), t->end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {1, 0, 2}); +} + +template +auto shuffle_b(const ck_tile::HostTensor& t) +{ + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + divisor, + GemmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); +} + +template +auto shuffle_bq_permuteN(const ck_tile::HostTensor& t) +{ + assert(t.get_lengths().size() == 2); + + int n_ = t.get_lengths()[1]; + int bqk_ = t.get_lengths()[0]; + constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; + + ck_tile::HostTensor t_view( + {n_ / GemmConfig::N_Tile, GemmConfig::N_Warp, GemmConfig::N_Warp_Tile, NRepeat, bqk_}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4}); +} + +template +auto shuffle_b_permuteN(const ck_tile::HostTensor& t) +{ + assert(t.get_lengths().size() == 2); + + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; + + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, + GemmConfig::N_Warp, + GemmConfig::N_Warp_Tile, + NRepeat, + k_ / GemmConfig::K_Warp_Tile, + divisor, + GemmConfig::K_Warp_Tile / divisor}); + + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); +} +} // namespace ck_tile diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 80167a1d21..1720029eee 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -55,6 +55,7 @@ class TestCkTileGemmQuantBase : public ::testing::Test static constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile; static constexpr bool PreshuffleQuant = GemmConfig::PreshuffleQuant; static constexpr bool PreshuffleB = GemmConfig::PreshuffleB; + static constexpr bool TiledMMAPermuteN = GemmConfig::TiledMMAPermuteN; static constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer; public: @@ -132,19 +133,6 @@ class TestCkTileGemmQuantBase : public ::testing::Test // Use higher threshold return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } - - template - auto shuffle_b(const ck_tile::HostTensor& t) - { - assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - constexpr int divisor = N_Warp_Tile == 32 ? 2 : 4; - ck_tile::HostTensor t_view( - {n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); - } }; // Define generic QuantTypeTraits template (will be specialized) diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 21f586499e..b12259c773 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -5,6 +5,7 @@ #include "test_gemm_quant_base.hpp" #include "ck_tile/host/permute_pk_int4.hpp" +#include "ck_tile/host/shuffle_utils.hpp" struct GemmConfigBase { @@ -26,6 +27,7 @@ struct GemmConfigBase static constexpr bool PreshuffleQuant = false; static constexpr bool PreshuffleB = false; static constexpr bool DoubleSmemBuffer = false; + static constexpr bool TiledMMAPermuteN = false; // Default GEMM tile sizes for tests static constexpr ck_tile::index_t M_Tile = 16; @@ -95,6 +97,12 @@ struct GemmConfigPreshuffleBPrefill : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = 64; }; +struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBPrefill +{ + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; +}; + template class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase> { @@ -119,24 +127,6 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase - auto shuffle_aq(const ck_tile::HostTensor* t, int block_aq_k) - { - if(t->get_lengths().size() != 2) - { - throw std::runtime_error("Host tensor is not rank 2 tensor."); - } - int m_ = t->get_lengths()[0]; - int aqk_ = t->get_lengths()[1]; - if(aqk_ % block_aq_k != 0) - { - throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k."); - } - ck_tile::HostTensor t_view({m_, aqk_ / block_aq_k, block_aq_k}); - std::copy(t->begin(), t->end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {1, 0, 2}); - } - // AQuant-specific data generation void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K) { @@ -191,7 +181,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase aq_shuffle_host = - shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize); + ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize); aq_m_aqk_dev_buf.ToDevice(aq_shuffle_host.data()); } else @@ -367,11 +357,13 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase b_k_n_dev = b_k_n; + if constexpr(PreshuffleB) + { + if constexpr(TiledMMAPermuteN) + { + printf("PreshuffleB with TiledMMAPermuteN\n"); + b_k_n_dev = ck_tile::shuffle_b_permuteN(b_k_n); + } + else + { + printf("PreshuffleB without TiledMMAPermuteN\n"); + b_k_n_dev = ck_tile::shuffle_b(b_k_n); + } + } if constexpr(std::is_same_v) { - if constexpr(PreshuffleB) - { - b_k_n_dev = this->shuffle_b(b_k_n); - } ck_tile::permute_vectors_i4x4_b(b_k_n_dev); - b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); + } + + b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); + + if constexpr(PreshuffleB && TiledMMAPermuteN) + { + printf("Preshuffle BQ with TiledMMAPermuteN \n"); + ck_tile::HostTensor bq_shuffle_host = + ck_tile::shuffle_bq_permuteN(bq_bqk_n); + bq_bqk_n_dev_buf.ToDevice(bq_shuffle_host.data()); } else - { - if constexpr(PreshuffleB) - { - b_k_n_dev = this->shuffle_b(b_k_n); - } - b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); - } - bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data()); + bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data()); // Create args for kernel execution ck_tile::QuantGemmHostArgs args{ @@ -559,7 +562,11 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase>; + ck_tile::memory_operation_enum::set, + 1, + false, + 1, + TiledMMAPermuteN>>; using Kernel = ck_tile::QuantGemmKernel, std::tuple, std::tuple, - std::tuple + std::tuple, + + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format off