mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Post-merge changes for fully async args copy in ck grouped gemm (#1991)
* Post-merge changes for fully async args copy in ck grouped gemm
* Post-merge documentation and naming changes
* Build fix and updated changelog
* Revised comments
[ROCm/composable_kernel commit: 9329432f6c]
This commit is contained in:
@@ -21,6 +21,7 @@ struct ExecutionConfig final
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
bool async_hargs = false;
|
||||
};
|
||||
|
||||
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
@@ -190,10 +191,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
gemm_workspace.Realloc(workspace_size);
|
||||
gemm.SetWorkSpacePointer(&argument, gemm_workspace.GetDeviceBuffer());
|
||||
}
|
||||
if(hargs_size > 0)
|
||||
if(config.async_hargs && hargs_size > 0)
|
||||
{
|
||||
hip_check_error(hipHostMalloc(&gemm_hargs, hargs_size));
|
||||
gemm.SetHostKernelArgs(&argument, gemm_hargs);
|
||||
gemm.SetHostKernelArgsPointer(&argument, gemm_hargs);
|
||||
}
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
@@ -203,16 +204,23 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
hipStream_t stream0 = nullptr;
|
||||
hip_check_error(hipStreamCreate(&stream0));
|
||||
if(!config.async_hargs)
|
||||
{
|
||||
invoker.Run(argument, StreamConfig{nullptr, false});
|
||||
}
|
||||
else
|
||||
{
|
||||
hipStream_t stream0 = nullptr;
|
||||
hip_check_error(hipStreamCreate(&stream0));
|
||||
|
||||
hipEvent_t event0 = nullptr;
|
||||
hip_check_error(hipEventCreate(&event0));
|
||||
hipEvent_t event0 = nullptr;
|
||||
hip_check_error(hipEventCreate(&event0));
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0);
|
||||
invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0);
|
||||
|
||||
hip_check_error(hipEventSynchronize(event0));
|
||||
hip_check_error(hipStreamSynchronize(stream0));
|
||||
hip_check_error(hipEventSynchronize(event0));
|
||||
hip_check_error(hipStreamSynchronize(stream0));
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
@@ -280,18 +288,25 @@ bool run_grouped_gemm_example(int argc, char* argv[])
|
||||
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
}
|
||||
|
||||
if(argc == 4)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 5)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
config.async_hargs = std::stoi(argv[4]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4: async hargs (0=n0, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user