diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index 9b2610813c..7f8aba7b3d 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -47,5 +47,6 @@ User need to select correct mapping of config for each quant mode: | For selecting AQuant | aquant | GemmConfigQuant | | For selecting Aquant with Preshuffle | aquant | GemmConfigPreshuffleQuant | | For selecting BQuant | bquant | GemmConfigQuant | +| For selecting PreShuffle Weight matrix with Bquant | bquant | GemmConfigPreshuffleB_Bquant_decode (or) GemmConfigPreshuffleB_Bquant_prefill | For selecting RowCol quant | rowcolquant | GemmConfigRowColQuant | diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp old mode 100644 new mode 100755 index 91f799f194..fa9ad967ad --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp @@ -23,7 +23,6 @@ template ); - // B datatype is safe to use as compute type as it should be at least fp8 using ComputeDataType = std::conditional_t; + QuantMode, + ALayout, // for AQLayout + BLayout, // for BQLayout + GemmConfig::DoubleSmemBuffer>; using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + using BaseGemmPipeline = std::conditional_t< + GemmConfig::PreshuffleB == true, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, + ck_tile::BaseGemmPipelineAgBgCrCompV3>; const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; @@ -110,9 +116,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str QuantMode == ck_tile::QuantType::RowColQuant || QuantMode == ck_tile::QuantType::TensorQuant, ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + std::conditional_t< + QuantMode == ck_tile::QuantType::AQuantGrouped, + ck_tile::AQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem(Kernel{}, grids, blocks, 0, kargs)); + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + ck_tile::RotatingMemWrapper + rotating_mem( + kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString( + hipMemsetAsync(args.c_ptr, + 0, + args.M * args.N * sizeof(typename TypeConfig::CDataType), + s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } return ave_time; }; @@ -180,6 +229,14 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + if((QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::RowColQuant) && + GemmConfig::PreshuffleB) + { + throw std::runtime_error( + "Preshuffling weight matrix is not supported for AQuant or RowColQuant"); + } + if constexpr(std::is_same_v || std::is_same_v || std::is_same_v) @@ -391,4 +448,7 @@ int run_gemm_example(int argc, char* argv[]) } } -int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + return !run_gemm_example(argc, argv); +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index e5313d8aaf..cfe7b72af9 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -91,6 +91,7 @@ struct GemmConfigBase static constexpr ck_tile::index_t TileParitionerM01 = 4; static constexpr bool PreshuffleQuant = false; + static constexpr bool PreshuffleB = false; static constexpr bool DoubleSmemBuffer = false; }; @@ -145,6 +146,26 @@ struct GemmConfigPreshuffleQuant : public GemmConfigBase static constexpr bool PreshuffleQuant = true; }; +template +struct GemmConfigPreshuffleB_Bquant_decode : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + get_k_from_preshuffled_warp_tile(); + + static constexpr bool PreshuffleB = true; + static constexpr bool DoubleSmemBuffer = true; +}; + template * t, int block_aq_k) return ck_tile::reference_permute(t_view, {1, 0, 2}); } +template +auto shuffle_b(const ck_tile::HostTensor& t) +{ + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + divisor, + GemmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); +} + template b_k_n_dev = b_k_n; if constexpr(std::is_same_v) { - // Permute vector pk_i4x4 data for device implementation - ck_tile::HostTensor b_k_n_dev = b_k_n; + + if constexpr(GemmConfig::PreshuffleB) + { + b_k_n_dev = shuffle_b(b_k_n); + } ck_tile::permute_vectors_i4x4_b(b_k_n_dev); b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); } else { - b_k_n_dev_buf.ToDevice(b_k_n.data()); + if constexpr(GemmConfig::PreshuffleB) + { + b_k_n_dev = shuffle_b(b_k_n); + } + b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); } + c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); @@ -509,7 +536,7 @@ int run_gemm_example_with_layouts(int argc, << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; } - std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; } else if(arg_parser.get_int("v") == 2) { diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 1f8b4f8adc..d66438528e 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -125,6 +125,7 @@ struct WarpGemmAttributeMfmaIterateK static constexpr index_t kN = Impl::kN; static constexpr index_t kK = Impl::kK * kKIter; static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter; + static constexpr index_t kCMLane = Impl::kCMLane; CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 9f90050899..531cd676a5 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp" +#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp" #include "ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp" #include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" @@ -13,6 +14,8 @@ #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" #include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp new file mode 100755 index 0000000000..c4c1f1bbf7 --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp" + +namespace ck_tile { + +// A is block window on shared memory +// BQ (scale tensor) is block distributed tensor. +// Consecutive kQuantGroupSize elements of B are quantized with a separate scale. +// B is block window on block distributed tensor. +// C is block distributed tensor +template +struct BlockGemmWeightPreshuffleBQuantARegBRegCReg +{ + using Problem = remove_cvref_t; + using BlockPolicy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BQDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; // TileFlatmmShape + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); + + static constexpr auto warp_size = get_warp_size(); + + using WG = remove_cvref_t())>; + + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + static constexpr index_t NIterPerWarp = + BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN)); + static constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + static constexpr auto MIter_2nd_last = + (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; + + static constexpr index_t KPerBlockBQ = KPerBlock / kQuantGroupSize; + + static constexpr index_t QScalesPerBlockRow = + (KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize; + + static constexpr index_t QScalesPerWarpGemmRow = + (WG::kK + kQuantGroupSize - 1) / kQuantGroupSize; + + static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; + static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read + + static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) + ? DsReadPreload + : MIterPerWarp * KIterPerWarp; + + template + CK_TILE_DEVICE static float cvt_scale_to_fp32(T& scale) + { + float scale_reg_f = 0.f; + if constexpr(std::is_same_v) + { + scale_reg_f = element_wise::amd_assembly_fp8_to_fp32(static_cast(scale)); + } + else if constexpr(std::is_same_v) + { + scale_reg_f = element_wise::amd_assembly_bf8_to_fp32(static_cast(scale)); + } + else if constexpr(std::is_same_v) + { + scale_reg_f = ck_tile::bit_cast(scale); + } + else + { + static_assert(false, "BQDataType must be float, fp8_t or bf8_t."); + } + return scale_reg_f; + } + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + ABlockTensor& a_warp_tensor, + BFlatBlockTensor& b_warp_tensor, + BQBlockTensor& bq_block_tensor, + ABlockWindow& a_warp_windows) const + { + using CWarpDstr = typename WG::CWarpDstr; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) { + CWarpTensor c_warp_tensor; + static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale; + + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + + // warp GEMM + if constexpr(kIterInQScale == 0) + c_warp_tensor = WG{}(a_warp_tensor(number{}), + b_warp_tensor(nIter)(number{})); + else + WG{}(c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor(nIter)(number{})); + + __builtin_amdgcn_sched_barrier(0x7F6); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows(number{})(number{})); + } + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + }); + }); + + constexpr auto tbuf_offset = + number{}, number<0>{}>{}, c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + constexpr index_t reg_offset = kQScale; + // nIter * KPerBlockBQ + kQScale; //((kIter * WG::kK) / kQuantGroupSize); + + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float scale_reg_f = cvt_scale_to_fp32(scale_reg); + + static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); + }); + }); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index f75d02f1a6..d4bece1a83 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -344,11 +344,11 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase if constexpr(Traits::PreshuffleQuant) { - static_assert(false, - "It is not supported yet to enable both Preshuffle and " - "TransposeC."); if constexpr(Traits::TransposeC) // transposed C { + static_assert(false, + "It is not supported yet to enable both Preshuffle " + "and TransposeC."); // TODO: // A new tile distribution is needed for the Preshuffle and // Transpose combination. For instance, with mnk at 16x16x32, lanes diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 0c9c816672..a0b6fc5821 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -77,6 +77,18 @@ struct is_quantpreshuffle_enabled { static constexpr bool value = T::PreshuffleQuant; }; + +template +struct is_preshuffleB_enabled +{ + static constexpr bool value = false; +}; + +template +struct is_preshuffleB_enabled> +{ + static constexpr bool value = T::PreshuffleB; +}; } // namespace detail struct QuantGemmProblem @@ -196,6 +208,7 @@ struct QuantGemmKernel static constexpr index_t kBlockSize = GemmPipeline::BlockSize; static constexpr bool PreshuffleQuant = detail::is_quantpreshuffle_enabled::value; + static constexpr bool PreshuffleB = detail::is_preshuffleB_enabled::value; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -630,12 +643,30 @@ struct QuantGemmKernel } else { - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); + if constexpr(PreshuffleB) + { + index_t kFlatK = + GemmPipeline::flatKPerWarp * + (splitk_batch_offset.splitted_k / + TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})); + index_t kFlatN = kargs.N * kargs.K / kFlatK; + + return make_naive_tensor_view( + b_ptr, + make_tuple(kFlatN, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } } } }(); @@ -716,6 +747,8 @@ struct QuantGemmKernel // no padding const auto& aq_pad_view = [&]() { return views.at(I1); }(); + const auto& b_flat_view = views.at(I2); // not applying any padding to flat B view + const auto& b_pad_view = [&]() { const auto& b_tensor_view = views.at(I2); if constexpr(std::is_same_v) @@ -755,8 +788,14 @@ struct QuantGemmKernel sequence{}); } }(); - - return make_tuple(a_pad_view, aq_pad_view, b_pad_view, bq_pad_view, c_pad_view); + if constexpr(PreshuffleB) + { + return make_tuple(a_pad_view, aq_pad_view, b_flat_view, bq_pad_view, c_pad_view); + } + else + { + return make_tuple(a_pad_view, aq_pad_view, b_pad_view, bq_pad_view, c_pad_view); + } } template @@ -826,19 +865,30 @@ struct QuantGemmKernel }(); const auto& b_block_window = [&]() { - if constexpr(std::is_same_v) + if constexpr(PreshuffleB) { - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); + return make_tile_window( + b_pad_view, + make_tuple(number{}, + number{}), + {static_cast(i_n / TilePartitioner::BlockGemmShape::WarpTile::at(I1)), 0}); } else { - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {0, i_n}); + if constexpr(std::is_same_v) + { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); + } + else + { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {0, i_n}); + } } }(); @@ -969,6 +1019,80 @@ struct QuantGemmKernel c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); } } + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param aq_ptr input AQ pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + * @tparam DstInMemOp Destination memory operation (default: set). + */ + template + CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr, + const BDataType* b_ptr, + const AQDataType* aq_ptr, + const BQDataType* bq_ptr, + CDataType* c_ptr, + void* smem_ptr_0, + void* smem_ptr_1, + const QuantGemmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( + a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = [&]() { + if constexpr(kQuantType == QuantType::BQuantGrouped) + { + const auto& bq_block_window = gemm_tile_windows.at(I3); + return GemmPipeline{}.template operator()(a_block_window, + b_block_window, + bq_block_window, + num_loop, + smem_ptr_0, + smem_ptr_1); + } + else + { + return nullptr; + } + }(); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I4); + + if constexpr(kQuantType == QuantType::BQuantGrouped) + { + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + else + { + return; + // throw std::runtime_error("DoubleSmemBuffer Not implemented for AQuantGrouped or + // RowColQuant"); static_assert(kQuantType == QuantType::BQuantGrouped, + // "DoubleSmemBuffer Not implemented"); + } + } CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const { @@ -989,8 +1113,35 @@ struct QuantGemmKernel __shared__ char smem_ptr_0[GetSmemSize()]; assert(kargs.k_batch == 1); - RunGemm( - a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + __shared__ char smem_ptr_1[GetSmemSize()]; + + RunGemm2LDS(a_ptr, + b_ptr, + aq_ptr, + bq_ptr, + c_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + else + { + RunGemm(a_ptr, + b_ptr, + aq_ptr, + bq_ptr, + c_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } } }; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp index d49204c64d..4978e70099 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp @@ -53,15 +53,15 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp new file mode 100644 index 0000000000..19c1223b78 --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp" + +namespace ck_tile { + +struct GemmWPQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipelineAgBgCrPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ() + { + using BQDataType = remove_cvref_t; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockBQ = KPerBlock / Problem::kQuantGroupSize; + + return GetABQGlobalVectorLoadSize(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution() + { + return GemmBQuantPipelineAgBgCrDefaultPolicy::MakeBQDramTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffleBQuant() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + using BTypeToUse = + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; + + using WarpGemm = WarpGemmDispatcher; + + // TODO : Use a custom block policy for AsBrCr + using BlockGemmPolicy = + BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy; + return BlockGemmWeightPreshuffleBQuantARegBRegCReg{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp new file mode 100644 index 0000000000..01c1a72335 --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -0,0 +1,471 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +template +struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV2 +{ + using Base = WeightPreshufflePipelineAGmemBGmemCRegV2; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BQDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using BQLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockWeightPreshuffle = remove_cvref_t< + decltype(PipelinePolicy::template GetBlockWeightPreshuffleBQuant())>; + + static constexpr auto config = + BlockWeightPreshuffle::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + using Base::kKPerBlock; + using Base::kMPerBlock; + using Base::kNPerBlock; + + using Base::KIterPerWarp; + using Base::MIterPerWarp; + using Base::NIterPerWarp; + + using Base::BlockSize; + + using Base::kPadK; + using Base::kPadM; + using Base::kPadN; + + using Base::I0; + using Base::I1; + using Base::I2; + + using Base::MWarp; + using Base::NWarp; + + using Base::KPerBlockPerIter; + using Base::MPerBlockPerIter; + + using Base::flatKPerWarp; + using Base::flatNPerWarp; + + using Base::m_preload; + + static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize; + static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize; + static constexpr index_t QScalesPerBlockRow = + (kKPerBlock + QuantGroupSize - 1) / QuantGroupSize; + + static constexpr index_t GetVectorSizeBQ() + { + return PipelinePolicy::template GetVectorSizeBQ(); + } + static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1); + return concat('_', "bquant_pipeline_AgBgCrV2_preshuffleB", + concat('x', kMPerBlock, kNPerBlock, kKPerBlock), + BlockSize, + concat('x', WaveNumM, WaveNumN), + concat('x', Base::GetVectorSizeA(), Base::GetVectorSizeB(), GetVectorSizeBQ()), + concat('x', kPadM, kPadN, kPadK), QuantGroupSize); + // clang-format on + } + + static constexpr bool PreshuffleB = Problem::PreshuffleB; + static constexpr auto TailNum = Problem::TailNum; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t num_loop, + void* p_smem_ping, + void* p_smem_pong) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "A/B/BQ Dram block window should have the same data type as appropriate " + "([A|B|BQ]DataType) defined in Problem definition!"); + + constexpr bool is_a_col_major = std::is_same_v; + static_assert(!is_a_col_major, "A must be row major (col major not supported yet)"); + + constexpr bool is_bq_col_major = std::is_same_v; + static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); + + constexpr bool is_b_row_major = std::is_same_v; + static_assert(!is_b_row_major, "B must be col major (row major not supported yet)"); + + const index_t iMWarp = get_warp_id() / NWarp; + + __builtin_amdgcn_sched_barrier(0); + + // A tile in LDS + ADataType* p_a_lds_ping = static_cast(p_smem_ping); + ADataType* p_a_lds_pong = static_cast(p_smem_pong); + + constexpr auto a_lds_block_desc = + PipelinePolicy::template MakeALdsBlockDescriptor(); + + auto a_lds_block_ping = + make_tensor_view(p_a_lds_ping, a_lds_block_desc); + auto a_lds_block_pong = + make_tensor_view(p_a_lds_pong, a_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeADramTileDistribution()); + + auto a_copy_lds_window_ping = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); + + auto a_copy_lds_window_pong = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); + + // ping-pong window for A LDS + auto a_warp_window_ping_tmp = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + auto a_warp_window_pong_tmp = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows_ping; + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows_pong; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; + + move_tile_window(a_warp_windows_ping(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; + + move_tile_window(a_warp_windows_pong(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Block GEMM + auto block_weight_preshuffle = BlockWeightPreshuffle(); + // Acc register tile + auto c_block_tile = block_weight_preshuffle.MakeCBlockTile(); + + // B flat DRAM window for load + auto b_flat_distribution = + PipelinePolicy::template MakeBFlatDramTileDistribution(); + auto b_flat_dram_window = // tile_window_with_static_distribution + make_tile_window( + b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views + make_tuple(number{}, number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + b_flat_distribution); + + using BTypeToUse = + std::conditional_t, ADataType, BDataType>; + using BTileType = decltype(make_static_distributed_tensor(b_flat_distribution)); + + // pingpong buffer for B + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_flat_dram_windows; + + statically_indexed_array, NIterPerWarp> + b_warp_tensor_ping; + + statically_indexed_array, NIterPerWarp> + b_warp_tensor_pong; + + // BQ DRAM window for load + auto bq_copy_dram_window = + make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + bq_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeBQDramTileDistribution()); + + // Prefetch A0 + auto a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // prefetch B + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + + load_int4_tile( + b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + }); + }); + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + // Strictly not needed given type deduction, but helps with readability + using BQBlockTileDistr = decltype(bq_copy_dram_window.get_tile_distribution()); + using BQBlockTile = + decltype(make_static_distributed_tensor(BQBlockTileDistr{})); + + // Load tile 0 for BQ data directly into registers for block tile + BQBlockTile bq_block_tile, bq_block_tile_2; + bq_block_tile = load_tile(bq_copy_dram_window); + // move BQ to tile 1 + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + + // Prefill A0 + auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + + __builtin_amdgcn_sched_barrier(0); + + // Prefetch A1 + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + block_sync_lds(); + + // preload A00,A10 from lds + statically_indexed_array{})(number<0>{}))), + m_preload> + a_warp_tensor; + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor(loadIter) = + load_tile(a_warp_windows_ping(number{})(number{})); + }); + __builtin_amdgcn_sched_barrier(0); + + // MAIN LOOP + index_t iCounter = (num_loop - 1) / 2; + while(iCounter > 0) + { + // prefetch B(2i+1) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + load_int4_tile( + b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + }); + }); + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + bq_block_tile_2 = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + + // Prefill A(2i+1) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + + // Prefetch A(2i+2) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // GEMM 2i + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_ping, + bq_block_tile, + a_warp_windows_ping); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor(loadIter) = + load_tile(a_warp_windows_pong(number{})(number{})); + }); + Base::HotLoopScheduler(); + + // Next K + + // prefetch B(2i+2) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + load_int4_tile( + b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + }); + }); + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + bq_block_tile = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + + // Prefill A(2i+2) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + + // Prefetch A(2i+3) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // GEMM 2i+1 + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_pong, + bq_block_tile_2, + a_warp_windows_pong); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor(loadIter) = + load_tile(a_warp_windows_ping(number{})(number{})); + }); + Base::HotLoopScheduler(); + + iCounter--; + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + // prefetch B(loopK) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + + load_int4_tile( + b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + }); + }); + bq_block_tile_2 = load_tile(bq_copy_dram_window); + + // Prefill A(loopK) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + + // GEMM loopK-1 + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_ping, + bq_block_tile, + a_warp_windows_ping); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor(loadIter) = + load_tile(a_warp_windows_pong(number{})(number{})); + }); + + Base::Last2ndHotLoopScheduler(); + + // GEMM loopK + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_pong, + bq_block_tile_2, + a_warp_windows_pong); + Base::LastHotLoopScheduler(); + } + else if constexpr(TailNum == TailNumber::Odd) + { + // GEMM loopK + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_ping, + bq_block_tile, + a_warp_windows_ping); + Base::LastHotLoopScheduler(); + } + + return c_block_tile; + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t num_loop, + void* p_smem_ping, + void* p_smem_pong) const + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_flat_dram_block_window_tmp, + bq_dram_block_window_tmp, + num_loop, + p_smem_ping, + p_smem_pong); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp index 3b5bff03d4..52a326a897 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp @@ -32,6 +32,7 @@ template (this)->SetUpQuantTypeSpecific(); } @@ -62,10 +65,9 @@ class TestCkTileGemmQuantBase : public ::testing::Test // Common test execution logic void invoke_quant_gemm(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s) { - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - constexpr bool kPreshuffle = false; + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; using CodegenGemmShape = ck_tile::TileGemmShape, @@ -77,11 +79,15 @@ class TestCkTileGemmQuantBase : public ::testing::Test using CodegenGemmTraits = ck_tile::TileGemmQuantTraits; + QuantType, + ALayout, + BLayout, + DoubleSmemBuffer>; // Let the derived class create the appropriate pipeline and epilogue static_cast(this) @@ -125,6 +131,19 @@ class TestCkTileGemmQuantBase : public ::testing::Test // Use higher threshold return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } + + template + auto shuffle_b(const ck_tile::HostTensor& t) + { + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + constexpr int divisor = N_Warp_Tile == 32 ? 2 : 4; + ck_tile::HostTensor t_view( + {n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } }; // Define generic QuantTypeTraits template (will be specialized) diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 5fc6b2f15c..98f88f4d53 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -24,6 +24,7 @@ struct GemmConfigBase static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool PreshuffleQuant = false; + static constexpr bool PreshuffleB = false; static constexpr bool DoubleSmemBuffer = false; // Default GEMM tile sizes for tests @@ -40,6 +41,41 @@ struct GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = 32; }; +struct GemmConfigPreshuffleB +{ + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool PreshuffleQuant = false; + static constexpr bool PreshuffleB = true; + static constexpr bool DoubleSmemBuffer = true; + + // Default GEMM tile sizes for tests + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 64; +}; + template class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase> { @@ -288,6 +324,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase b_k_n_dev = b_k_n; if constexpr(std::is_same_v) { - // Permute vector pk_i4x4 data for device implementation - ck_tile::HostTensor temp = b_k_n; - ck_tile::permute_vectors_i4x4_b(temp); - b_k_n_dev_buf.ToDevice(temp.data()); + if constexpr(PreshuffleB) + { + b_k_n_dev = this->shuffle_b(b_k_n); + } + ck_tile::permute_vectors_i4x4_b(b_k_n_dev); + b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); } else { - b_k_n_dev_buf.ToDevice(b_k_n.data()); + if constexpr(PreshuffleB) + { + b_k_n_dev = this->shuffle_b(b_k_n); + } + b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); } bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data()); @@ -419,7 +463,10 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase; - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + using BaseGemmPipeline = std::conditional_t< + PreshuffleB == false, + ck_tile::BaseGemmPipelineAgBgCrCompV3, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2>; const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile; const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); @@ -443,7 +490,11 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase; - using GemmPipeline = ck_tile::BQuantGemmPipelineAgBgCrCompV3; + using GemmPipeline = + std::conditional_t, + ck_tile::WPQuantBPipelineAgBgCrV2>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem +class TestCkTileGemmPreshuffleBBQuant : public TestCkTileGemmBQuant +{ +}; + // RowColQuant-specific test fixture template class TestCkTileGemmRowColQuant diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp index 1926b7cd0f..e131c03189 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp @@ -41,6 +41,14 @@ using BQuantTypes = ::testing::Types< >; // clang-format on +// clang-format off +using BPreshuffleBQuantTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; + // clang-format off using RowColQuantTypes = ::testing::Types< std::tuple, @@ -58,6 +66,7 @@ using TensorQuantTypes = ::testing::Types< // Test suites for each quantization type TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTypes); TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTypes); +TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleBQuantTypes); TYPED_TEST_SUITE(TestCkTileGemmRowColQuant, RowColQuantTypes); TYPED_TEST_SUITE(TestCkTileGemmTensorQuant, TensorQuantTypes); diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_ut_cases.inc b/test/ck_tile/gemm_block_scale/test_gemm_quant_ut_cases.inc index 9b07afa2b3..042735eccb 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_ut_cases.inc +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_ut_cases.inc @@ -15,6 +15,11 @@ TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) this->run_test_with_validation(1024, 1024, 1024); } +// BQuant tests +TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} // RowColQuant tests TYPED_TEST(TestCkTileGemmRowColQuant, RowColQuantTest) {