diff --git a/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp index 3172450e49..210dd05c2b 100644 --- a/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp @@ -19,6 +19,7 @@ #include "ck_tile/host.hpp" #include "abquant_grouped_gemm.hpp" +// Non-persistent grouped gemm for ABQuant template + ck_tile::QuantType QuantMode> +float grouped_gemm_abquant(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) +{ + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = + GemmQuantConfig::template BaseGemmPipeline; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + + using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem; + + using GemmPipeline = + GemmQuantConfig::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC, + memory_operation>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; + + return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); +} + +// Persistent grouped gemm tileloop for ABQuant +template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, void* kargs_ptr) diff --git a/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp index 658ed5737a..442c5fb32c 100644 --- a/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp @@ -143,6 +143,27 @@ inline std::size_t get_workspace_size(const std::vector& gem { return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg); } + +// Forward declaration of the non-persistent version +template +float grouped_gemm_abquant(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr); + // Forward declaration of the tileloop version for persistent kernels template -float grouped_gemm_abquant_tileloop(const ck_tile::stream_config& s, - const ck_tile::index_t num_groups, - void* kargs_ptr); + typename BQuantGroupSize, + ck_tile::QuantType QuantMode = ck_tile::QuantType::ABQuantGrouped> +float grouped_gemm_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr); diff --git a/example/ck_tile/17_grouped_gemm/abquant_run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/abquant_run_grouped_gemm_example.inc index b15ef4dac8..df2da03d0e 100644 --- a/example/ck_tile/17_grouped_gemm/abquant_run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/abquant_run_grouped_gemm_example.inc @@ -74,54 +74,85 @@ float invoke_abquant_gemm(int n_warmup, float ave_time = 0; - // Persistent TileLoop kernel only - std::vector kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); - if(args[0].k_batch != 1) + if constexpr(!GemmConfig::Persistent) { - throw std::runtime_error("Split-K not supported yet for persistent kernel"); + ave_time = + grouped_gemm_abquant(args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, + gemm_workspace.GetDeviceBuffer()); } + else + { + // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have + // the gemm problems known on the host. Instead, we can just pass the pointer + // to the kernel and let the workgroups figure out which tiles to work on. + // This is useful when the gemm problems are generated dynamically. + // In this example however, we generate the `kargs` using the known gemm_descs, + // and copy the gemm descriptions to the device memory. + // The contents of the memory pointed to by `kargs_ptr` pointer could be + // written by e.g. another kernel from earlier stage. + std::vector kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + if(args[0].k_batch != 1) + { + throw std::runtime_error("Split-K not supported yet for persistent kernel"); + } - for(const auto& arg : args) - { - kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, - arg.b_ptr, - arg.aq_ptr, - arg.bq_ptr, - arg.e_ptr, - arg.M, - arg.N, - arg.K, - arg.QK_A, - arg.QK_B, - arg.stride_A, - arg.stride_B, - arg.stride_E, - arg.stride_AQ, - arg.stride_BQ, - arg.k_batch}); + for(const auto& arg : args) + { + kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, + arg.b_ptr, + arg.aq_ptr, + arg.bq_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + arg.QK_A, + arg.QK_B, + arg.stride_A, + arg.stride_B, + arg.stride_E, + arg.stride_AQ, + arg.stride_BQ, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + ave_time = grouped_gemm_tileloop(stream, group_count, kargs_ptr); } - const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), - hipMemcpyHostToDevice, - stream.stream_id_)); - ave_time = grouped_gemm_tileloop(stream, group_count, kargs_ptr); std::string op_name = "ABQuant Grouped Gemm"; @@ -426,11 +457,10 @@ int run_abquant_grouped_gemm_example_with_layouts(int argc, return pass; } -template +template int run_abquant_grouped_gemm_example_prec_type(std::string a_layout, std::string b_layout, std::string c_layout, - [[maybe_unused]] bool persistent, int argc, char* argv[]) { @@ -447,9 +477,6 @@ int run_abquant_grouped_gemm_example_prec_type(std::string a_layout, using AQuantGroupSize = ck_tile::QuantGroupShape>; using BQuantGroupSize = ck_tile::QuantGroupShape>; - using GemmConfig = typename GemmQuantConfig:: - template GemmConfig; - // Support RCR, RRR, CRR layouts if(a_layout == "R" && b_layout == "C" && c_layout == "R") { @@ -496,6 +523,30 @@ int run_abquant_grouped_gemm_example_prec_type(std::string a_layout, } } +template +int run_abquant_gemm_example_persistency(std::string a_layout, + std::string b_layout, + std::string c_layout, + bool persistent, + int argc, + char* argv[]) +{ + if(persistent) + { + using GemmConfig = typename GemmQuantConfig:: + template GemmConfig; + return run_abquant_grouped_gemm_example_prec_type( + a_layout, b_layout, c_layout, argc, argv); + } + else + { + using GemmConfig = typename GemmQuantConfig:: + template GemmConfig; + return run_abquant_grouped_gemm_example_prec_type( + a_layout, b_layout, c_layout, argc, argv); + } +} + int run_abquant_grouped_gemm_example(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); @@ -508,7 +559,7 @@ int run_abquant_grouped_gemm_example(int argc, char* argv[]) const std::string b_layout = arg_parser.get_str("b_layout"); const std::string c_layout = arg_parser.get_str("c_layout"); const std::string data_type = arg_parser.get_str("prec"); - const bool persistent = arg_parser.get_bool("persistent"); + bool persistent = arg_parser.get_bool("persistent"); // Validate layout combinations if(!((a_layout == "R" && b_layout == "C" && c_layout == "R") || @@ -522,12 +573,12 @@ int run_abquant_grouped_gemm_example(int argc, char* argv[]) if(data_type == "fp8") { - return run_abquant_grouped_gemm_example_prec_type( + return run_abquant_gemm_example_persistency( a_layout, b_layout, c_layout, persistent, argc, argv); } else if(data_type == "bf8") { - return run_abquant_grouped_gemm_example_prec_type( + return run_abquant_gemm_example_persistency( a_layout, b_layout, c_layout, persistent, argc, argv); } else diff --git a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt index 55f09726cc..b8d5f12d0c 100644 --- a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt @@ -17,7 +17,10 @@ endif() # add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp) # target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -# add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp) -# target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -# endif() + add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp) + target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_ck_tile_grouped_gemm_quant_abquant test_grouped_gemm_quant_abquant.cpp) + target_compile_options(test_ck_tile_grouped_gemm_quant_abquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +endif() diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_abquant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_abquant.cpp new file mode 100644 index 0000000000..1a2d8d0b35 --- /dev/null +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_abquant.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_util_quant.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using True = ck_tile::bool_constant; +using False = ck_tile::bool_constant; +using ABQuant = std::integral_constant; + +// clang-format off +using KernelTypes_ABQuant = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, True, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, ABQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, False, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, False, False> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_ABQuant, KernelTypes_ABQuant); + +#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_ABQuant +#include "test_grouped_gemm_quant_ut_cases.inc" +#undef TEST_CLASS_NAME + diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp index 9941066c3e..838c59d51c 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp @@ -85,7 +85,8 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test constexpr ck_tile::index_t TileParitionerGroupNum = 8; constexpr ck_tile::index_t TileParitionerM01 = 4; constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped || - QuantType == ck_tile::QuantType::BQuantGrouped; + QuantType == ck_tile::QuantType::BQuantGrouped || + QuantType == ck_tile::QuantType::ABQuantGrouped; using QuantGroupSize = ck_tile::QuantGroupShape>; @@ -168,17 +169,32 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test scheduler, has_hot_loop_v, tail_number_v>, - ck_tile::GemmBQuantPipelineProblem>, + std::conditional_t, + ck_tile::GemmABQuantPipelineProblem>>, ck_tile::GemmRowColTensorQuantPipelineProblem, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>, + std::conditional_t< + QuantType == ck_tile::QuantType::BQuantGrouped, + std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>, + ck_tile::ABQuantGemmPipelineAgBgCrCompV3>>, ck_tile::GemmPipelineAgBgCrCompV3>; using GemmEpilogue = ck_tile::CShuffleEpilogue< @@ -309,7 +328,8 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test // These are automatically run inside the kernel based on the given input data. constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped || - QuantType == ck_tile::QuantType::BQuantGrouped; + QuantType == ck_tile::QuantType::BQuantGrouped || + QuantType == ck_tile::QuantType::ABQuantGrouped; using QuantGemmProblem = std::conditional_t< UseGroupedQuant, std::conditional_t, - ck_tile::GemmBQuantPipelineProblem>, + std::conditional_t, + ck_tile::GemmABQuantPipelineProblem>>, ck_tile::GemmRowColTensorQuantPipelineProblem, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>, + std::conditional_t< + QuantType == ck_tile::QuantType::BQuantGrouped, + std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>, + ck_tile::ABQuantGemmPipelineAgBgCrCompV3>>, ck_tile::GemmPipelineAgBgCrCompV3>; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem( ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(ALayout{})))); @@ -565,6 +616,15 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test ck_tile::HostTensor(ck_tile::host_tensor_descriptor( BQK, N, stride_BQs[i], is_row_major(BQLayout())))); } + else if constexpr(QuantType == ck_tile::QuantType::ABQuantGrouped) + { + aq_tensors.push_back( + ck_tile::HostTensor(ck_tile::host_tensor_descriptor( + M, AQK, stride_AQs[i], is_row_major(AQLayout{})))); + bq_tensors.push_back( + ck_tile::HostTensor(ck_tile::host_tensor_descriptor( + BQK, N, stride_BQs[i], is_row_major(BQLayout())))); + } std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc @@ -750,6 +810,18 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test false>( a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); } + else if constexpr(QuantType == ck_tile::QuantType::ABQuantGrouped) + { + ck_tile::reference_gemm_abquant( + a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref); + } const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); @@ -782,3 +854,6 @@ using TestCkTileGroupedGemmQuant_AQuant = TestCkTileGroupedGemmQuant; template using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant; + +template +using TestCkTileGroupedGemmQuant_ABQuant = TestCkTileGroupedGemmQuant;