Post-merge changes for fully async args copy in ck grouped gemm (#1991)

* Post-merge changes for fully async args copy in ck grouped gemm

* Post-merge documentation and naming changes

* Build fix and updated changelog

* Revised comments

[ROCm/composable_kernel commit: 9329432f6c]
This commit is contained in:
aledudek
2025-04-03 13:35:43 +02:00
committed by GitHub
parent 49565538fe
commit 7a78bc823a
5 changed files with 68 additions and 16 deletions

View File

@@ -7,6 +7,8 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
### Added
* Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data
* Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels.
* Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced).
* Added support for GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW).
* Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW).
* Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW).

View File

@@ -21,6 +21,7 @@ struct ExecutionConfig final
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
bool async_hargs = false;
};
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
@@ -190,10 +191,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
gemm_workspace.Realloc(workspace_size);
gemm.SetWorkSpacePointer(&argument, gemm_workspace.GetDeviceBuffer());
}
if(hargs_size > 0)
if(config.async_hargs && hargs_size > 0)
{
hip_check_error(hipHostMalloc(&gemm_hargs, hargs_size));
gemm.SetHostKernelArgs(&argument, gemm_hargs);
gemm.SetHostKernelArgsPointer(&argument, gemm_hargs);
}
if(!gemm.IsSupportedArgument(argument))
@@ -203,16 +204,23 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
"not support this GEMM problem");
}
hipStream_t stream0 = nullptr;
hip_check_error(hipStreamCreate(&stream0));
if(!config.async_hargs)
{
invoker.Run(argument, StreamConfig{nullptr, false});
}
else
{
hipStream_t stream0 = nullptr;
hip_check_error(hipStreamCreate(&stream0));
hipEvent_t event0 = nullptr;
hip_check_error(hipEventCreate(&event0));
hipEvent_t event0 = nullptr;
hip_check_error(hipEventCreate(&event0));
invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0);
invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0);
hip_check_error(hipEventSynchronize(event0));
hip_check_error(hipStreamSynchronize(stream0));
hip_check_error(hipEventSynchronize(event0));
hip_check_error(hipStreamSynchronize(stream0));
}
bool pass = true;
if(config.do_verification)
@@ -280,18 +288,25 @@ bool run_grouped_gemm_example(int argc, char* argv[])
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
}
if(argc == 4)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
config.async_hargs = std::stoi(argv[4]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4: async hargs (0=n0, 1=yes)\n");
exit(0);
}

View File

@@ -607,6 +607,9 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
}
}
// If the user provides copy stream and copy event, we assume that they're also
// responsible for providing allocated host memory (eg. pinned) which
// would be used to copy kernel arguments to the device.
if(cpy_stream && cpy_event)
{
if(arg.gemm_kernel_host_args_ == nullptr)
@@ -625,7 +628,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
hipGetErrorString(hipEventRecord(cpy_event, cpy_stream));
hipGetErrorString(hipEventSynchronize(cpy_event));
}
else
else // In this case CK owns memory allocated on host.
{
hipGetErrorString(
hipMemcpyAsync(arg.p_workspace_,
@@ -801,7 +804,15 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
}
void SetHostKernelArgs(BaseArgument* p_arg, void* p_host_kernel_args) const
//----------------------------------------------------------------------------------------------
/// @brief Sets the host kernel arguments pointer and copies that data on the host side.
/// This function can be utilised to use pinned memory for the host args and
/// achieve fully async data copy.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_host_kernel_args The pointer to the host memory where the kernel
/// arguments will be copied
void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const
{
Argument* pArg_ = dynamic_cast<Argument*>(p_arg);
if(!pArg_)

View File

@@ -560,6 +560,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
}
}
// If the user provides copy stream and copy event, we assume that they're also
// responsible for providing allocated host memory (eg. pinned) which
// would be used to copy kernel arguments to the device.
if(cpy_stream && cpy_event)
{
if(arg.gemm_kernel_host_args_ == nullptr)
@@ -578,7 +581,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
hipGetErrorString(hipEventRecord(cpy_event, cpy_stream));
hipGetErrorString(hipEventSynchronize(cpy_event));
}
else
else // In this case CK owns memory allocated on host.
{
hipGetErrorString(hipMemcpyAsync(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(),
@@ -763,7 +766,16 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); }
void SetHostKernelArgs(BaseArgument* p_arg, void* p_host_kernel_args) const
//----------------------------------------------------------------------------------------------
/// @brief Sets the host kernel arguments pointer and copies that data on the host side.
/// This function can be utilised to use pinned memory for the host args and
/// achieve fully async data copy.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_host_kernel_args The pointer to the host memory where the kernel
/// arguments will be copied
///
void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const
{
Argument* pArg_ = dynamic_cast<Argument*>(p_arg);
if(!pArg_)

View File

@@ -423,6 +423,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
}
}
// If the user provides copy stream and copy event, we assume that they're also
// responsible for providing allocated host memory (eg. pinned) which
// would be used to copy kernel arguments to the device.
if(cpy_stream && cpy_event)
{
if(arg.gemm_kernel_host_args_ == nullptr)
@@ -441,7 +444,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
hip_check_error(hipEventRecord(cpy_event, cpy_stream));
hip_check_error(hipEventSynchronize(cpy_event));
}
else
else // In this case CK owns memory allocated on host.
{
hip_check_error(
@@ -702,7 +705,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
}
void SetHostKernelArgs(BaseArgument* p_arg, void* p_host_kernel_args) const
//----------------------------------------------------------------------------------------------
/// @brief Sets the host kernel arguments pointer and copies that data on the host side.
/// This function can be utilised to use pinned memory for the host args and
/// achieve fully async data copy.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_host_kernel_args The pointer to the host memory where the kernel
/// arguments will be copied
///
void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const
{
Argument* pArg_ = dynamic_cast<Argument*>(p_arg);
if(!pArg_)