From b7359bcfac062f3fb482acf3f5a009cff5743b09 Mon Sep 17 00:00:00 2001 From: aledudek Date: Thu, 3 Apr 2025 13:35:43 +0200 Subject: [PATCH] Post-merge changes for fully async args copy in ck grouped gemm (#1991) * Post-merge changes for fully async args copy in ck grouped gemm * Post-merge documentation and naming changes * Build fix and updated changelog * Revised comments [ROCm/composable_kernel commit: 9329432f6c3d4ddd8d5b836245bd44acef89be3d] --- CHANGELOG.md | 2 ++ .../run_grouped_gemm_example.inc | 35 +++++++++++++------ .../device_grouped_gemm_multiple_d_dl.hpp | 15 ++++++-- .../device/impl/device_grouped_gemm_xdl.hpp | 16 +++++++-- ...evice_grouped_gemm_xdl_splitk_cshuffle.hpp | 16 +++++++-- 5 files changed, 68 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f9da2b3117..49ef2998eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Added * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data +* Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels. +* Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced). * Added support for GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW). * Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW). * Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 86b3182a52..7186c22233 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -21,6 +21,7 @@ struct ExecutionConfig final bool do_verification = true; int init_method = 1; bool time_kernel = false; + bool async_hargs = false; }; bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) @@ -190,10 +191,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co gemm_workspace.Realloc(workspace_size); gemm.SetWorkSpacePointer(&argument, gemm_workspace.GetDeviceBuffer()); } - if(hargs_size > 0) + if(config.async_hargs && hargs_size > 0) { hip_check_error(hipHostMalloc(&gemm_hargs, hargs_size)); - gemm.SetHostKernelArgs(&argument, gemm_hargs); + gemm.SetHostKernelArgsPointer(&argument, gemm_hargs); } if(!gemm.IsSupportedArgument(argument)) @@ -203,16 +204,23 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co "not support this GEMM problem"); } - hipStream_t stream0 = nullptr; - hip_check_error(hipStreamCreate(&stream0)); + if(!config.async_hargs) + { + invoker.Run(argument, StreamConfig{nullptr, false}); + } + else + { + hipStream_t stream0 = nullptr; + hip_check_error(hipStreamCreate(&stream0)); - hipEvent_t event0 = nullptr; - hip_check_error(hipEventCreate(&event0)); + hipEvent_t event0 = nullptr; + hip_check_error(hipEventCreate(&event0)); - invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0); + invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0); - hip_check_error(hipEventSynchronize(event0)); - hip_check_error(hipStreamSynchronize(stream0)); + hip_check_error(hipEventSynchronize(event0)); + hip_check_error(hipStreamSynchronize(stream0)); + } bool pass = true; if(config.do_verification) @@ -280,18 +288,25 @@ bool run_grouped_gemm_example(int argc, char* argv[]) problem_size.stride_Bs.push_back(problem_size.Ks[i]); problem_size.stride_Cs.push_back(problem_size.Ns[i]); } - if(argc == 4) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[3]); } + else if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.async_hargs = std::stoi(argv[4]); + } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: async hargs (0=n0, 1=yes)\n"); exit(0); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp index c148d7dbb7..463b10de43 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -607,6 +607,9 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemmSetWorkSpacePointer(p_arg, p_dev_kernel_args); } - void SetHostKernelArgs(BaseArgument* p_arg, void* p_host_kernel_args) const + //---------------------------------------------------------------------------------------------- + /// @brief Sets the host kernel arguments pointer and copies that data on the host side. + /// This function can be utilised to use pinned memory for the host args and + /// achieve fully async data copy. + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_host_kernel_args The pointer to the host memory where the kernel + /// arguments will be copied + void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const { Argument* pArg_ = dynamic_cast(p_arg); if(!pArg_) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index 2a6406aac3..d9a0249da8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -560,6 +560,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm(p_arg); if(!pArg_) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index 03431d7156..a2afb62eec 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -423,6 +423,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitKSetWorkSpacePointer(p_arg, p_dev_kernel_args); } - void SetHostKernelArgs(BaseArgument* p_arg, void* p_host_kernel_args) const + //---------------------------------------------------------------------------------------------- + /// @brief Sets the host kernel arguments pointer and copies that data on the host side. + /// This function can be utilised to use pinned memory for the host args and + /// achieve fully async data copy. + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_host_kernel_args The pointer to the host memory where the kernel + /// arguments will be copied + /// + void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const { Argument* pArg_ = dynamic_cast(p_arg); if(!pArg_)