From fdb397b2c9337fc20a992d640bcc373b08828be3 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Fri, 10 Oct 2025 15:36:24 -0700 Subject: [PATCH] supporting prefill shapes for preshuffle block scale gemm (#2975) * debugging * debugging for prefill shapes * comment unused code * fix for prefill shapes * clearing up the code * add int4 to universal gemm example * clang formatted * adding test for prefill shapes in block scale gemm * lil improv on the block pipeline * Address Review Comment --------- Co-authored-by: ThomasNing [ROCm/composable_kernel commit: 3c39d279ab4569d1b33399e7746465744ed662c0] --- .../03_gemm/gemm_weight_preshuffle.cpp | 7 ++ .../gemm_weight_preshuffle_invoker.hpp | 5 +- example/ck_tile/03_gemm/run_gemm_example.inc | 16 +-- example/ck_tile/38_block_scale_gemm/README.md | 14 ++- .../38_block_scale_gemm/gemm_quant_basic.cpp | 3 +- .../38_block_scale_gemm/gemm_utils.hpp | 22 +++- ...ock_universal_gemm_ar_flatbr_bquant_cr.hpp | 104 ++++++++++-------- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 1 - .../test_gemm_quant_fixtures.hpp | 41 +++---- .../test_gemm_quant_typed.cpp | 13 ++- 10 files changed, 137 insertions(+), 89 deletions(-) mode change 100755 => 100644 include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp index 0f323cb0e3..89f177b781 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -75,6 +75,13 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) ck_tile::bf8_t, ck_tile::half_t>(a_layout, b_layout, arg_parser); } + else if(data_type == "int4") + { + return run_gemm_example_prec_type, + ck_tile::fp8_t, + ck_tile::pk_int4_t, + ck_tile::half_t>(a_layout, b_layout, arg_parser); + } else { throw std::runtime_error("Unsupported data type for this operation !!!"); diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp index d737a0f864..023b0336fe 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp @@ -194,10 +194,7 @@ struct WeightPreshuffleInvoker } else { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + throw std::runtime_error("split-k is not supported yet!"); } }; diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index e6875f97d5..42a2d70692 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -300,16 +300,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, if(init_method == 0) { - if constexpr(preshuffle) - { - ck_tile::FillUniformDistribution{-.5f, .5f}(a_m_k); - ck_tile::FillUniformDistribution{-.5f, .5f}(b_k_n); - } - else - { - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); - } + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); } else if(init_method == 1) { @@ -353,6 +345,10 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, } }(); // shuffled buffer B for device implementation + if constexpr(std::is_same_v) + { + ck_tile::permute_vectors_i4x4_b(b_shuffle_host); + } b_k_n_dev_buf.ToDevice(b_shuffle_host.data()); } else diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index 7f8aba7b3d..b7b14f9d13 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -4,8 +4,18 @@ This folder contains examples of quant GEMMs using the ck_tile tile-programming - AQuant kernel with blocks of A matrix sharing scales: custom GEMM pipeline - BQuant kernel with blocks of B matrix sharing scales: custom GEMM pipeline -- Row and Column-wise scaled: scaling implemented in Epilogue -- Tensor-wise scaled: scaling implemented in Epilogue +- Row and Column-wise scaled: All of the rowwise elements in A Matrix and columwise elements in B Matrix will share the same quantization element and the elementwisde operation will complete in epilogue. +- Tensor-wise scaled: Share the same scalar scale across the whole tensor of A or B + +--- + +## Features + +- **Preshuffled GEMM**: Shuffle the GEMM of B (weight) matrix in the warp layout and bypass the shared memory to do the GEMM calculation. Best performance solution for GEMM. +- **TransposeC**: Transpose the C Matrix Output layout to have the best coalesced scale reading +- **Preshuffled Quant**: Preshuffle the input matrix to load multiple Quant warp blocks along the selected dimension. +- **Precision**: Supports fp16, bf16, fp8, bf8, int4 (for B Matrix). +- **Validation**: CPU/GPU validation and error tolerance options. ## build ``` diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp index 00d1af5aaa..c9cc56d033 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp @@ -47,6 +47,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str QuantMode, ALayout, // for AQLayout BLayout, // for BQLayout + false, GemmConfig::DoubleSmemBuffer>; using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase(argc, argv); } +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index cfe7b72af9..0206aa88a8 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -166,6 +166,26 @@ struct GemmConfigPreshuffleB_Bquant_decode : public GemmConfigBase static constexpr bool DoubleSmemBuffer = true; }; +template +struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + get_k_from_preshuffled_warp_tile(); + + static constexpr bool PreshuffleB = true; + static constexpr bool DoubleSmemBuffer = true; +}; + template {}; - static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) { - CWarpTensor c_warp_tensor; - static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale; + statically_indexed_array, MIterPerWarp> + c_acc; - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - - // warp GEMM - if constexpr(kIterInQScale == 0) - c_warp_tensor = WG{}(a_warp_tensor(number{}), - b_warp_tensor(nIter)(number{})); - else - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor(nIter)(number{})); - - __builtin_amdgcn_sched_barrier(0x7F6); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows(number{})(number{})); - } - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } - }); + auto zero_accumulators = [&] { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, (WG::kM * WG::kN) / warp_size, 1>{}([&](auto i) { + c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f; + }); // make sure WG::CWarpTensor exposes a clear/zero }); }); + }; + static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) { + zero_accumulators(); + static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // warp GEMM + WG{}(c_acc(mIter)(nIter), + a_warp_tensor(number{}), + b_warp_tensor(nIter)(number{})); + }); + __builtin_amdgcn_sched_barrier(0x7F6); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows(number{})(number{})); + } + // barrier + // Could be deleted + if constexpr((mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + }); + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; - constexpr auto tbuf_offset = - number{}, number<0>{}>{}, c_warp_y_index_zeros)) / - CBlockTensor::PackedSize>{}; + constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale; - constexpr index_t reg_offset = kQScale; - // nIter * KPerBlockBQ + kQScale; //((kIter * WG::kK) / kQuantGroupSize); + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float scale_reg_f = cvt_scale_to_fp32(scale_reg); - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - float scale_reg_f = cvt_scale_to_fp32(scale_reg); - - static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { - c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += - (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); + static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { + auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; + const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; + c_ref = c_ref + acc_val * scale_reg_f; + }); + }); }); }); } diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index bba2bc8400..bc2c9c603a 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -1111,7 +1111,6 @@ struct QuantGemmKernel // allocate LDS __shared__ char smem_ptr_0[GetSmemSize()]; - assert(kargs.k_batch == 1); if constexpr(GemmPipeline::DoubleSmemBuffer == true) { diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 21eabd6041..21f586499e 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -57,26 +57,10 @@ struct GemmConfigPreshuffleQuantTransposeC : public GemmConfigBase static constexpr bool TransposeC = true; }; -struct GemmConfigPreshuffleB +struct GemmConfigPreshuffleBDecode : public GemmConfigBase { - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool PermuteA = false; - static constexpr bool PermuteB = false; - - static constexpr bool TransposeC = false; - static constexpr bool UseStructuredSparsity = false; - - static constexpr int kBlockPerCu = 1; - static constexpr ck_tile::index_t TileParitionerGroupNum = 8; - static constexpr ck_tile::index_t TileParitionerM01 = 4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t NumWaveGroups = 1; - static constexpr bool PreshuffleQuant = false; - static constexpr bool PreshuffleB = true; - static constexpr bool DoubleSmemBuffer = true; + static constexpr bool PreshuffleB = true; + static constexpr bool DoubleSmemBuffer = true; // Default GEMM tile sizes for tests static constexpr ck_tile::index_t M_Tile = 16; @@ -92,6 +76,25 @@ struct GemmConfigPreshuffleB static constexpr ck_tile::index_t K_Warp_Tile = 64; }; +struct GemmConfigPreshuffleBPrefill : public GemmConfigBase +{ + static constexpr bool PreshuffleB = true; + static constexpr bool DoubleSmemBuffer = true; + + // Default GEMM tile sizes for tests + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 64; +}; + template class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase> { diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp index ea7a88febb..b4c11d5c5a 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp @@ -62,10 +62,15 @@ using BQuantTypes = ::testing::Types< // clang-format off using BPreshuffleBQuantTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format off