Post-merge changes for fully async args copy in ck grouped gemm (#1991)

* Post-merge changes for fully async args copy in ck grouped gemm

* Post-merge documentation and naming changes

* Build fix and updated changelog

* Revised comments

[ROCm/composable_kernel commit: 9329432f6c]
This commit is contained in:
aledudek
2025-04-03 13:35:43 +02:00
committed by GitHub
parent 169e3cb4f8
commit b7359bcfac
5 changed files with 68 additions and 16 deletions

View File

@@ -21,6 +21,7 @@ struct ExecutionConfig final
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
bool async_hargs = false;
};
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
@@ -190,10 +191,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
gemm_workspace.Realloc(workspace_size);
gemm.SetWorkSpacePointer(&argument, gemm_workspace.GetDeviceBuffer());
}
if(hargs_size > 0)
if(config.async_hargs && hargs_size > 0)
{
hip_check_error(hipHostMalloc(&gemm_hargs, hargs_size));
gemm.SetHostKernelArgs(&argument, gemm_hargs);
gemm.SetHostKernelArgsPointer(&argument, gemm_hargs);
}
if(!gemm.IsSupportedArgument(argument))
@@ -203,16 +204,23 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
"not support this GEMM problem");
}
hipStream_t stream0 = nullptr;
hip_check_error(hipStreamCreate(&stream0));
if(!config.async_hargs)
{
invoker.Run(argument, StreamConfig{nullptr, false});
}
else
{
hipStream_t stream0 = nullptr;
hip_check_error(hipStreamCreate(&stream0));
hipEvent_t event0 = nullptr;
hip_check_error(hipEventCreate(&event0));
hipEvent_t event0 = nullptr;
hip_check_error(hipEventCreate(&event0));
invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0);
invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0);
hip_check_error(hipEventSynchronize(event0));
hip_check_error(hipStreamSynchronize(stream0));
hip_check_error(hipEventSynchronize(event0));
hip_check_error(hipStreamSynchronize(stream0));
}
bool pass = true;
if(config.do_verification)
@@ -280,18 +288,25 @@ bool run_grouped_gemm_example(int argc, char* argv[])
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
}
if(argc == 4)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
config.async_hargs = std::stoi(argv[4]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4: async hargs (0=n0, 1=yes)\n");
exit(0);
}