From 375e499d10ecb72261e2c779fd973a14006b6c44 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Mon, 8 Dec 2025 21:13:22 +0000 Subject: [PATCH] Merge commit 'c363a98d4154c647c1a2d5331ad0d76879b84dfa' into develop --- .../17_grouped_gemm/quant_grouped_gemm.cpp | 225 ++++++++++-- .../17_grouped_gemm/quant_grouped_gemm.hpp | 75 +++- .../quant_run_grouped_gemm_example.inc | 233 ++++++++---- .../block_universal_gemm_as_bs_bquant_cr.hpp | 32 +- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 66 +++- .../kernel/grouped_gemm_quant_kernel.hpp | 109 +++++- .../gemm_aquant_pipeline_ag_bg_cr_v3.hpp | 49 ++- .../gemm_bquant_pipeline_ag_bg_cr_base.hpp | 14 +- .../gemm_bquant_pipeline_ag_bg_cr_policy.hpp | 22 +- .../gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 85 +++-- .../pipeline/gemm_group_quant_utils.hpp | 151 +++++--- .../gemm_block_scale/test_gemm_quant_base.hpp | 4 +- .../test_gemm_quant_bquant.cpp | 76 ++-- .../test_gemm_quant_bquant_preshuffle.cpp | 90 ++--- .../test_gemm_quant_fixtures.hpp | 8 +- .../ck_tile/grouped_gemm_quant/CMakeLists.txt | 3 + .../test_grouped_gemm_quant.cpp | 51 +-- .../test_grouped_gemm_quant_aquant.cpp | 38 ++ .../test_grouped_gemm_quant_bquant.cpp | 11 +- .../test_grouped_gemm_quant_rowcol.cpp | 13 +- .../test_grouped_gemm_quant_tensor.cpp | 13 +- .../test_grouped_gemm_util_quant.hpp | 334 +++++++++++++++--- 22 files changed, 1307 insertions(+), 395 deletions(-) create mode 100644 test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index d8b905fe3d..d3b75ac72f 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -9,14 +9,190 @@ #include #include #include +#include #include "ck_tile/core.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/gemm_quant.hpp" #include "ck_tile/host.hpp" #include "quant_grouped_gemm.hpp" +template +float grouped_gemm(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) +{ + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = + GemmQuantConfig::template BaseGemmPipeline; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + + constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::BQuantGrouped; + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, + ck_tile::GemmRowColTensorQuantPipelineProblem>; + + using GemmPipeline = + GemmQuantConfig::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC, + memory_operation>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; + + return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); +} + template ; // Persistence + GemmConfig::Persistent>; float ave_time{0}; const auto Run = [&](const auto memory_operation_) { constexpr auto scheduler = GemmConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - constexpr bool transpose_c = false; - using QuantGemmProblem = typename std::conditional< - QuantMode == ck_tile::QuantType::BQuantGrouped, - ck_tile::GemmBQuantPipelineProblem, + constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::BQuantGrouped; + + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, ck_tile::GemmRowColTensorQuantPipelineProblem>::type; + scheduler>>; - using GemmPipeline = std::conditional_t< - QuantMode == ck_tile::QuantType::RowColQuant || - QuantMode == ck_tile::QuantType::TensorQuant, - ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + using GemmPipeline = + GemmQuantConfig::template GemmPipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem(argc, argv); + int result1 = run_grouped_gemm_example(argc, argv); return result1; } diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp index ede683abe6..0317685770 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp @@ -64,6 +64,7 @@ struct GemmTypeConfig using CDataType = ck_tile::half_t; }; +template struct GemmConfigBase { static constexpr bool kPadM = false; @@ -83,10 +84,11 @@ struct GemmConfigBase static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool DoubleSmemBuffer = false; static constexpr bool PreshuffleB = false; + static constexpr bool Persistent = Persistent_; }; -template -struct GemmConfigComputeV3_2 : public GemmConfigBase +template +struct GemmConfigComputeV3_2 : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -101,8 +103,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; -template -struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase +template +struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -121,6 +123,66 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase static constexpr bool DoubleSmemBuffer = true; }; +template +struct GemmQuantConfig; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill; + + template + using GemmPipeline = std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>; + + template + using BaseGemmPipeline = + std::conditional_t, + ck_tile::BaseGemmPipelineAgBgCrCompV3>; +}; + using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; auto create_args(int argc, char* argv[]) @@ -148,8 +210,9 @@ auto create_args(int argc, char* argv[]) .insert("repeat", "100", "number of iterations to benchmark the kernel.") .insert("group_count", "8", "group count.") .insert("kbatch", "1", "kbatch for SplitK") - .insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol") - .insert("init", "0", "0. Random, 2. One(s) (Constant)"); + .insert("quant_mode", "bquant", "Choose aquant, bquant (default), tensor, or rowcol") + .insert("init", "0", "0. Random, 2. One(s) (Constant)") + .insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc index 37fab44f77..37832b54ba 100644 --- a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc @@ -57,56 +57,83 @@ float invoke_gemm(int n_warmup, float ave_time = 0; - // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have - // the gemm problems known on the host. Instead, we can just pass the pointer - // to the kernel and let the workgroups figure out which tiles to work on. - // This is useful when the gemm problems are generated dynamically. - // In this example however, we generate the `kargs` using the known gemm_descs, - // and copy the gemm descriptions to the device memory. - // The contents of the memory pointed to by `kargs_ptr` pointer could be - // written by e.g. another kernel from earlier stage. - std::vector kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); - assert(args[0].k_batch == 1); - for(const auto& arg : args) + if constexpr(!GemmConfig::Persistent) { - kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, - arg.b_ptr, - arg.aq_ptr, - arg.bq_ptr, - arg.e_ptr, - arg.M, - arg.N, - arg.K, - arg.QK_A, - arg.QK_B, - arg.stride_A, - arg.stride_B, - arg.stride_E, - arg.stride_AQ, - arg.stride_BQ, - arg.k_batch}); + ave_time = + grouped_gemm(args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, + gemm_workspace.GetDeviceBuffer()); + } + else + { + // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have + // the gemm problems known on the host. Instead, we can just pass the pointer + // to the kernel and let the workgroups figure out which tiles to work on. + // This is useful when the gemm problems are generated dynamically. + // In this example however, we generate the `kargs` using the known gemm_descs, + // and copy the gemm descriptions to the device memory. + // The contents of the memory pointed to by `kargs_ptr` pointer could be + // written by e.g. another kernel from earlier stage. + std::vector kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + if(args[0].k_batch != 1) + { + throw std::runtime_error("Split-K not supported yet for persistent kernel"); + } + + for(const auto& arg : args) + { + kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, + arg.b_ptr, + arg.aq_ptr, + arg.bq_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + arg.QK_A, + arg.QK_B, + arg.stride_A, + arg.stride_B, + arg.stride_E, + arg.stride_AQ, + arg.stride_BQ, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + ave_time = grouped_gemm_tileloop(stream, group_count, kargs_ptr); } - const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), - hipMemcpyHostToDevice, - stream.stream_id_)); - ave_time = grouped_gemm_tileloop(stream, group_count, kargs_ptr); std::string op_name = "Quant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")"; @@ -259,13 +286,24 @@ int run_grouped_gemm_example_with_layouts(int argc, AQK = 1; // Row quantization: tensor shape [M, 1] or [1] BQK = 1; // Column quantization: tensor shape [1, N] or [1] } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize + BQK = 0; // No B quantization + if(K % QuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by QuantGroupSize::kK for AQuantGrouped mode"); + } + } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { AQK = 0; // No A quantization BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize if(K % QuantGroupSize::kK != 0) { - throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode"); + throw std::runtime_error( + "K must be divisible by QuantGroupSize::kK for BQuantGrouped mode"); } } @@ -284,6 +322,12 @@ int run_grouped_gemm_example_with_layouts(int argc, stride_AQs[i] = 1; // Tensor quantization: tensor shape [1] stride_BQs[i] = 1; // Tensor quantization: tensor shape [1] } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + stride_AQs[i] = + ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout)); + stride_BQs[i] = 0; // No B quantization + } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { stride_AQs[i] = 0; // No A quantization @@ -311,10 +355,17 @@ int run_grouped_gemm_example_with_layouts(int argc, bq_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout)))); } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + aq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout)))); + bq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(0, 0, stride_BQs[i], is_row_major(bq_layout)))); + } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { aq_tensors.push_back(ck_tile::HostTensor( - ck_tile::host_tensor_descriptor(0, AQK, stride_AQs[i], is_row_major(aq_layout)))); + ck_tile::host_tensor_descriptor(0, 0, stride_AQs[i], is_row_major(aq_layout)))); bq_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout)))); } @@ -444,7 +495,7 @@ int run_grouped_gemm_example_with_layouts(int argc, bq_tensors[i], c_m_n_host_ref); } - else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) { ck_tile::reference_gemm_quant( + a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); + } + else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + ck_tile::reference_gemm_quant( a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); } @@ -477,7 +539,7 @@ int run_grouped_gemm_example_with_layouts(int argc, return pass; } -template +template int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) { using Row = ck_tile::tensor_layout::gemm::RowMajor; @@ -494,6 +556,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a if(a_layout == "R" && b_layout == "C") { + return run_grouped_gemm_example_with_layouts typename GemmConfig> +template +int run_gemm_example_persistency( + std::string a_layout, std::string b_layout, bool persistent, int argc, char* argv[]) +{ + if(persistent) + { + using GemmConfig = GemmQuantConfig::template GemmConfig; + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } + else + { + using GemmConfig = GemmQuantConfig::template GemmConfig; + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } +} + int run_grouped_gemm_example(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); @@ -524,29 +604,29 @@ int run_grouped_gemm_example(int argc, char* argv[]) const std::string b_layout = arg_parser.get_str("b_layout"); const std::string data_type = arg_parser.get_str("prec"); std::string quant_mode = arg_parser.get_str("quant_mode"); + bool persistent = arg_parser.get_bool("persistent"); if(data_type == "fp8") { if(quant_mode == "tensor") { - return run_gemm_example_prec_type, - ck_tile::fp8_t, - ck_tile::QuantType::TensorQuant>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else if(quant_mode == "rowcol") { - return run_gemm_example_prec_type, - ck_tile::fp8_t, - ck_tile::QuantType::RowColQuant>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); + } + else if(quant_mode == "aquant") + { + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else if(quant_mode == "bquant") { - return run_gemm_example_prec_type, - ck_tile::fp8_t, - ck_tile::QuantType::BQuantGrouped>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else { @@ -557,24 +637,23 @@ int run_grouped_gemm_example(int argc, char* argv[]) { if(quant_mode == "tensor") { - return run_gemm_example_prec_type, - ck_tile::bf8_t, - ck_tile::QuantType::TensorQuant>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else if(quant_mode == "rowcol") { - return run_gemm_example_prec_type, - ck_tile::bf8_t, - ck_tile::QuantType::RowColQuant>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); + } + else if(quant_mode == "aquant") + { + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else if(quant_mode == "bquant") { - return run_gemm_example_prec_type, - ck_tile::bf8_t, - ck_tile::QuantType::BQuantGrouped>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else { diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index d97145cbc3..628e9194ae 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -61,6 +61,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using BQDataType = remove_cvref_t; + using BQLayout = remove_cvref_t; using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; @@ -154,6 +155,10 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; + // BDataType gets converted from PkInt4 during loading + using OverrideBDataType = + std::conditional_t, ADataType, BDataType>; + using Base = BlockGemmBQuantBase; using WarpGemm = remove_cvref_t; @@ -271,12 +276,20 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase ALdsTile a_warp_tile_; BLdsTile b_warp_tile_; - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant = {}, + bool_constant = {}) { - load_int4_tile(a_warp_tile_, a_block_window); - load_int4_tile(b_warp_tile_, b_block_window); + load_int4_tile( + a_warp_tile_, a_block_window); + // If B datatype were pkint4 it would be converted prior to storing in LDS + load_int4_tile( + b_warp_tile_, b_block_window); } // C += A * B @@ -397,11 +410,16 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase MakeCBlockTile(); } - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) { - block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window); + block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index dd85705cf2..203b79aec6 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -426,7 +426,6 @@ struct QuantGemmKernel if constexpr(kQuantType == QuantType::BQuantGrouped) { - static_assert(std::is_same_v); if(kargs.QK_B % GemmPipeline::GetVectorSizeBQ() != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) @@ -781,7 +780,9 @@ struct QuantGemmKernel { if constexpr(PreshuffleQuant) { - static_assert(std::is_same_v); + static_assert(std::is_same_v, + "PreshuffleQuant with BQuantGrouped currently only supports " + "ColumnMajor BQ layout"); return MakePreshuffledQuantTensorView< GemmPipeline::KPerBlockBQ, @@ -791,14 +792,35 @@ struct QuantGemmKernel } else { - static_assert(std::is_same_v); using QuantGroupSize = remove_cvref_t; - return make_naive_tensor_view( - bq_ptr, - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B), - make_tuple(kargs.stride_BQ, 1), - number{}, - number<1>{}); + + if constexpr(std::is_same_v) + { + // For RowMajor BQ: memory layout is [K/QuantGroupK][N/QuantGroupN] + // Dimensions: [K/QuantGroupK, N/QuantGroupN] + // Strides: [N/QuantGroupN, 1] + return make_naive_tensor_view( + bq_ptr, + make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), + integer_divide_ceil(kargs.N, QuantGroupSize::kN)), + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1), + number{}, + number<1>{}); + } + else + { + static_assert(std::is_same_v); + // For ColumnMajor BQ: memory layout is [N/QuantGroupN][K/QuantGroupK] + // Dimensions: [N/QuantGroupN, K/QuantGroupK] + // Strides: [K/QuantGroupK, 1] + return make_naive_tensor_view( + bq_ptr, + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), + integer_divide_ceil(kargs.K, QuantGroupSize::kK)), + make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1), + number{}, + number<1>{}); + } } } else @@ -1023,10 +1045,10 @@ struct QuantGemmKernel } else if constexpr(kQuantType == QuantType::BQuantGrouped) { + using QuantGroupSize = remove_cvref_t; if constexpr(PreshuffleQuant) { static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN; constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; @@ -1042,13 +1064,23 @@ struct QuantGemmKernel } else { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - return make_tile_window( - bq_pad_view, - make_tuple(number{}, - number{}), - {i_n / QuantGroupSize::kN, 0}); + if constexpr(std::is_same_v) + { + return make_tile_window( + bq_pad_view, + make_tuple(number{}, + number{}), + {0, i_n / QuantGroupSize::kN}); + } + else + { + static_assert(std::is_same_v); + return make_tile_window( + bq_pad_view, + make_tuple(number{}, + number{}), + {i_n / QuantGroupSize::kN, 0}); + } } } else diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index caa6aad363..726f678d37 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -163,7 +163,6 @@ struct QuantGroupedGemmKernel static constexpr index_t kBlockSize = GemmPipeline::BlockSize; static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel; - static_assert(UsePersistentKernel == true, "UsePersistentKernel must be true"); [[nodiscard]] CK_TILE_HOST static const std::string GetName() { @@ -262,10 +261,9 @@ struct QuantGroupedGemmKernel auto karg = QuantGroupedGemmKernelArgs{type_convert(gemm_descs[i].a_ptr), type_convert(gemm_descs[i].b_ptr), - type_convert(gemm_descs[i].e_ptr), type_convert(gemm_descs[i].aq_ptr), type_convert(gemm_descs[i].bq_ptr), - gemm_descs[i].k_batch, + type_convert(gemm_descs[i].e_ptr), M, N, K, @@ -275,7 +273,8 @@ struct QuantGroupedGemmKernel stride_b, stride_e, gemm_descs[i].stride_AQ, - gemm_descs[i].stride_BQ}; + gemm_descs[i].stride_BQ, + gemm_descs[i].k_batch}; gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); } @@ -342,16 +341,32 @@ struct QuantGroupedGemmKernel else { - RunGemmWithPipelineSelection(a_ptr, - b_ptr, - aq_ptr, - bq_ptr, - c_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); + if constexpr(UsePersistentKernel) + { + RunGemmWithPipelineSelection(a_ptr, + b_ptr, + aq_ptr, + bq_ptr, + c_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + else // Non-persistent kernel + { + Base::RunGemm({a_ptr}, + {b_ptr}, + aq_ptr, + bq_ptr, + c_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } } } @@ -451,7 +466,24 @@ struct QuantGroupedGemmKernel const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - if constexpr(kQuantType == QuantType::BQuantGrouped) + if constexpr(kQuantType == QuantType::AQuantGrouped) + { + const auto& aq_block_window = gemm_tile_windows.at(Base::I1); + // Run GEMM pipeline + const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, + b_block_window, + aq_block_window, + num_loop, + has_hot_loop, + tail_num, + smem_ptr_0); + + auto& c_block_window = gemm_tile_windows.at(Base::I4); + + // Run Epilogue Pipeline + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + else if constexpr(kQuantType == QuantType::BQuantGrouped) { const auto& bq_block_window = gemm_tile_windows.at(Base::I3); // Run GEMM pipeline @@ -496,6 +528,53 @@ struct QuantGroupedGemmKernel } } + CK_TILE_DEVICE index_t FindGroupId(const QuantGemmTransKernelArg* gemm_desc_ptr, + index_t block_id, + index_t group_count) const + { + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) >> 1); + + while((!(block_id >= gemm_desc_ptr[group_id].block_start && + block_id < gemm_desc_ptr[group_id].block_end)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].block_start) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) >> 1); + } + + return group_id; + } + + // For non-persistent kernels + template > + CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + index_t group_count) const + { + const index_t block_id = ck_tile::get_block_1d_id(); + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count); + const auto& kargs = gemm_desc_ptr[group_id]; + + const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N); + const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex( + 0, + kargs.group_karg.M, + kargs.group_karg.N, + (block_id - kargs.block_start) % grid_size_2d); + Run(kargs.group_karg, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d); + } + // For persistent kernels template , diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 30b9d70eb8..e7bd4a2626 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -319,6 +319,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* p_smem, + index_t m = 0) const + { + const auto RunPipeline = [&](auto has_hot_loop_, auto tail_number_) { + constexpr bool hot_loop = has_hot_loop_.value; + constexpr auto tail_num = tail_number_.value; + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + aq_dram_block_window_tmp, + m, // dummy value, won't be used + num_loop, + p_smem); + }; + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp index 4cd343e640..c570d4a131 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp @@ -42,14 +42,18 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase); - - using YPerTile = number; - using XPerTile = number; + using YPerTile = + std::conditional_t, + number, + number>; + using XPerTile = + std::conditional_t, + number, + number>; auto bq_copy_dram_window = make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(YPerTile(), XPerTile()), + make_tuple(YPerTile{}, XPerTile{}), bq_dram_block_window_tmp.get_window_origin(), Policy::template MakeBQDramTileDistribution()); return bq_copy_dram_window; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index 870326cb9d..154d068f0a 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -25,8 +25,16 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK; - static_assert(std::is_same_v); - return GetABQGlobalVectorLoadSize(); + // Support both RowMajor and ColumnMajor layouts for BQ + if constexpr(std::is_same_v) + { + return GetABQGlobalVectorLoadSize(); + } + else + { + static_assert(std::is_same_v); + return GetABQGlobalVectorLoadSize(); + } } template @@ -52,7 +60,6 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC WarpTile::at(I2), Problem::TransposeC>; - static_assert(std::is_same_v); if constexpr(PreshuffleQuant) { using TileEncodingPattern = tile_distribution_encoding_pattern_bq< @@ -62,18 +69,21 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC NPerBlock / WarpGemm::kN, ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()), VecLoadSize, + BQLayout, PreshuffleQuant>; return TileEncodingPattern::make_2d_static_tile_distribution(); } else { + // KPerTile and NPerTile are LOGICAL dimensions (K quant groups and N quant groups) using TileEncodingPattern = tile_distribution_encoding_pattern_bq; + KPerBlockBQ, // Logical K dimension + NPerBlockBQ, // Logical N dimension + Problem::QuantGroupSize::kN, + BQLayout>; return TileEncodingPattern::make_2d_static_tile_distribution(); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 4883a30f57..2c191cc2b4 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -33,6 +33,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; using QuantGroupSize = remove_cvref_t; + // BDataType gets converted from PkInt4 during loading + using OverrideBDataType = + std::conditional_t, ADataType, BDataType>; + static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); using I0 = number<0>; using I1 = number<1>; @@ -83,6 +87,9 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + using Base::PrefetchStages; [[nodiscard]] CK_TILE_HOST static const std::string GetName() @@ -125,7 +132,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 + CK_TILE_DEVICE static void LoadAndConvertBTile(BBlockTile_& b_block_tile, + const BDramWindow& b_dram_window) + { + using DestDataType = typename BBlockTile_::DataType; + using SrcDataType = typename BDramWindow::Base::TileWindowBase::DataType; + constexpr index_t UnaryOpSize = 8; + load_int4_tile(b_block_tile, b_dram_window); + } + template ; - constexpr bool is_bq_col_major = - std::is_same_v; constexpr bool is_b_row_major = std::is_same_v; - - static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); + constexpr bool is_bq_row_major = + std::is_same_v; static_assert(is_a_col_major ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && @@ -212,12 +227,22 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(p_smem); constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); @@ -237,7 +262,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(ABlockTileDistr{})); using BBlockTile = - decltype(make_static_distributed_tensor(BBlockTileDistr{})); + decltype(make_static_distributed_tensor(BBlockTileDistr{})); using BQBlockTile = decltype(make_static_distributed_tensor(BQBlockTileDistr{})); @@ -258,18 +283,20 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}), 0) - : is_bq_col_major ? make_array(0, KPerBlockBQ) - : make_array(KPerBlockBQ, 0); + : is_bq_row_major ? make_array(KPerBlockBQ, 0) + : make_array(0, KPerBlockBQ); // DRAM prefetch (global read 0) Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // B tile gets converted to A datatype during loading + LoadAndConvertBTile(b_block_tile, b_copy_dram_window); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); Base::GlobalPrefetch( bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -281,9 +308,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // B datatype is converted to A datatype during loading + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -294,11 +322,13 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( Policy::template MakeShuffledARegTileDistribution()); @@ -322,9 +352,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: BDataType PkInt4 gets converted during loading earlier + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -335,7 +366,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: BDataType gets converted during loading from PkInt4 + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -393,7 +427,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern { @@ -210,36 +211,41 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding /// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales) /// /// This function determines the optimal thread distribution pattern for loading and applying - /// quantization scales to the B matrix based on the quantization group size (XPerQ) relative + /// quantization scales to the B matrix based on the quantization group size (NPerQ) relative /// to warp dimensions. /// /// Three distinct distribution patterns are handled: /// - /// 1. Fine-grained quantization (XPerQ < WarpGemm::kN): + /// 1. Fine-grained quantization (NPerQ < WarpGemm::kN): /// - Multiple quantization groups exist within a single warp's N-dimension - /// - Each warp processes multiple scales (WarpGemm::kN / XPerQ scales per warp) - /// - Distribution includes explicit replication factor (XR = XPerQ) for scale broadcast - /// - Example: XPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp + /// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp) + /// - Distribution includes explicit replication factor (XR = NPerQ) for scale broadcast + /// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp /// - /// 2. Medium-grained quantization (WarpGemm::kN <= XPerQ <= WarpGemm::kN * NWarps): + /// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps): /// - Each warp handles exactly one quantization scale - /// - Scales are distributed across warps with replication factor XR = XPerQ / WarpGemm::kN - /// - Example: XPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4 + /// - Scales are distributed across warps with replication factor XR = NPerQ / WarpGemm::kN + /// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4 /// - /// 3. Coarse-grained quantization (XPerQ > WarpGemm::kN * NWarps): + /// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps): /// - Quantization group spans multiple warps /// - All warps share the same scale value - /// - Example: XPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale + /// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale /// /// @return A static tile distribution encoding for the BQ scale tensor CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() { + // Preshuffle only supported for ColumnMajor currently + static_assert(!(PreshuffleQuant && std::is_same_v), + "PreshuffleQuant only supported for ColumnMajor BQLayout"); + if constexpr(PreshuffleQuant) { + // ColumnMajor only for preshuffle constexpr index_t X1 = warp_size; - constexpr index_t X0 = XPerTile / warp_size; + constexpr index_t X0 = NPerTile / warp_size; constexpr index_t Y1 = NWarps; - constexpr index_t Y0 = YPerTile / Y1; + constexpr index_t Y0 = KPerTile / Y1; return make_static_tile_distribution( tile_distribution_encoding, @@ -251,52 +257,97 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding } else { - if constexpr(YPerQ < WarpGemm::kN) + if constexpr(NPerQ < WarpGemm::kN) { // Case 1: Fine-grained - multiple quantization scales within a single warp - constexpr index_t X = XPerTile; // Full X dimension of tile - constexpr index_t XR = 1; // No Y replication needed - constexpr index_t Y0 = NIterPerWarp; // Iterations per warp in N-dim - constexpr index_t Y1 = NWarps; // Number of warps in N-dim - constexpr index_t Y2 = WarpGemm::kN / YPerQ; // Number of scales per warp - constexpr index_t YR = YPerQ; // Elements per quantization group + // N dimension needs to be partitioned the same way regardless of layout + constexpr index_t NR = 1; // No N replication needed + constexpr index_t N0 = NIterPerWarp; // Iterations per warp in N-dim + constexpr index_t N1 = NWarps; // Number of warps in N-dim + constexpr index_t N2 = WarpGemm::kN / NPerQ; // Number of scales per warp - static_assert(Y0 * Y1 * Y2 == YPerTile, - "Y0, Y1, Y2 must cover the blocktile along Y."); + static_assert(N0 * N1 * N2 == NPerTile, + "N0, N1, N2 must cover the blocktile along N dimension."); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0, 1, 0>>, - tuple, sequence<1, 2, 2>>, - sequence<1, 2>, - sequence<0, 0>>{}); + if constexpr(std::is_same_v) + { + // ColumnMajor: [(N0, N1, N2), K] - N on Y-axis, partition Y + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 1, 0>>, + tuple, sequence<1, 2, 2>>, + sequence<1, 2>, + sequence<0, 0>>{}); + } + else + { + // RowMajor: [K, (N0, N1, N2)] - N on X-axis, partition X + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 2, 0>>, + tuple, sequence<1, 2, 2>>, + sequence<2, 1>, + sequence<0, 0>>{}); + } } - else if constexpr(YPerQ <= WarpGemm::kN * NWarps) + else if constexpr(NPerQ <= WarpGemm::kN * NWarps) { // Case 2: Medium-grained - one quantization scale per warp - constexpr auto YR = YPerQ / WarpGemm::kN; // Scale replication factor - constexpr auto Y1 = NWarps / YR; // Warps per unique scale - constexpr auto Y0 = YPerTile / Y1; // Iterations to cover X dimension - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0>>, - tuple, sequence<2>>, - sequence<1, 2>, - sequence<0, 0>>{}); + constexpr auto NR = NPerQ / WarpGemm::kN; // Scale replication factor + constexpr auto N1 = NWarps / NR; // Warps per unique scale + constexpr auto N0 = NPerTile / N1; // Iterations to cover N dimension + + if constexpr(std::is_same_v) + { + // ColumnMajor: [(N0, N1), K] - N on Y-axis + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<1, 2>, + sequence<0, 0>>{}); + } + else + { + // RowMajor: [K, (N0, N1)] - N on X-axis + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<2, 1>, + sequence<0, 0>>{}); + } } - else // XPerQ > WarpGemm::kN * NWarps + else // NPerQ > WarpGemm::kN * NWarps { // Case 3: Coarse-grained - quantization group spans all warps // All warps in N-dimension share the same quantization scale - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0>>, - tuple, sequence<2>>, - sequence<2, 1>, - sequence<0, 0>>{}); + if constexpr(std::is_same_v) + { + // ColumnMajor: [N, K] + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<1, 2>, + sequence<0, 0>>{}); + } + else + { + // RowMajor: [K, N] + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<2, 1>, + sequence<0, 0>>{}); + } } } } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 38bd59b882..39a7c66f38 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -86,8 +86,8 @@ class TestCkTileGemmQuantBase : public ::testing::Test using TilePartitioner = ck_tile::GemmTile1DPartitioner; - // BQLayout is always ColumnMajor for BQuant - using BQLayout = ck_tile::tensor_layout::gemm::ColumnMajor; + // Re-use the AQLayout for BQLayout + using BQLayout = AQLayout; using CodegenGemmTraits = ck_tile::TileGemmQuantTraits>; using GroupSize2D128N = ck_tile::QuantGroupShape>; // Type combinations for BQuant tests (without PreshuffleB) -// Tuple format: // clang-format off using BQuantTypes = ::testing::Types< - // 1d cases with grouping only on k axis (AQLayout is always RowMajor for BQuant) - std::tuple, - std::tuple, - std::tuple, - std::tuple, + // 1d cases with grouping only on k axis + std::tuple, + std::tuple, + std::tuple, + std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, // 2d cases with grouping also on the n axis - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // some cases with transpose layouts + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, + std::tuple, + std::tuple, + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, + std::tuple, + std::tuple, + + // pkint4 + transpose cases + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, + std::tuple, + std::tuple, + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp index 6cde4bded5..3a62fc091a 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp @@ -26,60 +26,60 @@ using GroupSize2D32N = ck_tile::QuantGroupShape>; using GroupSize2D64N = ck_tile::QuantGroupShape>; // Type combinations for BQuant tests with PreshuffleB -// Tuple format: // clang-format off using BPreshuffleBQuantTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, // //2d cases with preshuffle B - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 7b16529aa8..bf9c7a138d 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -389,6 +389,9 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBaseis_row_major(BQLayout{}) ? BQN : BQK; // Generate test data ck_tile::HostTensor a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); ck_tile::HostTensor b_k_n( ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); - // BQ is always ColumnMajor ck_tile::HostTensor bq_bqk_bqn( - ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, ck_tile::bool_constant{})); + ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BQLayout{}))); // Initialize data with random values ck_tile::FillUniformDistribution{-0.5f, 0.5f}(a_m_k); diff --git a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt index 2bd2571993..7a7ae77730 100644 --- a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt @@ -14,6 +14,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp) target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp) + target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp) target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp index 551989421f..6a1a28884a 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp @@ -18,32 +18,41 @@ using True = ck_tile::bool_constant; using False = ck_tile::bool_constant; using RowColQuant = std::integral_constant; using TensorQuant = std::integral_constant; +using AQuant = std::integral_constant; using BQuant = std::integral_constant; // clang-format off using KernelTypes = ::testing::Types< - // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, - std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True>, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True> + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp new file mode 100644 index 0000000000..8dcd6d017d --- /dev/null +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_util_quant.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using True = ck_tile::bool_constant; +using False = ck_tile::bool_constant; +using AQuant = std::integral_constant; + +// clang-format off +using KernelTypes_AQuant = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, True>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, False> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_AQuant, KernelTypes_AQuant); + +#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_AQuant +#include "test_grouped_gemm_quant_ut_cases.inc" +#undef TEST_CLASS_NAME diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp index 4f44acf4c4..6c0ad545b7 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp @@ -20,9 +20,14 @@ using BQuant = std::integral_constant, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True> + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, False, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, False, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp index 48720aeebf..cc1b32fb20 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp @@ -20,11 +20,14 @@ using RowColQuant = std::integral_constant, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False> + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, False, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, False, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp index f59fa29ec2..e446f7b168 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp @@ -20,11 +20,14 @@ using TensorQuant = std::integral_constant, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False> + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp index 68b6735655..9941066c3e 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp @@ -3,6 +3,7 @@ #pragma once #include #include +#include #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" @@ -32,24 +33,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using AQLayout = Row; using BQLayout = Col; - static constexpr bool Persistent = true; static constexpr bool PreshuffleB = std::tuple_element_t<10, Tuple>::value; - - template - static constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile() - { -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif - } + static constexpr bool Persistent = std::tuple_element_t<11, Tuple>::value; + static constexpr bool TransposeC = std::tuple_element_t<12, Tuple>::value; struct GroupedGemKernelParam_Mfma { @@ -66,11 +52,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test static const ck_tile::index_t N_Warp = 2; static const ck_tile::index_t K_Warp = 1; - static const ck_tile::index_t M_Warp_Tile = 32; - static const ck_tile::index_t N_Warp_Tile = 32; - static const ck_tile::index_t K_Warp_Tile = - TestCkTileGroupedGemmQuant::template get_k_from_preshuffled_warp_tile(); + static const ck_tile::index_t M_Warp_Tile = 16; + static const ck_tile::index_t N_Warp_Tile = 16; + static const ck_tile::index_t K_Warp_Tile = 32; }; struct GroupedGemKernelParam_Wmma : public GroupedGemKernelParam_Mfma @@ -90,16 +74,201 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg); } + template + float invoke_grouped_gemm(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) + { + constexpr bool DoubleSmemBuffer = + PreshuffleB; // currently DoubleSmemBuffer is only supported for preshuffled B + + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped || + QuantType == ck_tile::QuantType::BQuantGrouped; + + using QuantGroupSize = ck_tile::QuantGroupShape>; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = std::conditional_t< + UseGroupedQuant, + std::conditional_t< + QuantType == ck_tile::QuantType::AQuantGrouped, + ck_tile::BaseGemmPipelineAgBgCrCompV3, + std::conditional_t< + PreshuffleB == true, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, + ck_tile::BaseGemmPipelineAgBgCrCompV3>>, + ck_tile::BaseGemmPipelineAgBgCrCompV3>; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GroupedGemKernelParam::K_Tile; + const ck_tile::index_t K_split = + (gemm_descs[0].K + k_grain - 1) / k_grain * GroupedGemKernelParam::K_Tile; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, + ck_tile::GemmRowColTensorQuantPipelineProblem>; + + using GemmPipeline = std::conditional_t< + UseGroupedQuant, + std::conditional_t< + QuantType == ck_tile::QuantType::AQuantGrouped, + ck_tile::AQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>, + ck_tile::GemmPipelineAgBgCrCompV3>; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GroupedGemKernelParam::M_Warp, + GroupedGemKernelParam::N_Warp, + GroupedGemKernelParam::M_Warp_Tile, + GroupedGemKernelParam::N_Warp_Tile, + GroupedGemKernelParam::K_Warp_Tile, + QuantGemmProblem::TransposeC, + memory_operation>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; + + return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + } + template void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, void* kargs_ptr) { - constexpr bool TransposeC = false; constexpr bool DoubleSmemBuffer = PreshuffleB; // currently DoubleSmemBuffer is only supported for preshuffled B - constexpr int kBlockPerCu = 1; constexpr ck_tile::index_t TileParitionerGroupNum = 8; constexpr ck_tile::index_t TileParitionerM01 = 4; @@ -131,40 +300,53 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test BQLayout, TransposeC, DoubleSmemBuffer, - true>; + Persistent>; const auto Run = [&](const auto memory_operation_) { constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; constexpr auto memory_operation = memory_operation_.value; - constexpr bool transpose_c = false; // We create the GEMM pipeline without specifying hotloop or tailnumber. // These are automatically run inside the kernel based on the given input data. - using QuantGemmProblem = typename std::conditional< - QuantType == ck_tile::QuantType::BQuantGrouped, - ck_tile::GemmBQuantPipelineProblem, + + constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped || + QuantType == ck_tile::QuantType::BQuantGrouped; + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, ck_tile::GemmRowColTensorQuantPipelineProblem>::type; + scheduler>>; using GemmPipeline = std::conditional_t< - QuantType == ck_tile::QuantType::RowColQuant || - QuantType == ck_tile::QuantType::TensorQuant, - ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + UseGroupedQuant, + std::conditional_t< + QuantType == ck_tile::QuantType::AQuantGrouped, + ck_tile::AQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>, + ck_tile::GemmPipelineAgBgCrCompV3>; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem( + ck_tile::make_kernel( Kernel{}, grids, blocks, @@ -292,13 +474,24 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test AQK = 1; // Row quantization: tensor shape [M, 1] or [1] BQK = 1; // Column quantization: tensor shape [1, N] or [1] } + else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped) + { + AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize + BQK = 0; // No B quantization + if(K % QuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by QuantGroupSize::kK for AQuantGrouped mode"); + } + } else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) { - AQK = 0; // No A quantization - BQK = K / 128; // Group quantization: BQK = K / GroupSize - if(K % 128 != 0) + AQK = 0; // No A quantization + BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize + if(K % QuantGroupSize::kK != 0) { - throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode"); + throw std::runtime_error( + "K must be divisible by QuantGroupSize::kK for BQuantGrouped mode"); } } @@ -317,6 +510,12 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test stride_AQs[i] = 1; // Tensor quantization: tensor shape [1] stride_BQs[i] = 1; // Tensor quantization: tensor shape [1] } + else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped) + { + stride_AQs[i] = + ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(AQLayout())); + stride_BQs[i] = 0; // No B quantization + } else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) { stride_AQs[i] = 0; // No A quantization @@ -348,11 +547,20 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test ck_tile::HostTensor(ck_tile::host_tensor_descriptor( 1, 1, stride_BQs[i], is_row_major(BQLayout())))); } + else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped) + { + aq_tensors.push_back( + ck_tile::HostTensor(ck_tile::host_tensor_descriptor( + M, AQK, stride_AQs[i], is_row_major(AQLayout{})))); + bq_tensors.push_back( + ck_tile::HostTensor(ck_tile::host_tensor_descriptor( + 0, 0, stride_BQs[i], is_row_major(BQLayout())))); + } else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) { aq_tensors.push_back( ck_tile::HostTensor(ck_tile::host_tensor_descriptor( - 0, AQK, stride_AQs[i], is_row_major(AQLayout{})))); + 0, 0, stride_AQs[i], is_row_major(AQLayout{})))); bq_tensors.push_back( ck_tile::HostTensor(ck_tile::host_tensor_descriptor( BQK, N, stride_BQs[i], is_row_major(BQLayout())))); @@ -429,11 +637,12 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test ck_tile::DeviceMem gemm_workspace; gemm_workspace.Realloc(get_workspace_size(gemm_descs)); + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + if constexpr(Persistent) { // Generate kernel arguments std::vector kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); assert(gemm_descs[0].k_batch == 1); for(const auto& arg : gemm_descs) { @@ -471,7 +680,14 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test } else { - GTEST_FAIL() << "Non-persistent kernel not implemented yet"; + const auto stream = ck_tile::stream_config{nullptr, false, 1}; +#if CK_TILE_USE_WMMA + invoke_grouped_gemm( + gemm_descs, stream, kargs_ptr); +#else + invoke_grouped_gemm( + gemm_descs, stream, kargs_ptr); +#endif } // Copy results back to host for validation @@ -512,7 +728,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test bq_tensors[i], c_m_n_host_ref); } - else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) + else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped) { ck_tile::reference_gemm_quant( + a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); + } + else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) + { + ck_tile::reference_gemm_quant( a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); } @@ -550,5 +777,8 @@ using TestCkTileGroupedGemmQuant_RowCol = TestCkTileGroupedGemmQuant; template using TestCkTileGroupedGemmQuant_Tensor = TestCkTileGroupedGemmQuant; +template +using TestCkTileGroupedGemmQuant_AQuant = TestCkTileGroupedGemmQuant; + template using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant;