diff --git a/codegen/test/rtc/include/rtc/kernel.hpp b/codegen/test/rtc/include/rtc/kernel.hpp index b1ee729f77..96337fe2c1 100644 --- a/codegen/test/rtc/include/rtc/kernel.hpp +++ b/codegen/test/rtc/include/rtc/kernel.hpp @@ -52,7 +52,7 @@ struct kernel template auto launch(hipStream_t stream, std::size_t global, std::size_t local, Ts... zs) const { - return [=](auto&&... xs) { + return [=, this](auto&&... xs) { launch(stream, global, local, std::vector{xs...}, zs...); }; } diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index 3b4258d8b1..1a913fcfc1 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -29,7 +29,7 @@ template + ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped> float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, void* kargs_ptr) @@ -48,8 +48,8 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using QuantGemmProblem = typename std::conditional< + QuantMode == ck_tile::QuantType::BQuantGrouped, + ck_tile::GemmBQuantPipelineProblem, // QuantGroupSize + ck_tile::GemmRowColTensorQuantPipelineProblem>::type; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = + typename std::conditional, + ck_tile::GemmPipelineAgBgCrCompV3>::type; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem constexpr ck_tile::index_t get_k_warp_tile() @@ -41,6 +42,14 @@ struct GemmTypeConfig using AccDataType = float; using CDataType = ck_tile::half_t; }; +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; struct GemmConfigBase { @@ -77,24 +86,11 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; static constexpr int kBlockPerCu = 1; }; -template -struct PipelineTypeTraits; - -template <> -struct PipelineTypeTraits -{ - template - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - template - using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; -}; - using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; auto create_args(int argc, char* argv[]) @@ -122,8 +118,7 @@ auto create_args(int argc, char* argv[]) .insert("repeat", "100", "number of iterations to benchmark the kernel.") .insert("group_count", "8", "group count.") .insert("kbatch", "1", "kbatch for SplitK") - .insert("quant_mode", "tensor", "Choose tensor (default), or rowcol"); - ; + .insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc index 19211ed494..152df38bff 100644 --- a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc @@ -43,8 +43,8 @@ template + ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped, + typename CDEElementWise = ck_tile::element_wise::PassThrough> float invoke_gemm(int n_warmup, int n_repeat, int group_count, @@ -159,11 +159,12 @@ int run_grouped_gemm_example_with_layouts(int argc, return group_count != 0 && ((args.size() == static_cast(group_count)) && ...); }; - const int group_count = arg_parser.get_int("group_count"); - const int repeat = arg_parser.get_int("repeat"); - const int warmup = arg_parser.get_int("warmup"); - const int kbatch = arg_parser.get_int("kbatch"); - bool validate = arg_parser.get_bool("validate"); + const int group_count = arg_parser.get_int("group_count"); + const int repeat = arg_parser.get_int("repeat"); + const int warmup = arg_parser.get_int("warmup"); + const int kbatch = arg_parser.get_int("kbatch"); + bool validate = arg_parser.get_bool("validate"); + const ck_tile::index_t QuantGroupSize = 128; if(kbatch > 1 && validate && warmup + repeat > 1) { @@ -172,9 +173,11 @@ int run_grouped_gemm_example_with_layouts(int argc, validate = false; } - std::vector Ms = arg_parser.get_int_vec("Ms"); - std::vector Ns = arg_parser.get_int_vec("Ns"); - std::vector Ks = arg_parser.get_int_vec("Ks"); + std::vector Ms = arg_parser.get_int_vec("Ms"); + std::vector Ns = arg_parser.get_int_vec("Ns"); + std::vector Ks = arg_parser.get_int_vec("Ks"); + std::vector AQs; // dimension of AQ tensor is calculated from A tensor + std::vector BQs; // dimension of BQ tensor is calculated from B tensor std::vector stride_As = arg_parser.get_int_vec("stride_As"); std::vector stride_Bs = arg_parser.get_int_vec("stride_Bs"); std::vector stride_Cs = arg_parser.get_int_vec("stride_Cs"); @@ -252,6 +255,15 @@ int run_grouped_gemm_example_with_layouts(int argc, AQK = 1; // Row quantization: tensor shape [M, 1] or [1] BQK = 1; // Column quantization: tensor shape [1, N] or [1] } + else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + AQK = 0; // No A quantization + BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize + if(K % QuantGroupSize != 0) + { + throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode"); + } + } stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout)); stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout)); @@ -289,6 +301,13 @@ int run_grouped_gemm_example_with_layouts(int argc, bq_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout)))); } + else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + aq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(0, AQK, stride_AQs[i], is_row_major(aq_layout)))); + bq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout)))); + } std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc @@ -394,6 +413,17 @@ int run_grouped_gemm_example_with_layouts(int argc, bq_tensors[i], c_m_n_host_ref); } + else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + ck_tile::reference_gemm_quant( + a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); + } const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); @@ -441,42 +471,6 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a QuantMode>( argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); } - else if(a_layout == "R" && b_layout == "R") - { - return run_grouped_gemm_example_with_layouts( - argc, argv, Row{}, Row{}, Row{}, Col{}, Row{}); - } - else if(a_layout == "C" && b_layout == "R") - { - return run_grouped_gemm_example_with_layouts( - argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); - } - else if(a_layout == "C" && b_layout == "C") - { - return run_grouped_gemm_example_with_layouts( - argc, argv, Col{}, Col{}, Col{}, Col{}, Row{}); - } else { throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); @@ -513,6 +507,41 @@ int run_grouped_gemm_example(int argc, char* argv[]) ck_tile::QuantType::RowColQuant>( a_layout, b_layout, argc, argv); } + else if(quant_mode == "bquant") + { + return run_gemm_example_prec_type, + ck_tile::fp8_t, + ck_tile::QuantType::BQuantGrouped>( + a_layout, b_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported quantization mode!"); + } + } + if(data_type == "bf8") + { + if(quant_mode == "tensor") + { + return run_gemm_example_prec_type, + ck_tile::bf8_t, + ck_tile::QuantType::TensorQuant>( + a_layout, b_layout, argc, argv); + } + else if(quant_mode == "rowcol") + { + return run_gemm_example_prec_type, + ck_tile::bf8_t, + ck_tile::QuantType::RowColQuant>( + a_layout, b_layout, argc, argv); + } + else if(quant_mode == "bquant") + { + return run_gemm_example_prec_type, + ck_tile::bf8_t, + ck_tile::QuantType::BQuantGrouped>( + a_layout, b_layout, argc, argv); + } else { throw std::runtime_error("Unsupported quantization mode!"); diff --git a/include/ck/utility/synchronization.hpp b/include/ck/utility/synchronization.hpp index 672fc8c31b..54391e7e86 100644 --- a/include/ck/utility/synchronization.hpp +++ b/include/ck/utility/synchronization.hpp @@ -16,10 +16,17 @@ __device__ void llvm_amdgcn_s_wait_dscnt(short cnt) __asm("llvm.amdgcn.s.wait.ds __device__ void block_sync_lds() { #if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM -#ifdef __gfx12__ +#if defined(__gfx12__) llvm_amdgcn_s_wait_dscnt(0); asm volatile("s_barrier_signal -1\n\t" "s_barrier_wait -1"); +#elif defined(__gfx11__) + // asm volatile("\ + // s_waitcnt lgkmcnt(0) \n \ + // s_barrier \ + // " ::); + __builtin_amdgcn_s_waitcnt(0xfc07); + __builtin_amdgcn_s_barrier(); #else // asm volatile("\ // s_waitcnt lgkmcnt(0) \n \ diff --git a/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp index b4ddc33e8d..7b73b89ede 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp @@ -185,14 +185,6 @@ struct GemmKernelMultiABD { return false; } - // Currently MultiABD kernel doesn't support F8 data type - if(ck_tile::get_device_name() == "gfx950" && - (std::is_same::value || - std::is_same::value || - std::is_same::value)) - { - return false; - } return UniversalGemmKernel::IsSupportedArgument(kargs); } diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index 72f133c997..eee4017c12 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -375,30 +375,48 @@ struct QuantGroupedGemmKernel const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - // Run GEMM pipeline - const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(Base::I4); - if constexpr(kQuantType == QuantType::RowColQuant) + if constexpr(kQuantType == QuantType::BQuantGrouped) { - const auto& aq_block_window = gemm_tile_windows.at(Base::I1); const auto& bq_block_window = gemm_tile_windows.at(Base::I3); - EpiloguePipeline{}.template - operator()( - c_block_window, - c_block_tile, - c_block_window, - smem_ptr_0, - aq_block_window, - bq_block_window); + // Run GEMM pipeline + const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, + b_block_window, + bq_block_window, + num_loop, + has_hot_loop, + tail_num, + smem_ptr_0); + + auto& c_block_window = gemm_tile_windows.at(Base::I4); + + // Run Epilogue Pipeline + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); } - else if constexpr(kQuantType == QuantType::TensorQuant) + else { - const AccDataType aq_scale = type_convert(*aq_ptr); - const AccDataType bq_scale = type_convert(*bq_ptr); - EpiloguePipeline{}( - c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); + // Run GEMM pipeline + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(Base::I4); + if constexpr(kQuantType == QuantType::RowColQuant) + { + const auto& aq_block_window = gemm_tile_windows.at(Base::I1); + const auto& bq_block_window = gemm_tile_windows.at(Base::I3); + EpiloguePipeline{}(c_block_window, + c_block_tile, + c_block_window, + smem_ptr_0, + aq_block_window, + bq_block_window); + } + else if constexpr(kQuantType == QuantType::TensorQuant) + { + const AccDataType aq_scale = type_convert(*aq_ptr); + const AccDataType bq_scale = type_convert(*bq_ptr); + EpiloguePipeline{}( + c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); + } } } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 92b1316b34..92261e7344 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -472,6 +472,49 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV num_loop, p_smem); } + + /// @brief Runtime pipeline dispatch operator for grouped GEMM kernels. + /// + /// This operator is used by grouped GEMM kernels where pipeline parameters + /// (has_hot_loop, num_loop, tail_number) are calculated on the device side + /// at runtime, not on the host side during compilation. This is necessary + /// because different GEMM problems in the group may have different K dimensions, + /// requiring different pipeline configurations that cannot be determined at + /// compile time. + /// + /// @param a_dram_block_window_tmp Block window for A tensor in DRAM + /// @param b_dram_block_window_tmp Block window for B tensor in DRAM + /// @param bq_dram_block_window_tmp Block window for BQ (quantization scale) tensor in DRAM + /// @param num_loop Number of main loop iterations (calculated on device) + /// @param has_hot_loop Whether the pipeline has a hot loop (calculated on device) + /// @param tail_number Type of tail handling required (calculated on device) + /// @param p_smem Pointer to shared memory + /// @return Accumulated result tile in registers + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* p_smem) const + { + const auto RunPipeline = [&](auto has_hot_loop_, auto tail_number_) { + constexpr bool hot_loop = has_hot_loop_.value; + constexpr auto tail_num = tail_number_.value; + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + bq_dram_block_window_tmp, + num_loop, + p_smem); + }; + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); + } }; } // namespace ck_tile diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp index 87d6a9101c..1eef00813d 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp @@ -20,20 +20,19 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using KernelTypes = ::testing::Types< // Has cshuffle epilogue enabled // A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type> - - // Currently MultiABD kernel doesn't support F8 data type - //std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, - //std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type> >; // clang-format on diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp index f2476e803f..5664ded81e 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp @@ -20,19 +20,17 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using KernelTypes = ::testing::Types< // Has cshuffle epilogue disabled // A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type> - - // Currently MultiABD kernel doesn't support F8 data type - //std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, - //std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type> >; // clang-format on diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc index 33eb404fbe..e2a7e0115f 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc @@ -1,5 +1,95 @@ #pragma once +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x512x512) { constexpr int M = 512; diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp index 428bed4e25..8b5d845d6b 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp @@ -13,40 +13,9 @@ #include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" -struct AddScale -{ - template - CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const A0& a0, const A1& a1) const - { - a = scale * (ck_tile::type_convert(a0) + ck_tile::type_convert(a1)); - } - - float scale = 1.0; -}; - -struct MultiplyMultiply -{ - template - CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void - { - const float x0_f = ck_tile::type_convert(c) * ck_tile::type_convert(d0) * - ck_tile::type_convert(d1); - - e = ck_tile::type_convert(x0_f); - } -}; - -struct ElementWiseAddAdd -{ - template - CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void - { - const float x0_f = ck_tile::type_convert(c) + ck_tile::type_convert(d0) + - ck_tile::type_convert(d1); - - e = ck_tile::type_convert(x0_f); - } -}; +using AddScale = ck_tile::element_wise::AddScale; +using ElementWiseAddAdd = ck_tile::element_wise::MultiDAdd; +using MultiplyMultiply = ck_tile::element_wise::MultiDMultiply; template static constexpr inline auto is_row_major(Layout layout_) diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp index acdc9f4400..599a669c0a 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp @@ -18,6 +18,7 @@ using True = ck_tile::bool_constant; using False = ck_tile::bool_constant; using RowColQuant = std::integral_constant; using TensorQuant = std::integral_constant; +using BQuant = std::integral_constant; // clang-format off using KernelTypes = ::testing::Types< @@ -31,16 +32,16 @@ using KernelTypes = ::testing::Types< std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>, std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>, std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>, std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>, std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>, std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>, std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>, std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>, - std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant> + std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_ut_cases.inc b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_ut_cases.inc index cef9c40b13..0b522f82f3 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_ut_cases.inc +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_ut_cases.inc @@ -26,3 +26,32 @@ TYPED_TEST(TestCkTileGroupedGemmQuant, Basic) this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count); } + +// No Hot Loop Test Case, this is to test the correctness of the kernel when there is no hot loop +// Using 256x256x128 to match the test kernel's tile size (M_Tile=256, N_Tile=256, K_Tile=128) +TYPED_TEST(TestCkTileGroupedGemmQuant, SmallUniform) // +{ + const int group_count = 2; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + std::vector stride_AQs; + std::vector stride_BQs; + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256); + Ns.push_back(256); + Ks.push_back(256); + + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); + stride_AQs.push_back(0); + stride_BQs.push_back(0); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count); +} diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp index 101e444f75..5e9956ab5b 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp @@ -107,7 +107,15 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test constexpr bool transpose_c = false; // We create the GEMM pipeline without specifying hotloop or tailnumber. // These are automatically run inside the kernel based on the given input data. - using QuantGemmProblem = + using QuantGemmProblem = typename std::conditional< + QuantType == ck_tile::QuantType::BQuantGrouped, + ck_tile::GemmBQuantPipelineProblem, // QuantGroupSize ck_tile::GemmRowColTensorQuantPipelineProblem; + scheduler>>::type; + + using GemmPipeline = typename std::conditional< + QuantType == ck_tile::QuantType::BQuantGrouped, + ck_tile::BQuantGemmPipelineAgBgCrCompV3, + ck_tile::GemmPipelineAgBgCrCompV3>::type; - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem( @@ -285,6 +312,15 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test ck_tile::HostTensor(ck_tile::host_tensor_descriptor( 1, 1, stride_BQs[i], is_row_major(BQLayout())))); } + else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) + { + aq_tensors.push_back( + ck_tile::HostTensor(ck_tile::host_tensor_descriptor( + 0, AQK, stride_AQs[i], is_row_major(AQLayout{})))); + bq_tensors.push_back( + ck_tile::HostTensor(ck_tile::host_tensor_descriptor( + BQK, N, stride_BQs[i], is_row_major(BQLayout())))); + } std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc @@ -373,7 +409,6 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), hipMemcpyHostToDevice, stream.stream_id_)); - invoke_grouped_gemm_persistent( stream, group_count, kargs_ptr); } @@ -420,6 +455,17 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test bq_tensors[i], c_m_n_host_ref); } + else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) + { + ck_tile::reference_gemm_quant( + a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); + } const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());