diff --git a/CHANGELOG.md b/CHANGELOG.md index b149a74df3..3280ad07dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## (Unreleased) Composable Kernel 1.3.0 ### Added +* Added preshuffleB support for abquant mode in blockscale GEMM. * Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight. * Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32". * Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines. diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp index 4a90c07e05..155f19881e 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp @@ -69,4 +69,64 @@ void abquant_quantgrouped_instance_factory( BQuantGroupSize, ck_tile::QuantType::ABQuantGrouped>(arg_parser); }; + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; } diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index d8988be7b0..398a61f368 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -74,9 +74,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str std::conditional_t< QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true, ck_tile::BaseGemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BaseGemmPipelineAgBgCrCompV3>>>; + std::conditional_t< + QuantMode == ck_tile::QuantType::AQuantGrouped, + ck_tile::BaseGemmPipelineAgBgCrMem, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2>>>; const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; @@ -145,26 +146,33 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str GemmConfig::Scheduler, has_hot_loop_v, tail_number_v>>>>; + using AQuantPipeline = + std::conditional_t, + ck_tile::AQuantGemmPipelineAgBgCrMem>; + + using BQuantPipeline = std::conditional_t< + GemmConfig::PreshuffleB, + ck_tile::WPQuantBPipelineAgBgCrV2, + std::conditional_t< + std::is_same_v, + ck_tile::MxFp4GemmPipelineAgBgCrCompV3, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + + using ABQuantPipeline = + std::conditional_t, + ck_tile::ABQuantGemmPipelineAgBgCrCompV3>; using GemmPipeline = std::conditional_t< QuantMode == ck_tile::QuantType::RowColQuant || QuantMode == ck_tile::QuantType::TensorQuant, ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t< - QuantMode == ck_tile::QuantType::AQuantGrouped, - std::conditional_t, - ck_tile::AQuantGemmPipelineAgBgCrMem>, - std::conditional_t< - QuantMode == ck_tile::QuantType::ABQuantGrouped, - ck_tile::ABQuantGemmPipelineAgBgCrCompV3, - std::conditional_t< - GemmConfig::PreshuffleB == true, - ck_tile::WPQuantBPipelineAgBgCrV2, - std::conditional_t< - std::is_same_v, - ck_tile::MxFp4GemmPipelineAgBgCrCompV3, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>>>; + std::conditional_t>>; constexpr bool TiledPermuteN = (BQuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN; @@ -908,8 +916,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - if((QuantMode == ck_tile::QuantType::ABQuantGrouped || - QuantMode == ck_tile::QuantType::AQuantGrouped || + if((QuantMode == ck_tile::QuantType::AQuantGrouped || QuantMode == ck_tile::QuantType::RowColQuant || std::is_same_v) && GemmConfig::PreshuffleB) @@ -938,7 +945,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped || QuantMode == ck_tile::QuantType::ABQuantGrouped) && - !GemmConfig::PreshuffleQuant) + !GemmConfig::PreshuffleQuant && !GemmConfig::PreshuffleB) { if(a_layout == "R" && b_layout == "R") { diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 1e4aece0d7..696de378aa 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -3,6 +3,7 @@ #pragma once #include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" +#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp" @@ -24,6 +25,8 @@ #include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp new file mode 100644 index 0000000000..63a5151108 --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp @@ -0,0 +1,282 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" + +namespace ck_tile { + +// A is block window on shared memory +// BQ (scale tensor) is block distributed tensor. +// Consecutive QuantGroupSize elements of B are quantized with a separate scale. +// B is block window on block distributed tensor. +// C is block distributed tensor +template +struct BlockGemmWeightPreshuffleABQuantARegBRegCReg +{ + private: + template + struct GemmTraits_ + { + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BQDataType = remove_cvref_t; + using BQLayout = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using AQuantGroupSize = remove_cvref_t; + using BQuantGroupSize = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr auto Scheduler = Problem::Scheduler; + + // Threadblock GEMM tile size + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t NQPerBlock = NPerBlock / BQuantGroupSize::kN; + static constexpr index_t KQPerBlock = KPerBlock / BQuantGroupSize::kK; + static constexpr index_t AQPerBlock = KPerBlock / AQuantGroupSize::kK; + + static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + // number of warps along M and N for threadblock's GEMM problem size + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + using I0 = number<0>; + using I1 = number<1>; + + static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}), + "Error! WarpGemm's MWarp is not consistent with BlockGemmShape!"); + static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}), + "Error! WarpGemm's NWarp is not consistent with BlockGemmShape!"); + static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}), + "Error! WarpGemm's M is not consistent with BlockGemmShape!"); + static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}), + "Error! WarpGemm's N is not consistent with BlockGemmShape!"); + + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + + static constexpr index_t QScalesPerBlockRow = + integer_divide_ceil(KPerBlock, BQuantGroupSize::kK); + static constexpr index_t QScalesPerWarpGemmRow = + integer_divide_ceil(WarpGemm::kK, BQuantGroupSize::kK); + + static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; + + static_assert(BQuantGroupSize::kK % WarpGemm::kK == 0, + "Error! WarpGemm::kK should be a multiple of QuantGroupSize"); + static_assert(QScalesPerWarpGemmRow == 1, + "Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK"); + static_assert(KIterPerWarp % QScalesPerBlockRow == 0, + "Error! KItersPerWarp should be a multiple of QscalesPerBlockRow"); + + static_assert(KPerBlock / BQuantGroupSize::kK > 0, + "Error! Each row of blockgemm should have a separate scale"); + + static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock, + "Error! Warps should cover all Block tile!"); + static_assert(NIterPerWarp * NWarp * WarpGemm::kN == NPerBlock, + "Error! Warps should cover all Block tile!"); + + // Currently tested combinations (A, B, BQ) + // 1. fp8, fp8, fp32 -> f32 + // 2. bf8, bf8, fp32 -> f32 + // 3. i4, fp8, (fp8/fp32) -> f32 + // 4. i4, bf8, (fp8/fp32) -> f32 + static_assert( + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v) && + std::is_same_v); + + static constexpr index_t InterWaveSchedulingMacClusters = 1; + + static constexpr index_t KPack = WarpGemm::kKPerThread; + static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread; + static constexpr bool TransposeC = Problem::TransposeC; + }; + + public: + using Traits = GemmTraits_; + using Problem = remove_cvref_t; + using BlockPolicy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BQDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; // TileFlatmmShape + using QuantGroupSize = remove_cvref_t; + + static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); + + static constexpr auto warp_size = get_warp_size(); + + using WG = remove_cvref_t())>; + + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); // 128 / (1 * 16) = 8 + static constexpr index_t NIterPerWarp = + BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN)); // 128 / (4 * 16) = 2 + static constexpr index_t KIterPerWarp = KPerBlock / WG::kK; // 128 / 16 = 8 + static constexpr auto MIter_2nd_last = + (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; + + static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK; + + static constexpr index_t QScalesPerBlockRow = + integer_divide_ceil(KPerBlock, QuantGroupSize::kK); // 128 / 128 = 1 + static constexpr index_t QScalesPerWarpGemmRow = + integer_divide_ceil(WG::kK, QuantGroupSize::kK); + + static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; // 8 / 1 = 8 + static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read + + static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) + ? DsReadPreload + : MIterPerWarp * KIterPerWarp; + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + return BlockGemmQuantCommon:: + MakeCBlockTile(); + } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + ABlockTensor& a_warp_tensor, + BFlatBlockTensor& b_warp_tensor, + AQBlockTensor& aq_block_tensor, + BQBlockTensor& bq_block_tensor, + ABlockWindow& a_warp_windows) const + { + using CWarpDstr = typename WG::CWarpDstr; + using AccTensor = typename WG::CWarpTensor; + + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + statically_indexed_array, MIterPerWarp> + c_acc; + + auto zero_accumulators = [&] { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, (WG::kM * WG::kN) / warp_size, 1>{}([&](auto i) { + c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f; + }); // make sure WG::CWarpTensor exposes a clear/zero + }); + }); + }; + static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) { + zero_accumulators(); + static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // warp GEMM + WG{}(c_acc(mIter)(nIter), + a_warp_tensor(number{}), + b_warp_tensor(nIter)(number{})); + }); + __builtin_amdgcn_sched_barrier(0x7F6); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows(number{})(number{})); + } + // barrier + // Could be deleted + if constexpr((mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + }); + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + AQPickerCommon aq_picker(aq_block_tensor); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + index_t reg_offset = [&]() { + if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN)) + { + return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ + + kQScale; + } + else + { + return nIter * KPerBlockBQ + kQScale; + } + }(); + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float b_scale_reg_f = + aq_picker.template cvt_scale_to_fp32(scale_reg); + + static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { + float a_scale_reg_f = aq_picker.template pick(); + auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; + const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; + c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f; + }); + }); + }); + }); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp new file mode 100755 index 0000000000..80e41cad45 --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp @@ -0,0 +1,120 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp" + +namespace ck_tile { + +struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipelineAgBgCrPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeAQ() + { + using AQDataType = remove_cvref_t; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockAQ = KPerBlock / Problem::AQuantGroupSize::kK; + + return GetABQGlobalVectorLoadSize(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeAQDramTileDistribution() + { + return GemmAQuantPipelineAgBgCrDefaultPolicy::MakeAQDramTileDistribution(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ() + { + using BQDataType = remove_cvref_t; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK; + + return GetABQGlobalVectorLoadSize(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution() + { + return GemmBQuantPipelineAgBgCrDefaultPolicy::MakeBQDramTileDistribution(); + } + + // as UniversalWeightPreshufflePipelineAgBgCrPolicy's MakeBFlatDramTileDistribution is changed; + // move original UniversalWeightPreshufflePipelineAgBgCrPolicy's implementation to here + // temporarily + template + CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution() + { + using TileShape = typename Problem::BlockGemmShape; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNum = BlockSize / WaveSize; + constexpr index_t KBPerLoad = GetKBPerLoad(); +#if defined(__gfx11__) + constexpr index_t KRepeatInWave = 2; +#else + constexpr index_t KRepeatInWave = 1; +#endif + constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim + constexpr index_t KWavePerBlk = 1; + constexpr index_t KRepeat = 1; + static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong"); + + constexpr index_t NBPerLoad = 1; + constexpr index_t NThdPerWave = 1; + constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp + constexpr index_t NRepeat = 1; + + constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, // ? + tuple, // second direction + sequence>, // first direction + // wave in blk, // thd in wave + // // + tuple, sequence<0, 1, 2>>, // which direction + tuple, sequence<1, 2, 2>>, // which index + // + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffleBQuant() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + using BTypeToUse = + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; + + using WarpGemm = WarpGemmDispatcher; + + // TODO : Use a custom block policy for AsBrCr + using BlockGemmPolicy = + BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy; + return BlockGemmWeightPreshuffleABQuantARegBRegCReg{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp new file mode 100644 index 0000000000..0f3951ffcc --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -0,0 +1,611 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +template +struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV2 +{ + using Base = WeightPreshufflePipelineAGmemBGmemCRegV2; + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BQDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using AQuantGroupSize = remove_cvref_t; + using BQuantGroupSize = remove_cvref_t; + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using BQLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockWeightPreshuffle = remove_cvref_t< + decltype(PipelinePolicy::template GetBlockWeightPreshuffleBQuant())>; + + static constexpr auto config = + BlockWeightPreshuffle::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + using Base::kKPerBlock; + using Base::kMPerBlock; + using Base::kNPerBlock; + + using Base::KIterPerWarp; + using Base::MIterPerWarp; + using Base::NIterPerWarp; + + using Base::BlockSize; + + using Base::kPadK; + using Base::kPadM; + using Base::kPadN; + + using Base::I0; + using Base::I1; + using Base::I2; + + using Base::MWarp; + using Base::NWarp; + + using Base::KPerBlockPerIter; + using Base::MPerBlockPerIter; + + using Base::flatKPerWarp; + using Base::flatNPerWarp; + + using Base::m_preload; + + static constexpr index_t VectorLoadSize = Problem::VectorLoadSize; + static constexpr index_t KPerBlockAQ = + integer_divide_ceil(BlockGemmShape::kK, AQuantGroupSize::kK); + static constexpr index_t KPerBlockBQ = + integer_divide_ceil(BlockGemmShape::kK, BQuantGroupSize::kK); + static constexpr index_t QScalesPerBlockRow = + integer_divide_ceil(kKPerBlock, BQuantGroupSize::kK); + static constexpr index_t GetVectorSizeAQ() + { + return PipelinePolicy::template GetVectorSizeAQ(); + } + static constexpr index_t GetVectorSizeBQ() + { + return PipelinePolicy::template GetVectorSizeBQ(); + } + static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1); + return concat('_', "bquant_pipeline_AgBgCrV2_preshuffleB", + concat('x', kMPerBlock, kNPerBlock, kKPerBlock), + BlockSize, + concat('x', WaveNumM, WaveNumN), + concat('x', Base::GetVectorSizeA(), Base::GetVectorSizeB(), GetVectorSizeAQ(), GetVectorSizeBQ()), + concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName(), BQuantGroupSize::GetName()); + // clang-format on + } + + template + CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() + { + // Estimated number of VMEM vector loads for A per block: + // total A bytes / (threads per block * vector width) + constexpr index_t Aload_inst = + (kMPerBlock * kKPerBlock * sizeof(ADataType)) / BlockSize / VectorLoadSize; + // Estimated number of VMEM vector loads for B per block: + // total B bytes / (threads per block * vector width) + constexpr index_t Bload_inst = + (kKPerBlock * kNPerBlock * sizeof(BDataType)) / BlockSize / VectorLoadSize; + + // Estimated number of VMEM loads for B's quant data (e.g. scales / zp). + // First ceil-divide by quant group size (how many elements share one scale), + // then by vector width to get an approximate number of vector loads. + constexpr index_t BQload_inst = ck_tile::integer_divide_ceil( + ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType), + BQuantGroupSize::kK * BQuantGroupSize::kK), + VectorLoadSize); + + // ToDo: Hardcoded, need to change in future. How many instruction emit per iteration + constexpr index_t kLdsInstCycle = 8; + // Total VMEM load instructions (A + B + quant data) + constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst; + // Approximate number of LDS reads per block + constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle; + // Approximate number of LDS writes per block + // (e.g., writing A from VMEM into LDS once per A load) + constexpr index_t ds_write_inst = Aload_inst; + // Number of MFMA instructions per wave for one block tile: + constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN); + // How often (in MFMA units) we should insert DS (LDS) operations. + constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); + // How often (in MFMA units) we should insert VMEM buffer loads. + // buffer_load_rep ≈ "MFMA per VMEM_READ", clamped so that one buffer_load + // is assumed to cover at most 4 MFMA instructions. + constexpr index_t buffer_load_rep = + min(mfma_inst / buffer_load_inst, 4); // 1 buffer_load cover 4 mfma + + static_for<0, nloop, 1>{}([&](auto) { + static_for<0, mfma_inst, 1>{}([&](auto i_inst) { + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA + + // Insert LDS read/write groups periodically based on ds_rep. + // The % pattern staggers READ and WRITE so they don't collapse + // into the same cycle in the model. + if constexpr(ds_rep > 0 && i_inst % ds_rep == 0) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::DS_READ, 1, 0); // DS read + } + if constexpr(ds_rep > 0 && i_inst % ds_rep == 1) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::DS_WRITE, 1, 0); // DS write + } + + if constexpr(buffer_load_rep > 0 && i_inst % buffer_load_rep == 0) + { + if constexpr(ds_write_inst > 0) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read + } + } + // Always mark some VALU work in the loop to reflect auxiliary scalar + // or vector ALU instructions that coexist with MFMA (Blockscale calculation). + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU + }); + }); + __builtin_amdgcn_sched_barrier(0); + } + + static constexpr bool PreshuffleB = Problem::PreshuffleB; + static constexpr auto TailNum = Problem::TailNum; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t m, + index_t n, + index_t num_loop, + void* p_smem) const + { + (void)m; + (void)n; + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "A/B/BQ Dram block window should have the same data type as appropriate " + "([A|B|BQ]DataType) defined in Problem definition!"); + + constexpr bool is_a_col_major = std::is_same_v; + static_assert(!is_a_col_major, "A must be row major (col major not supported yet)"); + + constexpr bool is_bq_col_major = std::is_same_v; + static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); + + constexpr bool is_b_row_major = std::is_same_v; + static_assert(!is_b_row_major, "B must be col major (row major not supported yet)"); + + const index_t iMWarp = get_warp_id() / NWarp; + // Double-Buffering (loop_count=2) for full load/compute overlap. + const index_t loop_count = 2; + + __builtin_amdgcn_sched_barrier(0); + + // A tile in LDS + constexpr index_t smem_size = PipelinePolicy::template GetSmemSize(); + ADataType* p_a_lds_ping = static_cast(p_smem); + ADataType* p_a_lds_pong = + reinterpret_cast(static_cast(p_smem) + smem_size); + + constexpr auto a_lds_block_desc = + PipelinePolicy::template MakeALdsBlockDescriptor(); + + auto a_lds_block_ping = + make_tensor_view(p_a_lds_ping, a_lds_block_desc); + auto a_lds_block_pong = + make_tensor_view(p_a_lds_pong, a_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeADramTileDistribution()); + + auto a_copy_lds_window_ping = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); + + auto a_copy_lds_window_pong = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); + + // ping-pong window for A LDS + auto a_warp_window_ping_tmp = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + auto a_warp_window_pong_tmp = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows_ping; + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows_pong; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; + + move_tile_window(a_warp_windows_ping(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; + + move_tile_window(a_warp_windows_pong(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Block GEMM + auto block_weight_preshuffle = BlockWeightPreshuffle(); + // Acc register tile + auto c_block_tile = block_weight_preshuffle.MakeCBlockTile(); + + // B flat DRAM window for load + auto b_flat_distribution = + PipelinePolicy::template MakeBFlatDramTileDistribution(); + auto b_flat_dram_window = // tile_window_with_static_distribution + make_tile_window( + b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views + make_tuple(number{}, number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + b_flat_distribution); + + using BTypeToUse = + std::conditional_t, ADataType, BDataType>; + using BTileType = decltype(make_static_distributed_tensor(b_flat_distribution)); + + // pingpong buffer for B + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_flat_dram_windows; + + statically_indexed_array, NIterPerWarp> + b_warp_tensor_ping; + + statically_indexed_array, NIterPerWarp> + b_warp_tensor_pong; + + auto aq_copy_dram_window = + make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(), + aq_dram_block_window_tmp.get_window_lengths(), + aq_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeAQDramTileDistribution()); + // BQ DRAM window for load + auto bq_copy_dram_window = + make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), + bq_dram_block_window_tmp.get_window_lengths(), + bq_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeBQDramTileDistribution()); + + // Prefetch A0 + auto a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // prefetch B + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + + load_int4_tile( + b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + }); + }); + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + // Strictly not needed given type deduction, but helps with readability + using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution()); + using AQBlockTile = + decltype(make_static_distributed_tensor(AQBlockTileDistr{})); + using BQBlockTileDistr = decltype(bq_copy_dram_window.get_tile_distribution()); + using BQBlockTile = + decltype(make_static_distributed_tensor(BQBlockTileDistr{})); + + // Load tile 0 for BQ data directly into registers for block tile + AQBlockTile aq_block_tile, aq_block_tile_2; + BQBlockTile bq_block_tile, bq_block_tile_2; + aq_block_tile = load_tile(aq_copy_dram_window); + bq_block_tile = load_tile(bq_copy_dram_window); + // move BQ to tile 1 + move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + // Prefill A0 + auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + + __builtin_amdgcn_sched_barrier(0); + + // Prefetch A1 + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + block_sync_lds(); + + // preload A00,A10 from lds + statically_indexed_array{})(number<0>{}))), + m_preload> + a_warp_tensor; + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor(loadIter) = + load_tile(a_warp_windows_ping(number{})(number{})); + }); + __builtin_amdgcn_sched_barrier(0); + + // MAIN LOOP + index_t iCounter = (num_loop - 1) / loop_count; + + while(iCounter > 0) + { + __builtin_amdgcn_sched_barrier(0); + // Prefill A(2i+1) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + + // Prefetch A(2i+2) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // GEMM 2i + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_ping, + aq_block_tile, + bq_block_tile, + a_warp_windows_ping); + // prefetch B(2i+1) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + load_int4_tile( + b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + }); + }); + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + aq_block_tile_2 = load_tile(aq_copy_dram_window); + move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); + bq_block_tile_2 = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor(loadIter) = + load_tile(a_warp_windows_pong(number{})(number{})); + }); + + // Next K + + // prefetch B(2i+2) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + load_int4_tile( + b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + }); + }); + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + aq_block_tile = load_tile(aq_copy_dram_window); + move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); + bq_block_tile = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + + // Prefill A(2i+2) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + + // Prefetch A(2i+3) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // GEMM 2i+1 + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_pong, + aq_block_tile_2, + bq_block_tile_2, + a_warp_windows_pong); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor(loadIter) = + load_tile(a_warp_windows_ping(number{})(number{})); + }); + iCounter--; + HotLoopScheduler(); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + // prefetch B(loopK) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + + load_int4_tile( + b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + }); + }); + aq_block_tile_2 = load_tile(aq_copy_dram_window); + bq_block_tile_2 = load_tile(bq_copy_dram_window); + + // Prefill A(loopK) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + + // GEMM loopK-1 + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_ping, + aq_block_tile, + bq_block_tile, + a_warp_windows_ping); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor(loadIter) = + load_tile(a_warp_windows_pong(number{})(number{})); + }); + + // GEMM loopK + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_pong, + aq_block_tile_2, + bq_block_tile_2, + a_warp_windows_pong); + HotLoopScheduler(); + } + else if constexpr(TailNum == TailNumber::Odd) + { + // GEMM loopK + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_ping, + aq_block_tile, + bq_block_tile, + a_warp_windows_ping); + Base::LastHotLoopScheduler(); + } + + return c_block_tile; + } + + // Replace lines 485-526 with a single optimized operator: + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t num_loop, + void* p_smem, + index_t m = 0, + index_t n = 0) const // Default value for non-preshuffle case + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_flat_dram_block_window_tmp, + aq_dram_block_window_tmp, + bq_dram_block_window_tmp, + m, + n, + num_loop, + p_smem); + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t num_loop, + TailNumber tail_number, + void* p_smem, + index_t n = 0) const + { + const auto RunPipeline = [&](auto bool_val, auto tail_num_) { + (void)bool_val; // Suppress unused parameter warning + constexpr auto tail_num = tail_num_.value; + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_flat_dram_block_window_tmp, + aq_dram_block_window_tmp, + bq_dram_block_window_tmp, + n, // dummy value, won't be used + num_loop, + p_smem); + }; + return Base::TailHandler(RunPipeline, true, tail_number); + } +}; +} // namespace ck_tile diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index f89aea1c17..2dad8be205 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -39,6 +39,12 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_abquant_padding PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_tile_gemm_quant_abquant_preshuffle + test_gemm_quant_abquant_preshuffle_2d.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + # AQuant tests add_gtest_executable(test_tile_gemm_quant_aquant_prefill test_gemm_quant_aquant_prefill.cpp ) diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp new file mode 100644 index 0000000000..793c9bd1df --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using ABQuantGrouped = + std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D128N = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantPreshuffleBTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPreshuffleBTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 6fb1b77fa8..3798cc4443 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -894,10 +894,10 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase; - using BaseGemmPipeline = - std::conditional_t, - ck_tile::BaseGemmPipelineAgBgCrCompV3>; + using BaseGemmPipeline = std::conditional_t< + PreshuffleB == true, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, + ck_tile::BaseGemmPipelineAgBgCrCompV3>; const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile; const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); @@ -926,8 +926,8 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase; using GemmPipeline = - std::conditional_t, + std::conditional_t, ck_tile::ABQuantGemmPipelineAgBgCrCompV3>; using GemmEpilogue = ck_tile::CShuffleEpilogue<