mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Async grouped gemm v3 (#1940)
* Fully async grouped gemm * Remove commented code * Remvoe maybe_unused * host kernel args * Checkpoint segfault debugging... * Working part1 * Working part2 * Remvoe comments... * Use void ptr for gemm kernel host args * Fix device_grouped_gemm_multiple_d_dl build issue * Fix device_grouped_gemm_xdl build issue
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -173,8 +173,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
|
||||
std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument);
|
||||
std::size_t kargs_size = gemm.GetDeviceKernelArgSize(&argument);
|
||||
std::size_t hargs_size = gemm.GetHostKernelArgSize(&argument);
|
||||
|
||||
DeviceMem gemm_workspace, gemm_kargs;
|
||||
void* gemm_hargs;
|
||||
|
||||
// The following is necessary since TwoStage kernel is using additional memory both
|
||||
// for Workspace and kernel arguments.
|
||||
@@ -188,6 +190,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
gemm_workspace.Realloc(workspace_size);
|
||||
gemm.SetWorkSpacePointer(&argument, gemm_workspace.GetDeviceBuffer());
|
||||
}
|
||||
if(hargs_size > 0)
|
||||
{
|
||||
hip_check_error(hipHostMalloc(&gemm_hargs, hargs_size));
|
||||
gemm.SetHostKernelArgs(&argument, gemm_hargs);
|
||||
}
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
@@ -196,7 +203,16 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false});
|
||||
hipStream_t stream0 = nullptr;
|
||||
hip_check_error(hipStreamCreate(&stream0));
|
||||
|
||||
hipEvent_t event0 = nullptr;
|
||||
hip_check_error(hipEventCreate(&event0));
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0);
|
||||
|
||||
hip_check_error(hipEventSynchronize(event0));
|
||||
hip_check_error(hipStreamSynchronize(stream0));
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
|
||||
Reference in New Issue
Block a user