diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index 64c9dda64a..3b4258d8b1 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -28,7 +28,8 @@ template + typename CDataType, + ck_tile::QuantType QuantMode> float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, void* kargs_ptr) @@ -44,19 +45,20 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - constexpr ck_tile::QuantType QuantMode = ck_tile::QuantType::RowColQuant; - using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; float ave_time{0}; diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp index 93e461b9d3..bc271ac38e 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp @@ -11,12 +11,6 @@ #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 - -#ifndef CK_TILE_PIPELINE_DEFAULT -#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 -#endif template constexpr ck_tile::index_t get_k_warp_tile() @@ -66,7 +60,6 @@ struct GemmConfigBase static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; - static constexpr bool Preshuffle = false; }; template @@ -102,15 +95,6 @@ struct PipelineTypeTraits using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; }; -template <> -struct PipelineTypeTraits -{ - template - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; - template - using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; -}; - using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; auto create_args(int argc, char* argv[]) @@ -119,7 +103,12 @@ auto create_args(int argc, char* argv[]) arg_parser.insert("Ms", "", "M dimensions - empty by default.") .insert("Ns", "", "N dimensions - empty by default.") .insert("Ks", "", "K dimensions - empty by default.") - .insert("stride_As", "", "Tensor A strides - it is empty by default.") + .insert( + "stride_As", + "", + "Tensor A strides - it is empty by default.") // stride_As/stride_Bs/stride_Cs/stride_AQs/stride_BQs + // can be set to zero if + // Ms/Ns/Ks is not empty .insert("stride_Bs", "", "Tensor B strides - it is empty by default.") .insert("stride_Cs", "", "Tensor C strides - it is empty by default.") .insert("stride_AQs", "", "Tensor AQ strides - it is empty by default.") @@ -132,7 +121,9 @@ auto create_args(int argc, char* argv[]) .insert("warmup", "10", "number of iterations before benchmark the kernel.") .insert("repeat", "100", "number of iterations to benchmark the kernel.") .insert("group_count", "8", "group count.") - .insert("kbatch", "1", "kbatch for SplitK"); + .insert("kbatch", "1", "kbatch for SplitK") + .insert("quant_mode", "tensor", "Choose tensor (default), or rowcol"); + ; bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -145,13 +136,17 @@ inline std::size_t get_workspace_size(const std::vector& gem template + typename CDataType, + ck_tile::QuantType QuantMode> float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk = false); + void* kargs_ptr); diff --git a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc index 10d317a2c7..19211ed494 100644 --- a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc @@ -43,6 +43,7 @@ template float invoke_gemm(int n_warmup, int n_repeat, @@ -102,9 +103,10 @@ float invoke_gemm(int n_warmup, BDataType, BQDataType, AccDataType, - CDataType>(stream, group_count, kargs_ptr); + CDataType, + QuantMode>(stream, group_count, kargs_ptr); - std::string op_name{"Grouped Gemm"}; + std::string op_name = "Quant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")"; std::size_t flop = 0, num_btype = 0; for(int j = 0; j < group_count; ++j) @@ -132,6 +134,7 @@ template (group_count)) && ...); }; const int group_count = arg_parser.get_int("group_count"); @@ -180,7 +183,8 @@ int run_grouped_gemm_example_with_layouts(int argc, ck_tile::index_t AQK, BQK; - if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs)) + if(!valid_input_data( + group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs)) { std::cout << "Please check the input data. Default values will be used." << std::endl; @@ -242,25 +246,49 @@ int run_grouped_gemm_example_with_layouts(int argc, const ck_tile::index_t M = Ms[i]; const ck_tile::index_t N = Ns[i]; const ck_tile::index_t K = Ks[i]; + if constexpr(QuantMode == ck_tile::QuantType::RowColQuant || + QuantMode == ck_tile::QuantType::TensorQuant) + { + AQK = 1; // Row quantization: tensor shape [M, 1] or [1] + BQK = 1; // Column quantization: tensor shape [1, N] or [1] + } - AQK = 1; // Row quantization: tensor shape [M, 1]. Only for NT - BQK = N; // Column quantization: tensor shape [1, N]. Only for NT + stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout)); + stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout)); + stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{})); + if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) + { + stride_AQs[i] = + ck_tile::get_default_stride(M, 1, stride_AQs[i], is_row_major(aq_layout)); + stride_BQs[i] = + ck_tile::get_default_stride(1, N, stride_BQs[i], is_row_major(bq_layout)); + } + else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant) + { + stride_AQs[i] = 1; // Tensor quantization: tensor shape [1] + stride_BQs[i] = 1; // Tensor quantization: tensor shape [1] + } - stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout)); - stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout)); - stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{})); - stride_AQs[i] = ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout)); - stride_BQs[i] = ck_tile::get_default_stride(1, N, stride_BQs[i], is_row_major(bq_layout)); a_m_k_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout)))); b_k_n_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout)))); c_m_n_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{})))); - aq_tensors.push_back(ck_tile::HostTensor( - ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout)))); - bq_tensors.push_back(ck_tile::HostTensor( - ck_tile::host_tensor_descriptor(1, N, stride_BQs[i], is_row_major(bq_layout)))); + if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) + { + aq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout)))); + bq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout)))); + } + else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant) + { + aq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(1, 1, stride_AQs[i], is_row_major(aq_layout)))); + bq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout)))); + } std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc @@ -324,7 +352,8 @@ int run_grouped_gemm_example_with_layouts(int argc, AQLayout, BLayout, BQLayout, - CLayout>(warmup, repeat, group_count, gemm_descs); + CLayout, + QuantMode>(warmup, repeat, group_count, gemm_descs); for(int i = 0; i < group_count; i++) { @@ -339,13 +368,33 @@ int run_grouped_gemm_example_with_layouts(int argc, ck_tile::HostTensor c_m_n_host_ref(ck_tile::host_tensor_descriptor( Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{}))); c_m_n_host_ref.SetZero(); - ck_tile::reference_gemm_rowcol_quant( - a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref); + if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) + { + ck_tile::reference_gemm_rowcol_quant(a_m_k_tensors[i], + aq_tensors[i], + b_k_n_tensors[i], + bq_tensors[i], + c_m_n_host_ref); + } + else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant) + { + ck_tile::reference_gemm_tensor_quant(a_m_k_tensors[i], + aq_tensors[i], + b_k_n_tensors[i], + bq_tensors[i], + c_m_n_host_ref); + } + const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); const auto rtol_atol = @@ -367,7 +416,7 @@ int run_grouped_gemm_example_with_layouts(int argc, return pass; } -template +template int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) { using Row = ck_tile::tensor_layout::gemm::RowMajor; @@ -388,7 +437,8 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a BDataType, BQDataType, CDataType, - AccDataType>( + AccDataType, + QuantMode>( argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); } else if(a_layout == "R" && b_layout == "R") @@ -399,8 +449,9 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a BDataType, BQDataType, CDataType, - AccDataType>( - argc, argv, Row{}, Row{}, Row{}, Row{}, Row{}); + AccDataType, + QuantMode>( + argc, argv, Row{}, Row{}, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "R") { @@ -410,7 +461,8 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a BDataType, BQDataType, CDataType, - AccDataType>( + AccDataType, + QuantMode>( argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "C") @@ -421,7 +473,8 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a BDataType, BQDataType, CDataType, - AccDataType>( + AccDataType, + QuantMode>( argc, argv, Col{}, Col{}, Col{}, Col{}, Row{}); } else @@ -442,11 +495,28 @@ int run_grouped_gemm_example(int argc, char* argv[]) const std::string a_layout = arg_parser.get_str("a_layout"); const std::string b_layout = arg_parser.get_str("b_layout"); const std::string data_type = arg_parser.get_str("prec"); + std::string quant_mode = arg_parser.get_str("quant_mode"); if(data_type == "fp8") { - return run_gemm_example_prec_type, ck_tile::fp8_t>( - a_layout, b_layout, argc, argv); + if(quant_mode == "tensor") + { + return run_gemm_example_prec_type, + ck_tile::fp8_t, + ck_tile::QuantType::TensorQuant>( + a_layout, b_layout, argc, argv); + } + else if(quant_mode == "rowcol") + { + return run_gemm_example_prec_type, + ck_tile::fp8_t, + ck_tile::QuantType::RowColQuant>( + a_layout, b_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported quantization mode!"); + } } else { diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index f822c7d8a7..dbdbe80c5d 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -143,7 +143,7 @@ int run_grouped_gemm_example_with_layouts(int argc, auto [result, arg_parser] = create_args(argc, argv); auto valid_input_data = [&](int group_count, const auto&... args) { - return !(args.empty() || ...) && group_count == (args.size() == ...); + return group_count != 0 && ((args.size() == static_cast(group_count)) && ...); }; const int group_count = arg_parser.get_int("group_count"); diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc index db66d9a54b..1abb541e65 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc @@ -159,7 +159,7 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, using DsDataType = ck_tile::tuple; auto valid_input_data = [&](int group_count, const auto&... args) { - return !(args.empty() || ...) && group_count == (args.size() == ...); + return group_count != 0 && ((args.size() == static_cast(group_count)) && ...); }; const int group_count = arg_parser.get_int("group_count"); diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index 39c8e406b7..72f133c997 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -393,6 +393,13 @@ struct QuantGroupedGemmKernel aq_block_window, bq_block_window); } + else if constexpr(kQuantType == QuantType::TensorQuant) + { + const AccDataType aq_scale = type_convert(*aq_ptr); + const AccDataType bq_scale = type_convert(*bq_ptr); + EpiloguePipeline{}( + c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); + } } // For persistent kernels diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 5fa6918c10..d58c80377a 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -5,6 +5,7 @@ add_subdirectory(batched_gemm) add_subdirectory(grouped_gemm) add_subdirectory(grouped_gemm_preshuffle) add_subdirectory(grouped_gemm_multi_d) +add_subdirectory(grouped_gemm_quant) add_subdirectory(gemm_multi_d) add_subdirectory(gemm_multi_abd) add_subdirectory(gemm_streamk) diff --git a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt new file mode 100644 index 0000000000..fddd8b69b2 --- /dev/null +++ b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt @@ -0,0 +1,10 @@ +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +if(GPU_TARGETS MATCHES "gfx94|gfx95") + add_gtest_executable(test_ck_tile_grouped_gemm_quant test_grouped_gemm_quant.cpp) + target_compile_options(test_ck_tile_grouped_gemm_quant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +endif() + diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp new file mode 100644 index 0000000000..acdc9f4400 --- /dev/null +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_util_quant.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using True = ck_tile::bool_constant; +using False = ck_tile::bool_constant; +using RowColQuant = std::integral_constant; +using TensorQuant = std::integral_constant; + +// clang-format off +using KernelTypes = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>, + + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>, + std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>, + std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>, + std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>, + + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>, + std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>, + std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>, + std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant, KernelTypes); + +#include "test_grouped_gemm_quant_ut_cases.inc" diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_ut_cases.inc b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_ut_cases.inc new file mode 100644 index 0000000000..cef9c40b13 --- /dev/null +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_ut_cases.inc @@ -0,0 +1,28 @@ +#pragma once + +TYPED_TEST(TestCkTileGroupedGemmQuant, Basic) +{ + const int group_count = 8; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + std::vector stride_AQs; + std::vector stride_BQs; + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 128 * i); + + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); + stride_AQs.push_back(0); + stride_BQs.push_back(0); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count); +} diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp new file mode 100644 index 0000000000..101e444f75 --- /dev/null +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp @@ -0,0 +1,441 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm_quant.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +template +class TestCkTileGroupedGemmQuant : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using AQDataType = std::tuple_element_t<4, Tuple>; + using BDataType = std::tuple_element_t<5, Tuple>; + using BQDataType = std::tuple_element_t<6, Tuple>; + using AccDataType = std::tuple_element_t<7, Tuple>; + using CDataType = std::tuple_element_t<8, Tuple>; + static constexpr auto QuantType = std::tuple_element_t<9, Tuple>::value; + using DsLayout = ck_tile::tuple<>; + using DsDataType = ck_tile::tuple<>; + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using AQLayout = Row; + using BQLayout = Col; + static constexpr bool Persistent = true; + + struct GroupedGemKernelParam_Mfma + { + static const bool kPadM = false; + static const bool kPadN = false; + static const bool kPadK = false; + + static const int kBlockPerCu = 1; + static const ck_tile::index_t M_Tile = 256; + static const ck_tile::index_t N_Tile = 256; + static const ck_tile::index_t K_Tile = 128; + + static const ck_tile::index_t M_Warp = 2; + static const ck_tile::index_t N_Warp = 2; + static const ck_tile::index_t K_Warp = 1; + + static const ck_tile::index_t M_Warp_Tile = 32; + static const ck_tile::index_t N_Warp_Tile = 32; + static const ck_tile::index_t K_Warp_Tile = 16; + }; + + using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; + std::size_t get_workspace_size(const std::vector& gemm_descs) + { + return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg); + } + + template + void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr) + { + constexpr bool TransposeC = false; + constexpr bool DoubleSmemBuffer = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto memory_operation = memory_operation_.value; + constexpr bool transpose_c = false; + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using QuantGemmProblem = + ck_tile::GemmRowColTensorQuantPipelineProblem; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::QuantGroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; + } + + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); + }; + + Run(ck_tile::integral_constant{}); + } + + template + static constexpr inline auto is_row_major(Layout layout_) + { + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; + } + + auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) + { + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); + } + + public: + void Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + std::vector& stride_As, + std::vector& stride_Bs, + std::vector& stride_Cs, + std::vector& stride_AQs, + std::vector& stride_BQs, + const int group_count = 16) + { + ck_tile::index_t AQK, BQK; + using namespace ck_tile::literals; + + std::vector> a_m_k_tensors; + std::vector> b_k_n_tensors; + std::vector> c_m_n_tensors; + std::vector> aq_tensors; + std::vector> bq_tensors; + a_m_k_tensors.reserve(group_count); + b_k_n_tensors.reserve(group_count); + c_m_n_tensors.reserve(group_count); + aq_tensors.reserve(group_count); + bq_tensors.reserve(group_count); + + std::vector> a_m_k_dev_buf; + std::vector> b_k_n_dev_buf; + std::vector> c_m_n_dev_buf; + std::vector> aq_dev_buf; + std::vector> bq_dev_buf; + + a_m_k_dev_buf.reserve(group_count); + b_k_n_dev_buf.reserve(group_count); + c_m_n_dev_buf.reserve(group_count); + aq_dev_buf.reserve(group_count); + bq_dev_buf.reserve(group_count); + + std::vector gemm_descs; + gemm_descs.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + const ck_tile::index_t M = Ms[i]; + const ck_tile::index_t N = Ns[i]; + const ck_tile::index_t K = Ks[i]; + if constexpr(QuantType == ck_tile::QuantType::RowColQuant || + QuantType == ck_tile::QuantType::TensorQuant) + { + AQK = 1; // Row quantization: tensor shape [M, 1] or [1] + BQK = 1; // Column quantization: tensor shape [1, N] or [1] + } + + stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(ALayout{})); + stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(BLayout{})); + stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{})); + if constexpr(QuantType == ck_tile::QuantType::RowColQuant) + { + stride_AQs[i] = + ck_tile::get_default_stride(M, 1, stride_AQs[i], is_row_major(AQLayout{})); + stride_BQs[i] = + ck_tile::get_default_stride(1, N, stride_BQs[i], is_row_major(BQLayout())); + } + else if constexpr(QuantType == ck_tile::QuantType::TensorQuant) + { + stride_AQs[i] = 1; // Tensor quantization: tensor shape [1] + stride_AQs[i] = 1; // Tensor quantization: tensor shape [1] + } + + a_m_k_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(ALayout{})))); + b_k_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(BLayout{})))); + c_m_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{})))); + if constexpr(QuantType == ck_tile::QuantType::RowColQuant) + { + aq_tensors.push_back( + ck_tile::HostTensor(ck_tile::host_tensor_descriptor( + M, AQK, stride_AQs[i], is_row_major(AQLayout{})))); + bq_tensors.push_back( + ck_tile::HostTensor(ck_tile::host_tensor_descriptor( + BQK, N, stride_BQs[i], is_row_major(BQLayout())))); + } + else if constexpr(QuantType == ck_tile::QuantType::TensorQuant) + { + aq_tensors.push_back( + ck_tile::HostTensor(ck_tile::host_tensor_descriptor( + 1, 1, stride_AQs[i], is_row_major(AQLayout{})))); + bq_tensors.push_back( + ck_tile::HostTensor(ck_tile::host_tensor_descriptor( + 1, 1, stride_BQs[i], is_row_major(BQLayout())))); + } + + std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc + << " b_k_n: " << b_k_n_tensors[i].mDesc + << " c_m_n: " << c_m_n_tensors[i].mDesc << " aq: " << aq_tensors[i].mDesc + << " bq: " << bq_tensors[i].mDesc << std::endl; + + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(aq_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(bq_tensors[i]); + + a_m_k_dev_buf.push_back(std::make_unique( + a_m_k_tensors[i].get_element_space_size_in_bytes())); + b_k_n_dev_buf.push_back(std::make_unique( + b_k_n_tensors[i].get_element_space_size_in_bytes())); + c_m_n_dev_buf.push_back(std::make_unique( + c_m_n_tensors[i].get_element_space_size_in_bytes())); + aq_dev_buf.push_back(std::make_unique( + aq_tensors[i].get_element_space_size_in_bytes())); + bq_dev_buf.push_back(std::make_unique( + bq_tensors[i].get_element_space_size_in_bytes())); + + a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data()); + b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); + aq_dev_buf[i]->ToDevice(aq_tensors[i].data()); + bq_dev_buf[i]->ToDevice(bq_tensors[i].data()); + c_m_n_dev_buf[i]->SetZero(); + c_m_n_tensors[i].SetZero(); + + const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer(); + const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); + void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); + const void* p_aq = aq_dev_buf[i]->GetDeviceBuffer(); + const void* p_bq = bq_dev_buf[i]->GetDeviceBuffer(); + + gemm_descs.push_back({p_a, + p_b, + p_c, + p_aq, + p_bq, + 1, + M, + N, + K, + AQK, + BQK, + stride_As[i], + stride_Bs[i], + stride_Cs[i], + stride_AQs[i], + stride_BQs[i]}); + } + + ck_tile::DeviceMem gemm_workspace; + gemm_workspace.Realloc(get_workspace_size(gemm_descs)); + + if constexpr(Persistent) + { + // Generate kernel arguments + std::vector kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + assert(gemm_descs[0].k_batch == 1); + for(const auto& arg : gemm_descs) + { + kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, + arg.b_ptr, + arg.aq_ptr, + arg.bq_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + arg.QK_A, + arg.QK_B, + arg.stride_A, + arg.stride_B, + arg.stride_E, + arg.stride_AQ, + arg.stride_BQ, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, false, 1}; + ck_tile::hip_check_error( + hipMemcpyWithStream(kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + + invoke_grouped_gemm_persistent( + stream, group_count, kargs_ptr); + } + else + { + GTEST_FAIL() << "Non-persistent kernel not implemented yet"; + } + + // Copy results back to host for validation + for(int i = 0; i < group_count; i++) + { + c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data()); + } + + bool pass{true}; + for(int i = 0; i < group_count; ++i) + { + ck_tile::HostTensor c_m_n_host_ref(ck_tile::host_tensor_descriptor( + Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + if constexpr(QuantType == ck_tile::QuantType::RowColQuant) + { + ck_tile::reference_gemm_rowcol_quant(a_m_k_tensors[i], + aq_tensors[i], + b_k_n_tensors[i], + bq_tensors[i], + c_m_n_host_ref); + } + else if constexpr(QuantType == ck_tile::QuantType::TensorQuant) + { + ck_tile::reference_gemm_tensor_quant(a_m_k_tensors[i], + aq_tensors[i], + b_k_n_tensors[i], + bq_tensors[i], + c_m_n_host_ref); + } + + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol(Ks[i], 1, max_accumulated_value); + pass &= ck_tile::check_err(c_m_n_tensors[i], + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "gemm[" << i + << "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + + EXPECT_TRUE(pass); + } +};