diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index bf52e0c3f4..9b51af22fe 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -3,7 +3,18 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95") add_executable(tile_example_grouped_gemm grouped_gemm.cpp) - add_executable(tile_example_quant_grouped_gemm quant_grouped_gemm.cpp) + add_executable(tile_example_quant_grouped_gemm + quant_grouped_gemm.cpp + quant_grouped_gemm_fp8_aquant.cpp + quant_grouped_gemm_fp8_bquant.cpp + quant_grouped_gemm_fp8_rowcol.cpp + quant_grouped_gemm_fp8_tensor.cpp + quant_grouped_gemm_bf8_aquant.cpp + quant_grouped_gemm_bf8_bquant.cpp + quant_grouped_gemm_bf8_rowcol.cpp + quant_grouped_gemm_bf8_tensor.cpp + ) + add_executable(tile_example_grouped_gemm_preshuffle grouped_gemm_preshuffle.cpp) add_executable(tile_example_grouped_gemm_multi_d grouped_gemm_multi_d.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index d3b75ac72f..ff66f26d61 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -3,332 +3,128 @@ #include -#include -#include -#include -#include -#include -#include -#include +#include "quant_run_grouped_gemm_example.hpp" -#include "ck_tile/core.hpp" -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" -#include "ck_tile/ops/gemm_quant.hpp" -#include "ck_tile/host.hpp" -#include "quant_grouped_gemm.hpp" +extern template int run_gemm_example_persistency( + const ck_tile::ArgParser&, std::string, std::string, bool); +extern template int run_gemm_example_persistency( + const ck_tile::ArgParser&, std::string, std::string, bool); +extern template int run_gemm_example_persistency( + const ck_tile::ArgParser&, std::string, std::string, bool); +extern template int run_gemm_example_persistency( + const ck_tile::ArgParser&, std::string, std::string, bool); +extern template int run_gemm_example_persistency( + const ck_tile::ArgParser&, std::string, std::string, bool); +extern template int run_gemm_example_persistency( + const ck_tile::ArgParser&, std::string, std::string, bool); +extern template int run_gemm_example_persistency( + const ck_tile::ArgParser&, std::string, std::string, bool); +extern template int run_gemm_example_persistency( + const ck_tile::ArgParser&, std::string, std::string, bool); -template -float grouped_gemm(const std::vector& gemm_descs, - const ck_tile::stream_config& s, - void* kargs_ptr) +auto create_args(int argc, char* argv[]) { - constexpr ck_tile::index_t TileParitionerGroupNum = 8; - constexpr ck_tile::index_t TileParitionerM01 = 4; + ck_tile::ArgParser arg_parser; + arg_parser.insert("Ms", "", "M dimensions - empty by default.") + .insert("Ns", "", "N dimensions - empty by default.") + .insert("Ks", "", "K dimensions - empty by default.") + .insert( + "stride_As", + "", + "Tensor A strides - it is empty by default.") // stride_As/stride_Bs/stride_Cs/stride_AQs/stride_BQs + // can be set to zero if + // Ms/Ns/Ks is not empty + .insert("stride_Bs", "", "Tensor B strides - it is empty by default.") + .insert("stride_Cs", "", "Tensor C strides - it is empty by default.") + .insert("stride_AQs", "", "Tensor AQ strides - it is empty by default.") + .insert("stride_BQs", "", "Tensor BQ strides - it is empty by default.") + .insert("a_layout", "R", "A tensor data layout - Row by default.") + .insert("b_layout", "C", "B tensor data layout - Column by default.") + .insert("c_layout", "R", "C tensor data layout - Row by default.") + .insert("validate", "1", "0. No validation, 1. Validation on CPU.") + .insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "10", "number of iterations before benchmark the kernel.") + .insert("repeat", "100", "number of iterations to benchmark the kernel.") + .insert("group_count", "8", "group count.") + .insert("kbatch", "1", "kbatch for SplitK") + .insert("quant_mode", "bquant", "Choose aquant, bquant (default), tensor, or rowcol") + .insert("init", "0", "0. Random, 2. One(s) (Constant)") + .insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent."); - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile:: - sequence>; - using TilePartitioner = ck_tile:: - GemmSpatiallyLocalTilePartitioner; - - using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - - using BaseGemmPipeline = - GemmQuantConfig::template BaseGemmPipeline; - - const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; - const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; - - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = ck_tile::memory_operation_enum::set; - - constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::BQuantGrouped; - using QuantGemmProblem = std::conditional_t< - UseGroupedQuant, - std::conditional_t, - ck_tile::GemmBQuantPipelineProblem>, - ck_tile::GemmRowColTensorQuantPipelineProblem>; - - using GemmPipeline = - GemmQuantConfig::template GemmPipeline; - - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - QuantGemmProblem::TransposeC, - memory_operation>>; - - using Kernel = ck_tile::QuantGroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } - - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); - - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); } -template -float grouped_gemm_tileloop(const ck_tile::stream_config& s, - const ck_tile::index_t num_groups, - void* kargs_ptr) -{ - constexpr ck_tile::index_t TileParitionerGroupNum = 8; - constexpr ck_tile::index_t TileParitionerM01 = 4; - - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile:: - sequence>; - using TilePartitioner = ck_tile:: - GemmSpatiallyLocalTilePartitioner; - - using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; - - float ave_time{0}; - - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; - - constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::BQuantGrouped; - - using QuantGemmProblem = std::conditional_t< - UseGroupedQuant, - std::conditional_t, - ck_tile::GemmBQuantPipelineProblem>, - ck_tile::GemmRowColTensorQuantPipelineProblem>; - - using GemmPipeline = - GemmQuantConfig::template GemmPipeline; - - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - QuantGemmProblem::TransposeC, - memory_operation>>; - using Kernel = ck_tile::QuantGroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - }; - - return ave_time = Run(ck_tile::integral_constant{}); -} - -#include "quant_run_grouped_gemm_example.inc" - int main(int argc, char* argv[]) { - int result1 = run_grouped_gemm_example(argc, argv); - return result1; + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + + const std::string a_layout = arg_parser.get_str("a_layout"); + const std::string b_layout = arg_parser.get_str("b_layout"); + const std::string data_type = arg_parser.get_str("prec"); + std::string quant_mode = arg_parser.get_str("quant_mode"); + bool persistent = arg_parser.get_bool("persistent"); + + if(data_type == "fp8") + { + if(quant_mode == "tensor") + { + return run_gemm_example_persistency( + arg_parser, a_layout, b_layout, persistent); + } + else if(quant_mode == "rowcol") + { + return run_gemm_example_persistency( + arg_parser, a_layout, b_layout, persistent); + } + else if(quant_mode == "aquant") + { + return run_gemm_example_persistency( + arg_parser, a_layout, b_layout, persistent); + } + else if(quant_mode == "bquant") + { + return run_gemm_example_persistency( + arg_parser, a_layout, b_layout, persistent); + } + else + { + throw std::runtime_error("Unsupported quantization mode!"); + } + } + if(data_type == "bf8") + { + if(quant_mode == "tensor") + { + return run_gemm_example_persistency( + arg_parser, a_layout, b_layout, persistent); + } + else if(quant_mode == "rowcol") + { + return run_gemm_example_persistency( + arg_parser, a_layout, b_layout, persistent); + } + else if(quant_mode == "aquant") + { + return run_gemm_example_persistency( + arg_parser, a_layout, b_layout, persistent); + } + else if(quant_mode == "bquant") + { + return run_gemm_example_persistency( + arg_parser, a_layout, b_layout, persistent); + } + else + { + throw std::runtime_error("Unsupported quantization mode!"); + } + } + else + { + throw std::runtime_error("Unsupported data type configuration."); + } } diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_bf8_aquant.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_bf8_aquant.cpp new file mode 100644 index 0000000000..0da7a55343 --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_bf8_aquant.cpp @@ -0,0 +1,7 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "quant_run_grouped_gemm_example.hpp" + +template int run_gemm_example_persistency( + const ck_tile::ArgParser&, std::string, std::string, bool); diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_bf8_bquant.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_bf8_bquant.cpp new file mode 100644 index 0000000000..135c8e20b8 --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_bf8_bquant.cpp @@ -0,0 +1,7 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "quant_run_grouped_gemm_example.hpp" + +template int run_gemm_example_persistency( + const ck_tile::ArgParser&, std::string, std::string, bool); diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_bf8_rowcol.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_bf8_rowcol.cpp new file mode 100644 index 0000000000..9ed59c6efa --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_bf8_rowcol.cpp @@ -0,0 +1,7 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "quant_run_grouped_gemm_example.hpp" + +template int run_gemm_example_persistency( + const ck_tile::ArgParser&, std::string, std::string, bool); diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_bf8_tensor.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_bf8_tensor.cpp new file mode 100644 index 0000000000..6c8d751f3f --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_bf8_tensor.cpp @@ -0,0 +1,7 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "quant_run_grouped_gemm_example.hpp" + +template int run_gemm_example_persistency( + const ck_tile::ArgParser&, std::string, std::string, bool); diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_config.hpp similarity index 68% rename from example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp rename to example/ck_tile/17_grouped_gemm/quant_grouped_gemm_config.hpp index 1fa8a03087..a1f287df6b 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_config.hpp @@ -64,8 +64,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp = 2; static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); }; @@ -152,57 +152,7 @@ struct GemmQuantConfig using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("Ms", "", "M dimensions - empty by default.") - .insert("Ns", "", "N dimensions - empty by default.") - .insert("Ks", "", "K dimensions - empty by default.") - .insert( - "stride_As", - "", - "Tensor A strides - it is empty by default.") // stride_As/stride_Bs/stride_Cs/stride_AQs/stride_BQs - // can be set to zero if - // Ms/Ns/Ks is not empty - .insert("stride_Bs", "", "Tensor B strides - it is empty by default.") - .insert("stride_Cs", "", "Tensor C strides - it is empty by default.") - .insert("stride_AQs", "", "Tensor AQ strides - it is empty by default.") - .insert("stride_BQs", "", "Tensor BQ strides - it is empty by default.") - .insert("a_layout", "R", "A tensor data layout - Row by default.") - .insert("b_layout", "C", "B tensor data layout - Row by default.") - .insert("c_layout", "R", "C tensor data layout - Row by default.") - .insert("validate", "1", "0. No validation, 1. Validation on CPU.") - .insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8") - .insert("warmup", "10", "number of iterations before benchmark the kernel.") - .insert("repeat", "100", "number of iterations to benchmark the kernel.") - .insert("group_count", "8", "group count.") - .insert("kbatch", "1", "kbatch for SplitK") - .insert("quant_mode", "bquant", "Choose aquant, bquant (default), tensor, or rowcol") - .insert("init", "0", "0. Random, 2. One(s) (Constant)") - .insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent."); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - inline std::size_t get_workspace_size(const std::vector& gemm_descs) { return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg); } - -template -float grouped_gemm_tileloop(const ck_tile::stream_config& s, - const ck_tile::index_t num_groups, - void* kargs_ptr); diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_fp8_aquant.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_fp8_aquant.cpp new file mode 100644 index 0000000000..1535848c62 --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_fp8_aquant.cpp @@ -0,0 +1,7 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "quant_run_grouped_gemm_example.hpp" + +template int run_gemm_example_persistency( + const ck_tile::ArgParser&, std::string, std::string, bool); diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_fp8_bquant.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_fp8_bquant.cpp new file mode 100644 index 0000000000..4711e06a89 --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_fp8_bquant.cpp @@ -0,0 +1,7 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "quant_run_grouped_gemm_example.hpp" + +template int run_gemm_example_persistency( + const ck_tile::ArgParser&, std::string, std::string, bool); diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_fp8_rowcol.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_fp8_rowcol.cpp new file mode 100644 index 0000000000..2ec60adb24 --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_fp8_rowcol.cpp @@ -0,0 +1,7 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "quant_run_grouped_gemm_example.hpp" + +template int run_gemm_example_persistency( + const ck_tile::ArgParser&, std::string, std::string, bool); diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_fp8_tensor.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_fp8_tensor.cpp new file mode 100644 index 0000000000..9c7dd37687 --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_fp8_tensor.cpp @@ -0,0 +1,7 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "quant_run_grouped_gemm_example.hpp" + +template int run_gemm_example_persistency( + const ck_tile::ArgParser&, std::string, std::string, bool); diff --git a/example/ck_tile/17_grouped_gemm/quant_invoke_grouped_gemm_kernel.hpp b/example/ck_tile/17_grouped_gemm/quant_invoke_grouped_gemm_kernel.hpp new file mode 100644 index 0000000000..16352722e1 --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/quant_invoke_grouped_gemm_kernel.hpp @@ -0,0 +1,313 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm_quant.hpp" + +template +float grouped_gemm(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) +{ + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = + GemmQuantConfig::template BaseGemmPipeline; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + + constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::BQuantGrouped; + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, + ck_tile::GemmRowColTensorQuantPipelineProblem>; + + using GemmPipeline = + GemmQuantConfig::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC, + memory_operation>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; + + return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); +} + +template +float grouped_gemm_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr) +{ + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + + float ave_time{0}; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::BQuantGrouped; + + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, + ck_tile::GemmRowColTensorQuantPipelineProblem>; + + using GemmPipeline = + GemmQuantConfig::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC, + memory_operation>>; + using Kernel = ck_tile::QuantGroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); + }; + + return ave_time = Run(ck_tile::integral_constant{}); +} diff --git a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.hpp similarity index 87% rename from example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc rename to example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.hpp index 37832b54ba..6a5bf192ca 100644 --- a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.hpp @@ -3,6 +3,24 @@ #pragma once +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" +#include "ck_tile/ops/gemm_quant.hpp" +#include "ck_tile/host.hpp" + +#include "quant_grouped_gemm_config.hpp" +#include "quant_invoke_grouped_gemm_kernel.hpp" + template static constexpr inline auto is_row_major(Layout layout_) { @@ -11,9 +29,9 @@ static constexpr inline auto is_row_major(Layout layout_) } template -auto calculate_rtol_atol(const ck_tile::index_t K, - const ck_tile::index_t kbatch, - const float max_accumulated_value) +static auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) { using ComputeType = std::conditional_t; @@ -170,21 +188,13 @@ template -int run_grouped_gemm_example_with_layouts(int argc, - char* argv[], +int run_grouped_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, const ALayout a_layout = ALayout{}, const AQLayout aq_layout = AQLayout{}, const BLayout b_layout = BLayout{}, const BQLayout bq_layout = BQLayout{}, [[maybe_unused]] const CLayout c_layout = CLayout{}) { - auto [result, arg_parser] = create_args(argc, argv); - - if(!result) - { - return -1; - }; - auto valid_input_data = [&](int group_count, const auto&... args) { return group_count != 0 && ((args.size() == static_cast(group_count)) && ...); }; @@ -540,7 +550,9 @@ int run_grouped_gemm_example_with_layouts(int argc, } template -int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser, + std::string a_layout, + std::string b_layout) { using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; @@ -556,7 +568,6 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a if(a_layout == "R" && b_layout == "C") { - return run_grouped_gemm_example_with_layouts( - argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); + arg_parser, Row{}, Row{}, Col{}, Col{}, Row{}); } - else + + if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped || + (QuantMode == ck_tile::QuantType::BQuantGrouped && !GemmConfig::PreshuffleB)) { - throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + if(a_layout == "R" && b_layout == "R") + { + return run_grouped_gemm_example_with_layouts( + arg_parser, Row{}, Row{}, Row{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_grouped_gemm_example_with_layouts( + arg_parser, Col{}, Col{}, Row{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_grouped_gemm_example_with_layouts( + arg_parser, Col{}, Col{}, Col{}, Col{}, Row{}); + } } + + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); } template -int run_gemm_example_persistency( - std::string a_layout, std::string b_layout, bool persistent, int argc, char* argv[]) +int run_gemm_example_persistency(const ck_tile::ArgParser& arg_parser, + std::string a_layout, + std::string b_layout, + bool persistent) { if(persistent) { using GemmConfig = GemmQuantConfig::template GemmConfig; return run_gemm_example_prec_type( - a_layout, b_layout, argc, argv); + arg_parser, a_layout, b_layout); } else { using GemmConfig = GemmQuantConfig::template GemmConfig; return run_gemm_example_prec_type( - a_layout, b_layout, argc, argv); - } -} - -int run_grouped_gemm_example(int argc, char* argv[]) -{ - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - { - return -1; - } - - const std::string a_layout = arg_parser.get_str("a_layout"); - const std::string b_layout = arg_parser.get_str("b_layout"); - const std::string data_type = arg_parser.get_str("prec"); - std::string quant_mode = arg_parser.get_str("quant_mode"); - bool persistent = arg_parser.get_bool("persistent"); - - if(data_type == "fp8") - { - if(quant_mode == "tensor") - { - return run_gemm_example_persistency( - a_layout, b_layout, persistent, argc, argv); - } - else if(quant_mode == "rowcol") - { - return run_gemm_example_persistency( - a_layout, b_layout, persistent, argc, argv); - } - else if(quant_mode == "aquant") - { - return run_gemm_example_persistency( - a_layout, b_layout, persistent, argc, argv); - } - else if(quant_mode == "bquant") - { - return run_gemm_example_persistency( - a_layout, b_layout, persistent, argc, argv); - } - else - { - throw std::runtime_error("Unsupported quantization mode!"); - } - } - if(data_type == "bf8") - { - if(quant_mode == "tensor") - { - return run_gemm_example_persistency( - a_layout, b_layout, persistent, argc, argv); - } - else if(quant_mode == "rowcol") - { - return run_gemm_example_persistency( - a_layout, b_layout, persistent, argc, argv); - } - else if(quant_mode == "aquant") - { - return run_gemm_example_persistency( - a_layout, b_layout, persistent, argc, argv); - } - else if(quant_mode == "bquant") - { - return run_gemm_example_persistency( - a_layout, b_layout, persistent, argc, argv); - } - else - { - throw std::runtime_error("Unsupported quantization mode!"); - } - } - else - { - throw std::runtime_error("Unsupported data type configuration."); + arg_parser, a_layout, b_layout); } } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 91c69472a6..b43066cdc5 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -422,7 +422,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( Policy::template MakeShuffledARegTileDistribution()); @@ -433,7 +433,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( diff --git a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt index 55f09726cc..5959e44f48 100644 --- a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt @@ -6,18 +6,21 @@ if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() -# if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") -# # Split into three separate test executables for faster parallel compilation -# add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp) -# target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") + # Split into three separate test executables for faster parallel compilation + add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp) + target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -# add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp) -# target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp) + target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -# add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp) -# target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp) + target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -# add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp) -# target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -# endif() + add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp) + target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant_preshuffleb test_grouped_gemm_quant_bquant_preshuffleb.cpp) + target_compile_options(test_ck_tile_grouped_gemm_quant_bquant_preshuffleb PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +endif() diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp index 8dcd6d017d..3b1aa967d1 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp @@ -21,13 +21,29 @@ using AQuant = std::integral_constant, std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>, + + // RCR BF8 (with/without TransposeC) std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, True>, std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, False>, + // RCR non-persistent (with/without TransposeC) std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, True>, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, False> + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, False>, + + // RRR layout (with/without TransposeC) + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>, + + // CRR layout (with/without TransposeC) + // NOT SUPPORTED: std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>, + + // CCR layout (with/without TransposeC) + // NOT SUPPORTED: std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp index 6c0ad545b7..e7f4486b23 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp @@ -21,13 +21,18 @@ using BQuant = std::integral_constant, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>, std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False, True, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False>, + // Non-persistent variant std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, False, False>, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, False, False> + + // Alternative layouts: RRR, CRR, CCR + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant_preshuffleb.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant_preshuffleb.cpp new file mode 100644 index 0000000000..cc6e84960b --- /dev/null +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant_preshuffleb.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_util_quant.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using True = ck_tile::bool_constant; +using False = ck_tile::bool_constant; +using BQuant = std::integral_constant; + +// clang-format off +using KernelTypes_BQuant_PreshuffleB = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + + // Base instances: RCR FP8/BF16 persistent + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False>, + + // Non-persistent variant + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, False, False> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_BQuant_PreshuffleB, KernelTypes_BQuant_PreshuffleB); + +#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_BQuant_PreshuffleB +#include "test_grouped_gemm_quant_ut_cases.inc" +#undef TEST_CLASS_NAME diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_ut_cases.inc b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_ut_cases.inc index bdb929d923..53bfb26bb2 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_ut_cases.inc +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_ut_cases.inc @@ -5,7 +5,7 @@ TYPED_TEST(TEST_CLASS_NAME, Basic) { - const int group_count = 8; + const int group_count = 6; std::vector Ms; std::vector Ns; std::vector Ks; diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp index 9941066c3e..b73221ac28 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp @@ -31,8 +31,8 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test using DsDataType = ck_tile::tuple<>; using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - using AQLayout = Row; - using BQLayout = Col; + using AQLayout = ALayout; + using BQLayout = BLayout; static constexpr bool PreshuffleB = std::tuple_element_t<10, Tuple>::value; static constexpr bool Persistent = std::tuple_element_t<11, Tuple>::value; static constexpr bool TransposeC = std::tuple_element_t<12, Tuple>::value; @@ -44,8 +44,8 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test static const bool kPadK = false; static const int kBlockPerCu = 1; - static const ck_tile::index_t M_Tile = 256; - static const ck_tile::index_t N_Tile = 256; + static const ck_tile::index_t M_Tile = 128; + static const ck_tile::index_t N_Tile = 128; static const ck_tile::index_t K_Tile = 128; static const ck_tile::index_t M_Warp = 2; @@ -782,3 +782,6 @@ using TestCkTileGroupedGemmQuant_AQuant = TestCkTileGroupedGemmQuant; template using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant; + +template +using TestCkTileGroupedGemmQuant_BQuant_PreshuffleB = TestCkTileGroupedGemmQuant;