From e8653f314dfb6b4a21cbf988ee3eda079c05436c Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Thu, 9 Oct 2025 17:10:20 +0000 Subject: [PATCH] feat(grouped_gemm_multi_d): add support for bf16 [ROCm/composable_kernel commit: 8d8b49dec2a3a8e5e3c144dbdcc1280ca58dd52a] --- .../17_grouped_gemm/grouped_gemm_multi_d.hpp | 45 ++++++-- .../run_grouped_gemm_multi_d_example.inc | 104 +++++++++++++----- 2 files changed, 111 insertions(+), 38 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp index 0789452ada..12d70eecb6 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp @@ -15,14 +15,6 @@ #define CK_TILE_PIPELINE_MEMORY 2 #define CK_TILE_PIPELINE_COMPUTE_V4 3 -using ADataType = ck_tile::half_t; -using BDataType = ck_tile::half_t; -using D0DataType = ck_tile::half_t; -using D1DataType = ck_tile::half_t; -using EDataType = ck_tile::half_t; -using DsDataType = ck_tile::tuple; -using AccDataType = float; - template constexpr ck_tile::index_t get_k_warp_tile() { @@ -173,7 +165,38 @@ struct PipelineTypeTraits using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; }; -using grouped_gemm_multi_d_kargs = ck_tile::GroupedGemmHostArgs; +template +struct GemmMultiDTypeConfig; + +template <> +struct GemmMultiDTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using D0DataType = ck_tile::half_t; + using D1DataType = ck_tile::half_t; + using EDataType = ck_tile::half_t; + using DsDataType = ck_tile::tuple; + using AccDataType = float; +}; + +template <> +struct GemmMultiDTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using D0DataType = ck_tile::bf16_t; + using D1DataType = ck_tile::bf16_t; + using EDataType = ck_tile::bf16_t; + using DsDataType = ck_tile::tuple; + using AccDataType = float; +}; + +// Deduce the number of D tensors from the DsDataType tuple size +// All precision configs have the same number of D tensors, so we can use any one +constexpr std::size_t NumDTensor = GemmMultiDTypeConfig::DsDataType::size(); + +using grouped_gemm_multi_d_kargs = ck_tile::GroupedGemmHostArgs; std::pair create_args(int argc, char* argv[]) { @@ -190,7 +213,7 @@ std::pair create_args(int argc, char* argv[]) .insert("ds_layout", "R", "Ds tensor data layout - Row by default.") .insert("e_layout", "R", "E tensor data layout - Row by default.") .insert("validate", "1", "0. No validation, 1. Validation on CPU.") - .insert("prec", "fp16", "data type. fp16") + .insert("prec", "bf16", "data type. fp16/bf16") .insert("warmup", "10", "number of iterations before benchmark the kernel.") .insert("repeat", "100", "number of iterations to benchmark the kernel.") .insert("group_count", "8", "group count.") @@ -204,7 +227,7 @@ std::pair create_args(int argc, char* argv[]) inline std::size_t get_workspace_size(const std::vector& gemm_descs) { - return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); + return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); } template >{}; } +template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) @@ -86,31 +91,31 @@ float invoke_gemm(int n_warmup, } else { - std::vector> kargs; + std::vector> kargs; void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); const bool splitk = args[0].k_batch > 1; for(const auto& arg : args) { - kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<1, 1, 2>{{arg.a_ptr}, - {arg.b_ptr}, - arg.ds_ptr, - arg.e_ptr, - arg.M, - arg.N, - arg.K, - {arg.stride_A}, - {arg.stride_B}, - arg.stride_Ds, - arg.stride_E, - arg.k_batch}); + kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<1, 1, NumDTensor>{{arg.a_ptr}, + {arg.b_ptr}, + arg.ds_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + {arg.stride_A}, + {arg.stride_B}, + arg.stride_Ds, + arg.stride_E, + arg.k_batch}); } const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; - HIP_CHECK_ERROR(hipMemcpyWithStream( - kargs_ptr, - kargs.data(), - kargs.size() * sizeof(ck_tile::GemmTransKernelArg), - hipMemcpyHostToDevice, - stream.stream_id_)); + HIP_CHECK_ERROR( + hipMemcpyWithStream(kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::GemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); ave_time = grouped_gemm_multi_d_tileloop; + using DsDataType = ck_tile::tuple; auto valid_input_data = [&](int group_count, const auto&... args) { return !(args.empty() || ...) && group_count == (args.size() == ...); @@ -360,7 +372,9 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, const float max_accumulated_value = *std::max_element(e_m_n_host_refs[i].mData.begin(), e_m_n_host_refs[i].mData.end()); - const auto rtol_atol = calculate_rtol_atol(Ks[i], 1, max_accumulated_value); + const auto rtol_atol = + calculate_rtol_atol( + Ks[i], 1, max_accumulated_value); pass &= ck_tile::check_err(e_m_n_tensors[i], @@ -390,6 +404,38 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, return pass; } +template +int run_gemm_multi_d_example_prec_type( + std::string a_layout, std::string b_layout, std::string ds_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using Types = GemmMultiDTypeConfig; + + using ADataType = typename Types::ADataType; + using BDataType = typename Types::BDataType; + using D0DataType = typename Types::D0DataType; + using D1DataType = typename Types::D1DataType; + using AccDataType = typename Types::AccDataType; + using EDataType = typename Types::EDataType; + + if(a_layout == "R" && b_layout == "C" && ds_layout == "R") + { + return run_grouped_gemm_multi_d_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for provided tensors!"); + } +} + template int run_grouped_gemm_multi_d_example(int argc, char* argv[]) { @@ -401,17 +447,21 @@ int run_grouped_gemm_multi_d_example(int argc, char* argv[]) const std::string a_layout = arg_parser.get_str("a_layout"); const std::string b_layout = arg_parser.get_str("b_layout"); const std::string ds_layout = arg_parser.get_str("ds_layout"); + const std::string data_type = arg_parser.get_str("prec"); - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - - if(a_layout == "R" && b_layout == "C" && ds_layout == "R") + if(data_type == "fp16") { - return run_grouped_gemm_multi_d_example_with_layouts( - argc, argv, Row{}, Col{}, Row{}, Row{}, Row{}); + return run_gemm_multi_d_example_prec_type( + a_layout, b_layout, ds_layout, argc, argv); + } + else if(data_type == "bf16") + { + return run_gemm_multi_d_example_prec_type( + a_layout, b_layout, ds_layout, argc, argv); } else { - throw std::runtime_error("Unsupported data layout configuration for provided tensors!"); + throw std::runtime_error( + "Unsupported data type configuration. Only fp16 and bf16 are supported."); } }