mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Async grouped gemm v3 (#1940)
* Fully async grouped gemm * Remove commented code * Remvoe maybe_unused * host kernel args * Checkpoint segfault debugging... * Working part1 * Working part2 * Remvoe comments... * Use void ptr for gemm kernel host args * Fix device_grouped_gemm_multiple_d_dl build issue * Fix device_grouped_gemm_xdl build issue
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -173,8 +173,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
|
||||
std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument);
|
||||
std::size_t kargs_size = gemm.GetDeviceKernelArgSize(&argument);
|
||||
std::size_t hargs_size = gemm.GetHostKernelArgSize(&argument);
|
||||
|
||||
DeviceMem gemm_workspace, gemm_kargs;
|
||||
void* gemm_hargs;
|
||||
|
||||
// The following is necessary since TwoStage kernel is using additional memory both
|
||||
// for Workspace and kernel arguments.
|
||||
@@ -188,6 +190,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
gemm_workspace.Realloc(workspace_size);
|
||||
gemm.SetWorkSpacePointer(&argument, gemm_workspace.GetDeviceBuffer());
|
||||
}
|
||||
if(hargs_size > 0)
|
||||
{
|
||||
hip_check_error(hipHostMalloc(&gemm_hargs, hargs_size));
|
||||
gemm.SetHostKernelArgs(&argument, gemm_hargs);
|
||||
}
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
@@ -196,7 +203,16 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false});
|
||||
hipStream_t stream0 = nullptr;
|
||||
hip_check_error(hipStreamCreate(&stream0));
|
||||
|
||||
hipEvent_t event0 = nullptr;
|
||||
hip_check_error(hipEventCreate(&event0));
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0);
|
||||
|
||||
hip_check_error(hipEventSynchronize(event0));
|
||||
hip_check_error(hipStreamSynchronize(stream0));
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -420,7 +420,8 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
cde_element_op_{cde_element_op},
|
||||
gemm_kernel_host_args_{nullptr}
|
||||
{
|
||||
grid_size_ = 0;
|
||||
|
||||
@@ -538,6 +539,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
|
||||
std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
|
||||
|
||||
index_t grid_size_;
|
||||
void* gemm_kernel_host_args_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -545,7 +547,10 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
float Run(const Argument& arg,
|
||||
const StreamConfig& stream_config = StreamConfig{},
|
||||
hipStream_t cpy_stream = nullptr,
|
||||
hipEvent_t cpy_event = nullptr)
|
||||
{
|
||||
auto K0 = arg.gemm_desc_kernel_arg_[0].a_grid_desc_k0_m_k1_.GetLength(I0);
|
||||
bool all_has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
|
||||
@@ -602,12 +607,33 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
|
||||
}
|
||||
}
|
||||
|
||||
hipGetErrorString(
|
||||
hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_desc_kernel_arg_.data(),
|
||||
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
if(cpy_stream && cpy_event)
|
||||
{
|
||||
if(arg.gemm_kernel_host_args_ == nullptr)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "No memory has been allocated for gemm kernel host args "
|
||||
<< "when providing the copy stream and copy event! In " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
hipGetErrorString(hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_kernel_host_args_,
|
||||
arg.group_count_ * sizeof(GemmKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
cpy_stream));
|
||||
hipGetErrorString(hipEventRecord(cpy_event, cpy_stream));
|
||||
hipGetErrorString(hipEventSynchronize(cpy_event));
|
||||
}
|
||||
else
|
||||
{
|
||||
hipGetErrorString(
|
||||
hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_desc_kernel_arg_.data(),
|
||||
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop,
|
||||
auto has_double_tail_k_block_loop) {
|
||||
@@ -762,6 +788,32 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
|
||||
{
|
||||
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmKernelArg);
|
||||
}
|
||||
|
||||
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
return GetWorkSpaceSize(p_arg);
|
||||
}
|
||||
|
||||
size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); }
|
||||
|
||||
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
|
||||
{
|
||||
return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
|
||||
}
|
||||
|
||||
void SetHostKernelArgs(BaseArgument* p_arg, void* p_host_kernel_args) const
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(p_arg);
|
||||
if(!pArg_)
|
||||
{
|
||||
throw std::runtime_error("Failed to cast argument pointer!");
|
||||
}
|
||||
|
||||
pArg_->gemm_kernel_host_args_ = p_host_kernel_args;
|
||||
std::copy(pArg_->gemm_desc_kernel_arg_.begin(),
|
||||
pArg_->gemm_desc_kernel_arg_.end(),
|
||||
static_cast<GemmKernelArg*>(pArg_->gemm_kernel_host_args_));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -500,6 +500,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
|
||||
|
||||
index_t grid_size_;
|
||||
void* gemm_kernel_host_args_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -507,7 +508,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
float Run(const Argument& arg,
|
||||
const StreamConfig& stream_config = StreamConfig{},
|
||||
hipStream_t cpy_stream = nullptr,
|
||||
hipEvent_t cpy_event = nullptr)
|
||||
{
|
||||
bool has_main_k_block_loop = true;
|
||||
|
||||
@@ -556,12 +560,33 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
}
|
||||
}
|
||||
|
||||
hipGetErrorString(
|
||||
hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_desc_kernel_arg_.data(),
|
||||
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
if(cpy_stream && cpy_event)
|
||||
{
|
||||
if(arg.gemm_kernel_host_args_ == nullptr)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "No memory has been allocated for gemm kernel host args "
|
||||
<< "when providing the copy stream and copy event! In " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
hipGetErrorString(hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_kernel_host_args_,
|
||||
arg.group_count_ * sizeof(GemmBiasTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
cpy_stream));
|
||||
hipGetErrorString(hipEventRecord(cpy_event, cpy_stream));
|
||||
hipGetErrorString(hipEventSynchronize(cpy_event));
|
||||
}
|
||||
else
|
||||
{
|
||||
hipGetErrorString(hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_desc_kernel_arg_.data(),
|
||||
arg.gemm_desc_kernel_arg_.size() *
|
||||
sizeof(GemmBiasTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
@@ -735,6 +760,22 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
{
|
||||
return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
|
||||
}
|
||||
|
||||
size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); }
|
||||
|
||||
void SetHostKernelArgs(BaseArgument* p_arg, void* p_host_kernel_args) const
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(p_arg);
|
||||
if(!pArg_)
|
||||
{
|
||||
throw std::runtime_error("Failed to cast argument pointer!");
|
||||
}
|
||||
|
||||
pArg_->gemm_kernel_host_args_ = p_host_kernel_args;
|
||||
std::copy(pArg_->gemm_desc_kernel_arg_.begin(),
|
||||
pArg_->gemm_desc_kernel_arg_.end(),
|
||||
static_cast<GemmBiasTransKernelArg*>(pArg_->gemm_kernel_host_args_));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -244,7 +244,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
index_t kbatch)
|
||||
: K_BATCH{kbatch}
|
||||
: K_BATCH{kbatch}, gemm_kernel_host_args_{nullptr}
|
||||
{
|
||||
grid_size_ = 0;
|
||||
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
|
||||
@@ -365,13 +365,17 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
index_t skipped_group_count_;
|
||||
|
||||
std::vector<GemmTransKernelArg> gemm_kernel_args_;
|
||||
void* gemm_kernel_host_args_;
|
||||
index_t grid_size_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
float Run(const Argument& arg,
|
||||
const StreamConfig& stream_config = StreamConfig{},
|
||||
hipStream_t cpy_stream = nullptr,
|
||||
hipEvent_t cpy_event = nullptr)
|
||||
{
|
||||
index_t K0 = arg.gemm_kernel_args_[0].karg_.K0Padded;
|
||||
bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1;
|
||||
@@ -419,12 +423,34 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
}
|
||||
}
|
||||
|
||||
hip_check_error(
|
||||
hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_kernel_args_.data(),
|
||||
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
if(cpy_stream && cpy_event)
|
||||
{
|
||||
if(arg.gemm_kernel_host_args_ == nullptr)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "No memory has been allocated for gemm kernel host args "
|
||||
<< "when providing the copy stream and copy event! In " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
hip_check_error(hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_kernel_host_args_,
|
||||
arg.group_count_ * sizeof(GemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
cpy_stream));
|
||||
hip_check_error(hipEventRecord(cpy_event, cpy_stream));
|
||||
hip_check_error(hipEventSynchronize(cpy_event));
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
hip_check_error(
|
||||
hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_kernel_args_.data(),
|
||||
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
@@ -652,6 +678,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
return GetWorkSpaceSize(p_arg);
|
||||
}
|
||||
|
||||
size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); }
|
||||
|
||||
// TODO: deperecation notice.
|
||||
static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
|
||||
|
||||
@@ -673,6 +701,20 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
{
|
||||
return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
|
||||
}
|
||||
|
||||
void SetHostKernelArgs(BaseArgument* p_arg, void* p_host_kernel_args) const
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(p_arg);
|
||||
if(!pArg_)
|
||||
{
|
||||
throw std::runtime_error("Failed to cast argument pointer!");
|
||||
}
|
||||
|
||||
pArg_->gemm_kernel_host_args_ = p_host_kernel_args;
|
||||
std::copy(pArg_->gemm_kernel_args_.begin(),
|
||||
pArg_->gemm_kernel_args_.end(),
|
||||
static_cast<GemmTransKernelArg*>(pArg_->gemm_kernel_host_args_));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
Reference in New Issue
Block a user