Async grouped gemm v3 (#1940)

* Fully async grouped gemm

* Remove commented code

* Remvoe maybe_unused

* host kernel args

* Checkpoint segfault debugging...

* Working part1

* Working part2

* Remvoe comments...

* Use void ptr for gemm kernel host args

* Fix device_grouped_gemm_multiple_d_dl build issue

* Fix device_grouped_gemm_xdl build issue
This commit is contained in:
aledudek
2025-03-17 16:42:43 +01:00
committed by GitHub
parent c2e4898b4b
commit 5095906975
4 changed files with 179 additions and 28 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -173,8 +173,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument);
std::size_t kargs_size = gemm.GetDeviceKernelArgSize(&argument);
std::size_t hargs_size = gemm.GetHostKernelArgSize(&argument);
DeviceMem gemm_workspace, gemm_kargs;
void* gemm_hargs;
// The following is necessary since TwoStage kernel is using additional memory both
// for Workspace and kernel arguments.
@@ -188,6 +190,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
gemm_workspace.Realloc(workspace_size);
gemm.SetWorkSpacePointer(&argument, gemm_workspace.GetDeviceBuffer());
}
if(hargs_size > 0)
{
hip_check_error(hipHostMalloc(&gemm_hargs, hargs_size));
gemm.SetHostKernelArgs(&argument, gemm_hargs);
}
if(!gemm.IsSupportedArgument(argument))
{
@@ -196,7 +203,16 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
"not support this GEMM problem");
}
invoker.Run(argument, StreamConfig{nullptr, false});
hipStream_t stream0 = nullptr;
hip_check_error(hipStreamCreate(&stream0));
hipEvent_t event0 = nullptr;
hip_check_error(hipEventCreate(&event0));
invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0);
hip_check_error(hipEventSynchronize(event0));
hip_check_error(hipStreamSynchronize(stream0));
bool pass = true;
if(config.do_verification)