From 790dfe9bcd7d5834f980644bf28176091a5b96a3 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Mon, 21 Apr 2025 08:39:45 -0700 Subject: [PATCH] multi instance generation for CkTileEngine (#2080) * Add support for multi-instance verification, print detail for each instance, documentation fix * clang formatted * Added Readme file * updated readme * Addressing review comments * clang formatted * Updated ReadMe and GPU reference code * simplified dispatch kernel code * indentation [ROCm/composable_kernel commit: 7cadf187e28693eb211c9cfb76d72ba0d6fb28b8] --- tile_engine/ops/gemm/README.md | 51 ++++++ .../gemm/configs/instance_combination.json | 2 +- tile_engine/ops/gemm/gemm_host_api.cpp | 79 +++++----- tile_engine/ops/gemm/gemm_host_api.hpp | 146 +++++++----------- tile_engine/ops/gemm/gemm_instance_builder.py | 64 ++++++-- 5 files changed, 202 insertions(+), 140 deletions(-) create mode 100644 tile_engine/ops/gemm/README.md diff --git a/tile_engine/ops/gemm/README.md b/tile_engine/ops/gemm/README.md new file mode 100644 index 0000000000..495232f19b --- /dev/null +++ b/tile_engine/ops/gemm/README.md @@ -0,0 +1,51 @@ +# GEMM Matrix Multiplication + +Use the files in this folder to generate and build applications that run Matrix multiplications using ck_tile programming based on the kernel parameters mentioned in the config file `./configs/instance_combination.json`. + +# Kernel Configurations + +User needs to provide kernel configuration such as datatype, layout, tile size, warp size, padding, pipeline, scheduler and epilogue in the config file. For reference please see `./configs/instance_combination.json` + +## Build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# To generate the executable +make tile_engine_gemm -j +``` +`tile_engine_gemm` will be located in the `./bin/` directory. + +## tile_engine_gemm inputs +``` + + -m m dimension (default:3840) + -n n dimension (default:4096) + -k k dimension (default:2048) + -stride_a Tensor A stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_c Tensor C stride (default:0) + -split_k SplitK value (default:1) + -v No validation: 0, Validation on CPU: 1, Validation on GPU: 2 (default:2) + -warmup Number of iterations before benchmark the kernel (default:50) + -repeat Number of iterations to benchmark the kernel (default:100) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) + -init Value for initializing tensor- random: 0, linear: 1, constant(1): 2 (default:0) + -pipeline possible values are: compv3, compv4, mem (default:compv3) + -scheduler possible values are: intrawave, interwave (default:intrawave) + -epilogue possible values are: cshuffle, default (default:cshuffle) + -pad_m Pad in m direction - true/false (default:false) + -pad_n Pad in n direction - true/false (default:false) + -pad_k Pad in k direction - true/false (default:false) + +Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in instance_combination.json +``` + +## Example + +Below example will run gemm kernel with default dimensions of matrices, for compv3 pipeline, intrawave scheduler and default epilogue with all possible tile sizes mentioned in Config file. + +``` +./bin/tile_engine_gemm -pipeline=compv3 -scheduler=intrawave -epilogue=default +``` diff --git a/tile_engine/ops/gemm/configs/instance_combination.json b/tile_engine/ops/gemm/configs/instance_combination.json index e21197d1de..e23df11500 100644 --- a/tile_engine/ops/gemm/configs/instance_combination.json +++ b/tile_engine/ops/gemm/configs/instance_combination.json @@ -19,7 +19,7 @@ "values": [256] }, "tile_k": { - "values": [64] + "values": [64, 32] }, "warp_m": { "values": [2] diff --git a/tile_engine/ops/gemm/gemm_host_api.cpp b/tile_engine/ops/gemm/gemm_host_api.cpp index 508f634920..3cef425a51 100644 --- a/tile_engine/ops/gemm/gemm_host_api.cpp +++ b/tile_engine/ops/gemm/gemm_host_api.cpp @@ -6,11 +6,16 @@ #include "gemm_dispatcher.hpp" #include "gemm_host_api.hpp" -float gemm_kernel_launch(KernelTraits& trait, - ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& s) +void gemm_kernel_launch(ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::HostTensor& c_m_n_dev_result, + int verify, + KernelTraits& trait, + ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& s) { - return GemmDispatcher::dispatch(trait, args, s); + return GemmDispatcher::dispatch( + c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, trait, args, s); } template -bool run(const ck_tile::ArgParser& arg_parser) +void run(const ck_tile::ArgParser& arg_parser) { const ALayout a_layout = ALayout{}; const BLayout b_layout = BLayout{}; - // const CLayout c_layout = CLayout{}; ck_tile::index_t kbatch = arg_parser.get_int("split_k"); ck_tile::index_t M = arg_parser.get_int("m"); @@ -113,43 +117,47 @@ bool run(const ck_tile::ArgParser& arg_parser) trait.kPadN = arg_parser.get_bool("pad_n"); trait.kPadK = arg_parser.get_bool("pad_k"); - float ave_time = gemm_kernel_launch( - trait, gemm_args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_byte = - sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_byte / 1.E6 / ave_time; - std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C << " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits::name << " B Type = " << DataTypeTraits::name - << " C Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " - << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + << " C Type = " << DataTypeTraits::name << std::endl; + + ck_tile::HostTensor c_m_n_host_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); - bool pass = true; if(verify) { - pass = gemm_verify( - verify, - a_m_k, - b_k_n, - c_m_n_dev_result, - a_m_k_dev_buf, - b_k_n_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - kbatch); + gemm_host_reference(verify, + a_m_k, + b_k_n, + c_m_n_host_result, + a_m_k_dev_buf, + b_k_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C); } - return pass; + + gemm_kernel_launch(c_m_n_dev_buf, + c_m_n_host_result, + c_m_n_dev_result, + verify, + trait, + gemm_args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + return; } int main(int argc, char* argv[]) @@ -159,7 +167,8 @@ int main(int argc, char* argv[]) auto [result, parser] = create_args(argc, argv); if(!result) return EXIT_FAILURE; - return run(parser); + run(parser); + return 0; } catch(const std::exception& e) { diff --git a/tile_engine/ops/gemm/gemm_host_api.hpp b/tile_engine/ops/gemm/gemm_host_api.hpp index 375f808966..c1e1e1dc4f 100644 --- a/tile_engine/ops/gemm/gemm_host_api.hpp +++ b/tile_engine/ops/gemm/gemm_host_api.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + #include #include @@ -54,24 +57,21 @@ struct DataTypeTraits static constexpr const char* name = "pk_int4_t"; }; -/** - * @brief trait for GEMM kernel - * @param pipeline: pipeline name - * @param scheduler: scheduler name - * @param epilogue: epilogue name - * @param kPadM: padding for M dimension - * @param kPadN: padding for N dimension - * @param kPadK: padding for K dimension - * - */ - +/// @brief Defines the configuration parameters for a GEMM operation, enabling the selection of a +/// specific kernel instance based on the provided settings. struct KernelTraits { + /// @brief The name of the pipeline. std::string pipeline; + /// @brief The name of the scheduler (e.g., "intrawave", "interwave"). std::string scheduler; + /// @brief The name of the epilogue (e.g., "cshuffle", "default"). std::string epilogue; + /// @brief Indicates whether padding is applied to the M dimension. bool kPadM; + /// @brief Indicates whether padding is applied to the N dimension. bool kPadN; + /// @brief Indicates whether padding is applied to the K dimension. bool kPadK; }; @@ -184,11 +184,28 @@ void permute_vectors_i4x4_b(Tensor& tensor) } } -/** - * @brief Function to verify the kernel output with reference implementation on CPU/GPU - * - */ +/// @brief Function to compare the results of the device and host computations +void compare(ck_tile::index_t K, + ck_tile::index_t kbatch, + ck_tile::HostTensor& c_m_n_dev_result, + ck_tile::HostTensor& c_m_n_host_result) +{ + const float max_accumulated_value = + *std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_result, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; + std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl; +} + +/// @brief Function to get the kernel output with reference implementation on CPU/GPU template -bool gemm_verify(int verify, - ck_tile::HostTensor& a_m_k, - ck_tile::HostTensor& b_k_n, - ck_tile::HostTensor& c_m_n_dev_result, - ck_tile::DeviceMem& a_m_k_dev_buf, - ck_tile::DeviceMem& b_k_n_dev_buf, - ck_tile::index_t M, - ck_tile::index_t N, - ck_tile::index_t K, - ck_tile::index_t stride_A, - ck_tile::index_t stride_B, - ck_tile::index_t stride_C, - ck_tile::index_t kbatch) +void gemm_host_reference(int verify, + ck_tile::HostTensor& a_m_k, + ck_tile::HostTensor& b_k_n, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C) { - bool pass = true; if(verify == 1) { - ck_tile::HostTensor c_m_n_host_ref( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - c_m_n_host_ref.SetZero(); + c_m_n_host_result.SetZero(); ck_tile::reference_gemm( - a_m_k, b_k_n, c_m_n_host_ref); - const float max_accumulated_value = - *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol( - K, kbatch, max_accumulated_value); - pass = ck_tile::check_err(c_m_n_dev_result, - c_m_n_host_ref, - "Error: Incorrect results!", - rtol_atol.at(ck_tile::number<0>{}), - rtol_atol.at(ck_tile::number<1>{})); - - std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) - << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) - << std::endl; - std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + a_m_k, b_k_n, c_m_n_host_result); } else if(verify == 2) { @@ -241,29 +240,14 @@ bool gemm_verify(int verify, // Restore input for B for gpu reference b_k_n_dev_buf.ToDevice(b_k_n.data()); } - ck_tile::HostTensor c_m_n_gpu_ref( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); - c_m_n_gpu_ref.SetZero(); + + ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_host_result.get_element_space_size_in_bytes()); + c_m_n_host_result.SetZero(); c_m_n_gpu_buf_ref.SetZero(); - ADataType* d_A; - BDataType* d_B; - CDataType* d_C; - - ck_tile::hip_check_error(hipMalloc(&d_A, a_m_k.get_element_space_size_in_bytes())); - ck_tile::hip_check_error(hipMalloc(&d_B, b_k_n.get_element_space_size_in_bytes())); - ck_tile::hip_check_error( - hipMalloc(&d_C, c_m_n_dev_result.get_element_space_size_in_bytes())); - - ck_tile::hip_check_error(hipMemcpy(d_A, - a_m_k_dev_buf.GetDeviceBuffer(), - a_m_k.get_element_space_size_in_bytes(), - hipMemcpyHostToDevice)); - ck_tile::hip_check_error(hipMemcpy(d_B, - b_k_n_dev_buf.GetDeviceBuffer(), - b_k_n.get_element_space_size_in_bytes(), - hipMemcpyHostToDevice)); + ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); - ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(), - d_C, - c_m_n_dev_result.get_element_space_size_in_bytes(), - hipMemcpyDeviceToHost)); - - ck_tile::hip_check_error(hipFree(d_A)); - ck_tile::hip_check_error(hipFree(d_B)); - ck_tile::hip_check_error(hipFree(d_C)); - - c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); - const float max_accumulated_value = - *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol( - K, kbatch, max_accumulated_value); - pass = ck_tile::check_err(c_m_n_dev_result, - c_m_n_gpu_ref, - "Error: Incorrect results!", - rtol_atol.at(ck_tile::number<0>{}), - rtol_atol.at(ck_tile::number<1>{})); - - std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) - << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) - << std::endl; - std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl; + c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data()); } - return pass; } diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index e449dff94d..cfefd38cd2 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -447,6 +447,17 @@ struct GemmKernel {{ return ave_time; }} + static std::string get_name() {{ + return std::string("GemmKernel> kernel_map; + std::function& c_m_n_host_result, + ck_tile::HostTensor& c_m_n_dev_result, + int verify, ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>> kernel_map; return kernel_map; } @@ -499,9 +513,12 @@ struct GemmDispatcher { for group in self.all_kernels: - content += f""" kernel_map["{group}"] = [](ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& s) {{ - std::vector results;""" + content += f""" kernel_map["{group}"] = [](ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::HostTensor& c_m_n_dev_result, + int verify, ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& s) {{ + """ for tile in tile_params: # Check if we have valid tile/warp combinations # (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m @@ -509,21 +526,46 @@ struct GemmDispatcher { ((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]): continue content += f""" - //we can have multiple tiles config for the one kernel_trait - return {group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>::launch(args, s);""" - content += """ - };\n""" + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, s);""" + content += f""" + }};\n""" content += """ } - - static float dispatch(const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args, + template + static void run_kernel(ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::HostTensor& c_m_n_dev_result, + int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + { + float avg_time = Kernel::launch(args, s); + std::string description = Kernel::get_name(); + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + + std::size_t flop = std::size_t(2) * args.M * args.N * args.K; + std::size_t num_byte = sizeof(ADataType) * args.M * args.K + sizeof(BDataType) * args.N * args.K + sizeof(CDataType) * args.M * args.N; + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_byte / 1.E6 / avg_time; + + std::cout << "Performance for " << description << " : " << avg_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + + if(verify) + compare(args.K, args.k_batch, c_m_n_dev_result, c_m_n_host_result); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + } + + static auto dispatch(ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::HostTensor& c_m_n_dev_result, + int verify, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args, const ck_tile::stream_config& s) { init(); const std::string key = assemble_key(trait); auto& kernel_map = get_kernel_map(); if(auto it = kernel_map.find(key); it != kernel_map.end()) { - return it->second(gemm_args, s); //Running single instance + return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify,gemm_args, s); } throw std::runtime_error("No suitable kernel found: " + key); }