diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp index c03155c116..5979ca5e4d 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp @@ -18,11 +18,11 @@ void abquant_quantgrouped_instance_factory( using BQuantGroupSize = ck_tile::QuantGroupShape>; using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type_layout, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); }; lut[hash_multiple_strings({"fp8", "abquant", @@ -33,49 +33,40 @@ void abquant_quantgrouped_instance_factory( using BQuantGroupSize = ck_tile::QuantGroupShape>; using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type_layout, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); }; lut[hash_multiple_strings({"bf8", "abquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using QuantGroupSize = ck_tile::QuantGroupShape>; + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; using TypeConfig = decltype(GemmQuantTypeConfig{}); return run_gemm_example_prec_type, TypeConfig, - QuantGroupSize, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "non-preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, ck_tile::QuantType::ABQuantGrouped>(arg_parser); }; - lut[hash_multiple_strings( - {"fp8i4", "abquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "abquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index 2c7b9a5ba4..d5961455a4 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[]) .insert("prec", "fp8", "Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, " - "or bf8i4; for ABQuant: fp8, bf8, i4fp8, or i4bf8") + "or bf8i4; for ABQuant: fp8, bf8") .insert("warmup", "50", "Number of iterations before benchmarking the kernel") .insert("repeat", "1000", "Number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 482c930e7f..5b41acf891 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -162,11 +162,11 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ck_tile::ABQuantGemmPipelineAgBgCrCompV3>>>; constexpr bool TiledPermuteN = - (QuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN; + (BQuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN; if(s.log_level_ > 0) { printf( - "TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, QuantGroupSize::kN); + "TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN); } using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem> bq_tensor_ptr = nullptr; if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped || + QuantMode == ck_tile::QuantType::ABQuantGrouped || QuantMode == ck_tile::QuantType::RowColQuant) { bq_tensor_ptr = std::make_unique>( ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout))); } - else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) - { - bq_tensor_ptr = std::make_unique>( - ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout))); - } else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant) { bq_tensor_ptr = std::make_unique>( @@ -715,7 +710,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::HostTensor b_k_n_dev = b_k_n; if constexpr(GemmConfig::PreshuffleB) { - if constexpr(GemmConfig::TiledMMAPermuteN && QuantGroupSize::kN == 1) + if constexpr(GemmConfig::TiledMMAPermuteN && BQuantGroupSize::kN == 1) { printf("PreshuffleB with TiledMMAPermuteN\n"); b_k_n_dev = ck_tile::shuffle_b_permuteN(b_k_n); @@ -742,10 +737,10 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, QuantMode == ck_tile::QuantType::TensorQuant) { if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN && - QuantGroupSize::kN == 1) + BQuantGroupSize::kN == 1) { ck_tile::HostTensor bq_permuted_host = - ck_tile::bq_permuteN(*bq_tensor_ptr, QuantGroupSize::kN); + ck_tile::bq_permuteN(*bq_tensor_ptr, BQuantGroupSize::kN); if constexpr(GemmConfig::PreshuffleQuant) { @@ -895,66 +890,6 @@ template -int run_gemm_example_prec_type_layout(const ck_tile::ArgParser& arg_parser) -{ - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - - if(QuantMode == ck_tile::QuantType::ABQuantGrouped && GemmConfig::PreshuffleB) - { - throw std::runtime_error("Preshuffling weight matrix is not supported for ABQuant"); - } - - if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) - { - std::string a_layout = arg_parser.get_str("a_layout"); - std::string b_layout = arg_parser.get_str("b_layout"); - if(a_layout == "R" && b_layout == "C") - { - return run_gemm_example_with_layouts( - arg_parser, Row{}, Row{}, Col{}, Col{}, Row{}); - } - else if(a_layout == "C" && b_layout == "R") - { - return run_gemm_example_with_layouts( - arg_parser, Col{}, Row{}, Row{}, Col{}, Row{}); - } - else if(a_layout == "R" && b_layout == "R") - { - return run_gemm_example_with_layouts( - arg_parser, Row{}, Row{}, Row{}, Col{}, Row{}); - } - else - { - throw std::runtime_error("Unsupported memory layout for the input matrices!"); - } - } - else - { - throw std::runtime_error("Unsupported data type for A."); - } - - return 0; -} - -template int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) { using Row = ck_tile::tensor_layout::gemm::RowMajor; @@ -979,19 +914,22 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) { return run_gemm_example_with_layouts( arg_parser, Row{}, Row{}, Col{}, Col{}, Row{}); } - if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped && !GemmConfig::PreshuffleQuant) + if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::ABQuantGrouped) && + !GemmConfig::PreshuffleQuant) { if(a_layout == "R" && b_layout == "R") { return run_gemm_example_with_layouts( arg_parser, Row{}, Row{}, Row{}, Col{}, Row{}); } @@ -999,24 +937,24 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) { return run_gemm_example_with_layouts( arg_parser, Col{}, Row{}, Row{}, Col{}, Row{}); } - else if(a_layout == "C" && b_layout == "C") + } + if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped && !GemmConfig::PreshuffleQuant) + { + if(a_layout == "C" && b_layout == "C") { return run_gemm_example_with_layouts( arg_parser, Col{}, Col{}, Col{}, Col{}, Row{}); } - else - { - throw std::runtime_error("Unsupported memory layout for the input matrices!"); - } } - else { throw std::runtime_error("Unsupported memory layout for the input matrices!"); @@ -1029,3 +967,16 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) return 0; } + +template +int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) +{ + return run_gemm_example_prec_type(arg_parser); +} diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index 3de91031e0..c44d330d13 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -33,6 +33,7 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase using AQDataType = remove_cvref_t; using BDataType = remove_cvref_t; using BQDataType = remove_cvref_t; + using BQLayout = remove_cvref_t; using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; @@ -79,13 +80,13 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase static constexpr index_t QScalesPerBlockRow = integer_divide_ceil(KPerBlock, BQuantGroupSize::kK); static constexpr index_t QScalesPerWarpGemmRow = - integer_divide_ceil(BQuantGroupSize::kK, WarpGemm::kK); + integer_divide_ceil(WarpGemm::kK, BQuantGroupSize::kK); static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; static_assert(BQuantGroupSize::kK % WarpGemm::kK == 0, "Error! WarpGemm::kK should be a multiple of QuantGroupSize"); - static_assert(QScalesPerWarpGemmRow > 1, + static_assert(QScalesPerWarpGemmRow == 1, "Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK"); static_assert(KIterPerWarp % QScalesPerBlockRow == 0, "Error! KItersPerWarp should be a multiple of QscalesPerBlockRow"); @@ -132,6 +133,9 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; + // BDataType gets converted from PkInt4 during loading + using OverrideBDataType = + std::conditional_t, ADataType, BDataType>; using Base = BlockGemmQuantBase; using WarpGemm = remove_cvref_t; @@ -152,6 +156,8 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase using BWarpTensor = typename WarpGemm::BWarpTensor; using CWarpTensor = typename WarpGemm::CWarpTensor; + static constexpr bool PreshuffleQuant = Traits::PreshuffleQuant; + static_assert(std::is_same_v); static constexpr auto a_warp_y_lengths = @@ -235,7 +241,6 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase template struct BlockGemmImpl { - public: static constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; static constexpr auto BLdsTileDistr = @@ -247,12 +252,20 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase ALdsTile a_warp_tile_; BLdsTile b_warp_tile_; - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant = {}, + bool_constant = {}) { - load_int4_tile(a_warp_tile_, a_block_window); - load_int4_tile(b_warp_tile_, b_block_window); + load_int4_tile( + a_warp_tile_, a_block_window); + // If B datatype were pkint4 it would be converted prior to storing in LDS + load_int4_tile( + b_warp_tile_, b_block_window); } // C += A * B @@ -267,7 +280,6 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase [[maybe_unused]] ASmemBlockWindow& a_block_window, [[maybe_unused]] BSmemBlockWindow& b_block_window) { - static_assert(std::is_same_v, "The CDataType as defined in traits should be the same as corresponding " "C block tensor data type!"); @@ -303,47 +315,78 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); } }); + + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; // a_scale AQPickerCommon aq_picker( aq_block_tensor); - // Multiply bquant with accumulated C - constexpr index_t reg_offset = [&]() { - if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN)) - return (nIter * NWarp * WarpGemm::kN) / - GemmTraits::BQuantGroupSize::kN * Traits::KQPerBlock + - kQScale; + if constexpr(PreshuffleQuant) + { + constexpr index_t reg_offset = nIter; + auto pull_from_lane = + (__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale; + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + // cross lane ops + uint32_t scale_reg_dword; + + if constexpr(std::is_same_v) + { + scale_reg_dword = ck_tile::bit_cast(scale_reg); + } else { - return nIter * Traits::KQPerBlock + kQScale; + scale_reg_dword = static_cast(scale_reg); } - }(); - constexpr auto tbuf_offset = - number{}, - c_warp_y_index_zeros)) / - CBlockTensor::PackedSize>{}; - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - float b_scale_reg_f = - Base::cvt_scale_to_fp32(scale_reg); + // cross lane ops to get the value of scale_reg. + int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( + pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword)); - constexpr auto tbuf_offset = - number{}, - c_warp_y_index_zeros)) / - CBlockTensor::PackedSize>{}; - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - float b_scale_reg_f = - Base::cvt_scale_to_fp32(scale_reg); + float b_scale_reg_f = + Base::cvt_scale_to_fp32( + gathered_scale_reg); - static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( - [&](auto c_row) { - float a_scale_reg_f = aq_picker.template pick(); - c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += - (c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f * - b_scale_reg_f); - }); + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + float a_scale_reg_f = aq_picker.template pick(); + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f * + b_scale_reg_f); + }); + } + else + { + // Multiply bquant with accumulated C + constexpr index_t reg_offset = [&]() { + if constexpr(GemmTraits::BQuantGroupSize::kN >= + (NWarp * WarpGemm::kN)) + return (nIter * NWarp * WarpGemm::kN) / + GemmTraits::BQuantGroupSize::kN * + Traits::KQPerBlock + + kQScale; + else + { + return nIter * Traits::KQPerBlock + kQScale; + } + }(); + + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float b_scale_reg_f = + Base::cvt_scale_to_fp32(scale_reg); + + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + float a_scale_reg_f = aq_picker.template pick(); + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f * + b_scale_reg_f); + }); + } }); }); }); @@ -357,11 +400,16 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase MakeCBlockTile(); } - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) { - block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window); + block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index a360271a09..1f7717ed63 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -125,6 +125,9 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; + // BDataType gets converted from PkInt4 during loading + using OverrideBDataType = + std::conditional_t, ADataType, BDataType>; using Base = BlockGemmQuantBase; using WarpGemm = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp index c13d43500e..cd70c2ca86 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp @@ -34,6 +34,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; using AQuantGroupSize = remove_cvref_t; using BQuantGroupSize = remove_cvref_t; + // BDataType gets converted from PkInt4 during loading + using OverrideBDataType = + std::conditional_t, ADataType, BDataType>; static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); static_assert(AQuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!"); @@ -98,6 +101,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + using Base::PrefetchStages; [[nodiscard]] CK_TILE_HOST static const std::string GetName() @@ -110,7 +116,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 + CK_TILE_DEVICE static void LoadAndConvertATile(ABlockTile_& a_block_tile, + const ADramWindow& a_dram_window) + { + using DestDataType = typename ABlockTile_::DataType; + using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType; + constexpr index_t UnaryOpSize = 8; + load_int4_tile(a_block_tile, a_dram_window); + } + + template + CK_TILE_DEVICE static void LoadAndConvertBTile(BBlockTile_& b_block_tile, + const BDramWindow& b_dram_window) + { + using DestDataType = typename BBlockTile_::DataType; + using SrcDataType = typename BDramWindow::Base::TileWindowBase::DataType; + constexpr index_t UnaryOpSize = 8; + load_int4_tile(b_block_tile, b_dram_window); + } + template ; constexpr bool is_aq_col_major = std::is_same_v; - constexpr bool is_bq_col_major = - std::is_same_v; constexpr bool is_b_row_major = std::is_same_v; - - static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)"); - static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); + constexpr bool is_bq_row_major = + std::is_same_v; static_assert(is_a_col_major ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && @@ -240,13 +263,23 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(p_smem); constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); @@ -290,20 +323,28 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}), 0) - : is_bq_col_major ? make_array(0, KPerBlockBQ) - : make_array(KPerBlockBQ, 0); + : is_bq_row_major ? make_array(KPerBlockBQ, 0) + : make_array(0, KPerBlockBQ); // DRAM prefetch (global read 0) - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + // Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + + LoadAndConvertATile(a_block_tile, a_copy_dram_window); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + // B tile gets converted to A datatype during loading + LoadAndConvertBTile(b_block_tile, b_copy_dram_window); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch( aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step); Base::GlobalPrefetch( @@ -311,7 +352,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( Policy::template MakeShuffledARegTileDistribution()); @@ -323,7 +364,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( Policy::template MakeShuffledBRegTileDistribution()); @@ -335,12 +376,18 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( Policy::template MakeShuffledARegTileDistribution()); @@ -364,9 +411,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: BDataType PkInt4 gets converted during loading earlier + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -376,8 +424,16 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: BDataType gets converted during loading from PkInt4 + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -450,7 +508,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; + KPerBlockBQ, // Logical K dimension + NPerBlockBQ, // Logical N dimension + Problem::BQuantGroupSize::kN, + BQLayout>; return TileEncodingPattern::make_2d_static_tile_distribution(); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp index ceb8e0e917..d8458c0b39 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp @@ -52,7 +52,7 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase