From fae4ebac66c60614d1fd503e5eca911e476e2d34 Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Wed, 9 Jul 2025 06:37:29 +0000 Subject: [PATCH] Add support for contiguous grouped gemm --- .../19_grouped_flatmm/grouped_flatmm.cpp | 86 +++++--- .../19_grouped_flatmm/grouped_flatmm.hpp | 47 ++-- .../run_grouped_flatmm_example.inc | 207 ++++++++++++++++++ .../flatmm/kernel/grouped_flatmm_kernel.hpp | 108 +++++++-- 4 files changed, 387 insertions(+), 61 deletions(-) diff --git a/example/ck_tile/19_grouped_flatmm/grouped_flatmm.cpp b/example/ck_tile/19_grouped_flatmm/grouped_flatmm.cpp index b53dd96e2a..5c45e37d73 100644 --- a/example/ck_tile/19_grouped_flatmm/grouped_flatmm.cpp +++ b/example/ck_tile/19_grouped_flatmm/grouped_flatmm.cpp @@ -19,8 +19,9 @@ template -float grouped_flatmm(const ck_tile::GroupedFlatmmHostArgs& args, const ck_tile::stream_config& s) + typename CLayout, + typename KernelArguments> +float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. constexpr bool kPadM = false; @@ -76,12 +77,12 @@ float grouped_flatmm(const ck_tile::GroupedFlatmmHostArgs& args, const ck_tile:: constexpr auto tail_number_v = tail_number_.value; constexpr auto memory_operation = memory_operation_.value; using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem; + BDataType, + AccDataType, + CodegenFlatmmShape, + CodegenGemmTraits, + has_hot_loop_v, + tail_number_v>; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem( - argc, argv, Row{}, Col{}, Row{}); + if(data_type == "fp16") + { + run_grouped_flatmm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf16") + { + run_grouped_flatmm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "fp8") + { + run_grouped_flatmm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf8") + { + run_grouped_flatmm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } } - else if(data_type == "bf16") + else if(mode == "contiguous") { - run_grouped_flatmm_example_with_layouts( - argc, argv, Row{}, Col{}, Row{}); - } - else if(data_type == "fp8") - { - run_grouped_flatmm_example_with_layouts( - argc, argv, Row{}, Col{}, Row{}); - } - else if(data_type == "bf8") - { - run_grouped_flatmm_example_with_layouts( - argc, argv, Row{}, Col{}, Row{}); + if(data_type == "fp16") + { + run_contiguous_grouped_flatmm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf16") + { + run_contiguous_grouped_flatmm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "fp8") + { + run_contiguous_grouped_flatmm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf8") + { + run_contiguous_grouped_flatmm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } } else { - throw std::runtime_error("Unsupported data_type!"); + throw std::runtime_error("Unsupported mode!"); } } else diff --git a/example/ck_tile/19_grouped_flatmm/grouped_flatmm.hpp b/example/ck_tile/19_grouped_flatmm/grouped_flatmm.hpp index 7c955709c9..a5bc9aee00 100644 --- a/example/ck_tile/19_grouped_flatmm/grouped_flatmm.hpp +++ b/example/ck_tile/19_grouped_flatmm/grouped_flatmm.hpp @@ -29,11 +29,11 @@ #define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE #endif -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) +#if (CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) #define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem #define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem #define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) +#elif (CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) #define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 #define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 #define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave @@ -153,9 +153,9 @@ struct GemmConfig static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 128; #elif defined(USING_MFMA_32x32x64_F8) // MI350 FP8 32X32 (need tune) - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 128; + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128; static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; @@ -165,9 +165,9 @@ struct GemmConfig static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 64; #elif defined(USING_MFMA_16x16x32_F8) // MI300 FP8 16X16 - static constexpr ck_tile::index_t M_Tile = 16; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 256; + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 256; static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; @@ -177,9 +177,9 @@ struct GemmConfig static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 64; #elif defined(USING_MFMA_32x32x16_F8) // MI300 FP8 32X32 (need tune) - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 128; + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128; static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 8; @@ -222,9 +222,9 @@ struct GemmConfig static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 128; #elif defined(USING_MFMA_32x32x64_F8) // MI350 FP8 32X32 (need tune) - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; @@ -234,9 +234,9 @@ struct GemmConfig static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 64; #elif defined(USING_MFMA_16x16x32_F16) // MI350 FP16 16X16 (need tune) - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; @@ -246,9 +246,9 @@ struct GemmConfig static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 32; #elif defined(USING_MFMA_32x32x16_F16) // MI350 FP16 32X32 (need tune) - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; @@ -323,15 +323,16 @@ struct GemmConfig auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; - arg_parser.insert("Ms", "512,128,1024", "m dimension") - .insert("Ns", "256,512,1024", "n dimension") - .insert("Ks", "512,1024,1024", "k dimension") + arg_parser.insert("Ms", "512,256,1024", "m dimension") + .insert("Ns", "1024,512,256", "n dimension") + .insert("Ks", "1024,1024,512", "k dimension") .insert("group_count", "3", "group count") .insert("a_layout", "R", "A tensor data layout - Row by default") .insert("b_layout", "C", "B tensor data layout - Row by default") .insert("c_layout", "R", "C tensor data layout - Row by default") .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") .insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8") + .insert("mode", "general", "grouped gemm mode: [general | contiguous], general by default") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "10", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") diff --git a/example/ck_tile/19_grouped_flatmm/run_grouped_flatmm_example.inc b/example/ck_tile/19_grouped_flatmm/run_grouped_flatmm_example.inc index 07e3c345ab..40a42aa94d 100644 --- a/example/ck_tile/19_grouped_flatmm/run_grouped_flatmm_example.inc +++ b/example/ck_tile/19_grouped_flatmm/run_grouped_flatmm_example.inc @@ -89,6 +89,35 @@ float invoke_gemm(int n_warmup, int n_repeat, const ck_tile::GroupedFlatmmHostAr return ave_time; } +template +float invoke_gemm(int n_warmup, int n_repeat, const ck_tile::ContiguousGroupedFlatmmHostArgs& args) +{ + float ave_time = + grouped_flatmm( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::string op_name{"Grouped Gemm"}; + + std::size_t flop = std::size_t(2) * args.M * args.N * args.K; + std::size_t num_byte = sizeof(ADataType) * args.M * args.K + + sizeof(BDataType) * args.N * args.K + + sizeof(CDataType) * args.M * args.N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + return ave_time; +} + template int run_grouped_flatmm_example_with_layouts(int argc, char* argv[], @@ -318,3 +347,181 @@ int run_grouped_flatmm_example_with_layouts(int argc, return pass; } + +template +int run_contiguous_grouped_flatmm_example_with_layouts( + int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + + if(!result) + { + return -1; + }; + + using ADataType = typename GemmBasicTypeConfig::ADataType; + using BDataType = typename GemmBasicTypeConfig::BDataType; + using CDataType = typename GemmBasicTypeConfig::CDataType; + using AccDataType = typename GemmBasicTypeConfig::AccDataType; + + constexpr int BlockM = GemmConfig::M_Tile; + + const int group_count = arg_parser.get_int("group_count"); + const int repeat = arg_parser.get_int("repeat"); + const int warmup = arg_parser.get_int("warmup"); + + std::vector Ms = arg_parser.get_int_vec("Ms"); + std::vector Ns = arg_parser.get_int_vec("Ns"); + std::vector Ks = arg_parser.get_int_vec("Ks"); + + if(!(int(Ms.size()) == group_count)) + { + std::cout << "Please check the input data." << std::endl; + // padding additional Ms if needed + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256 + 64 * i); + } + } + + ck_tile::index_t M = + std::reduce(Ms.begin(), Ms.begin() + group_count, 0, [](auto acc, auto group_m) { + // round up to the multiple of BlockM + return acc + (group_m + BlockM - 1) / BlockM * BlockM; + }); + std::cout << "Total M: " << M << std::endl; + ck_tile::index_t N = Ns[0]; + ck_tile::index_t K = Ks[0]; + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + + ck_tile::index_t stride_A = 0; + ck_tile::index_t stride_B = 0; + ck_tile::index_t stride_C = 0; + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_B = ck_tile::get_default_stride(K, N * group_count, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(c_layout)); + + ck_tile::HostTensor a_m_k_tensor( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor b_k_n_tensor(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(K, N * group_count, stride_B, is_row_major(b_layout)))); + ck_tile::HostTensor c_m_n_tensor(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(c_layout)))); + + std::vector m_indices(std::size_t(M), -1); + int indices_fill_start = 0; + for(int i = 0; i < group_count; ++i) + { + int group_m = Ms[i]; + int padded_group_m = (group_m + BlockM - 1) / BlockM * BlockM; + for(int j = 0; j < padded_group_m; j++) + { + m_indices[indices_fill_start + j] = j < group_m ? i : -1; // -1 for padding + } + indices_fill_start += padded_group_m; + } + + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensor); + ck_tile::FillUniformDistribution{-4.f, 4.f}(b_k_n_tensor); + c_m_n_tensor.SetZero(); + + ck_tile::HostTensor b_shuffle_host = shuffle_b(b_k_n_tensor); + + std::unique_ptr a_m_k_dev_buf( + std::make_unique(a_m_k_tensor.get_element_space_size_in_bytes())); + std::unique_ptr b_shfl_dev_buf( + std::make_unique(b_shuffle_host.get_element_space_size_in_bytes())); + std::unique_ptr c_m_n_dev_buf( + std::make_unique(c_m_n_tensor.get_element_space_size_in_bytes())); + c_m_n_dev_buf->SetZero(); + + ck_tile::DeviceMem m_indices_dev_buf(M * sizeof(ck_tile::index_t)); + m_indices_dev_buf.ToDevice(m_indices.data()); + + ck_tile::ContiguousGroupedFlatmmHostArgs kernal_args{ + static_cast(m_indices_dev_buf.GetDeviceBuffer()), + M, + N, + K, + a_m_k_dev_buf->GetDeviceBuffer(), + stride_A, + b_shfl_dev_buf->GetDeviceBuffer(), + stride_B, + c_m_n_dev_buf->GetDeviceBuffer(), + stride_C, + kbatch, + }; + + invoke_gemm( + warmup, repeat, kernal_args); + c_m_n_dev_buf->FromDevice(c_m_n_tensor.data()); + + bool pass{true}; + if(arg_parser.get_int("v") == 1) + { + throw std::runtime_error( + "Not support v=1 host verification in contiguous grouped gemm, use " + "v=2 device verification instead"); + } + else if(arg_parser.get_int("v") == 2) + { + BDataType* d_B; + CDataType* d_C; + ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType))); + ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType))); + + ck_tile::HostTensor c_gpu_ref_host( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + ck_tile::index_t acc_m = 0; + for(int i = 0; i < group_count; ++i) + { + ck_tile::index_t padded_M = (Ms[i] + BlockM - 1) / BlockM * BlockM; + + ck_tile::hip_check_error(hipMemcpy(d_B, + b_k_n_tensor.data() + group_count * N * K, + N * K * sizeof(BDataType), + hipMemcpyHostToDevice)); + ck_tile::reference_gemm_gpu( + static_cast(a_m_k_dev_buf->GetDeviceBuffer()) + acc_m * K, + d_B, + d_C + acc_m * N, + padded_M, + N, + K, + stride_A, + stride_B, + stride_C); + acc_m += padded_M; + } + ck_tile::hip_check_error(hipMemcpy( + c_gpu_ref_host.data(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + ck_tile::hip_check_error(hipFree(d_B)); + ck_tile::hip_check_error(hipFree(d_C)); + + float rtol = 1e-3; + float atol = 1e-3; + + pass = ck_tile::check_err( + c_m_n_tensor, c_gpu_ref_host, "Error: Incorrect results!", rtol, atol); + + std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol + << std::endl; + std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; + } + + return pass; +} diff --git a/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp index 6b8ffcecfb..d67ef6f33b 100644 --- a/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp @@ -53,10 +53,52 @@ struct GroupedFlatmmHostArgs index_t k_batch; }; +struct ContiguousGroupedFlatmmHostArgs +{ + CK_TILE_HOST ContiguousGroupedFlatmmHostArgs() = default; + CK_TILE_HOST ContiguousGroupedFlatmmHostArgs(index_t* M_indices_, + index_t M_, + index_t N_, + index_t K_, + const void* a_ptr_, + index_t stride_A_, + const void* b_shuffle_ptr_, + index_t stride_B_, + void* c_ptr_, + index_t stride_C_, + index_t k_batch_) + : M_indices(M_indices_), + M(M_), + N(N_), + K(K_), + a_ptr(a_ptr_), + stride_A(stride_A_), + b_shuffle_ptr(b_shuffle_ptr_), + stride_B(stride_B_), + c_ptr(c_ptr_), + stride_C(stride_C_), + k_batch(k_batch_) + { + } + + index_t* M_indices; + index_t M; + index_t N; + index_t K; + const void* a_ptr; + index_t stride_A; + const void* b_shuffle_ptr; + index_t stride_B; + void* c_ptr; + index_t stride_C; + index_t k_batch; +}; + template struct GroupedFlatmmKernel : FlatmmKernel { using UnderlyingGemmKernel = FlatmmKernel; + using BlockGemmShape = typename UnderlyingGemmKernel::BlockGemmShape; using TilePartitioner = remove_cvref_t; using FlatmmPipeline = remove_cvref_t; @@ -68,15 +110,13 @@ struct GroupedFlatmmKernel : FlatmmKernel; - using GroupedFlatmmKernelArgs = GroupedFlatmmHostArgs; - CK_TILE_HOST static const std::string GetName() { return concat( '_', "grouped_flatmm", gemm_prec_str, FlatmmPipeline::GetName()); } - - CK_TILE_HOST_DEVICE static auto GridSize(const GroupedFlatmmKernelArgs& kernelArgs) + template + CK_TILE_HOST_DEVICE static auto GridSizeImpl(const KernelArgs& kernelArgs) { hipDeviceProp_t prop; int deviceId = 0; // default device @@ -89,29 +129,41 @@ struct GroupedFlatmmKernel : FlatmmKernel(GroupedFlatmmKernel::Kernel), - reinterpret_cast( - kentry2), + reinterpret_cast(kentry2), block_size, dync_smem_size); const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU; - // print maxActiveBlocksPerCU and persistent_block_size - // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU - // << ", persistent_block_size: " << persistent_block_size << std::endl; + std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU + << ", persistent_block_size: " << persistent_block_size << std::endl; assert(kernelArgs.k_batch == 1); return dim3(persistent_block_size, 1, kernelArgs.k_batch); } - CK_TILE_HOST static constexpr GroupedFlatmmKernelArgs - MakeKernelArgs(const GroupedFlatmmHostArgs& hostArgs) + CK_TILE_HOST_DEVICE static auto + GridSize([[maybe_unused]] const GroupedFlatmmHostArgs& kernelArgs) + { + return GridSizeImpl(kernelArgs); + } + CK_TILE_HOST_DEVICE static auto + GridSize([[maybe_unused]] const ContiguousGroupedFlatmmHostArgs& kernelArgs) + { + return GridSizeImpl(kernelArgs); + } + + CK_TILE_HOST static constexpr auto MakeKernelArgs(const GroupedFlatmmHostArgs& hostArgs) + { + return hostArgs; + } + CK_TILE_HOST static constexpr auto + MakeKernelArgs(const ContiguousGroupedFlatmmHostArgs& hostArgs) { return hostArgs; } - CK_TILE_DEVICE void operator()(GroupedFlatmmKernelArgs kargs) const + CK_TILE_DEVICE void operator()(GroupedFlatmmHostArgs kargs) const { int group_idx = 0; int block_linear_idx = blockIdx.x; @@ -147,6 +199,36 @@ struct GroupedFlatmmKernel : FlatmmKernel(kargs.b_shuffle_ptr) + group_idx * kargs.N * kargs.K, + kargs.c_ptr, + kargs.M, + kargs.N, + kargs.K, + kargs.stride_A, + kargs.stride_B, + kargs.stride_C, + kargs.k_batch, + }; + // call the underlying flatmm kernel + underlying_kernel(impl_kargs, block_linear_idx); + } + } }; } // namespace ck_tile