mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] FMHA BWD launcher: address PR #7331 review comments
This commit is contained in:
@@ -636,22 +636,41 @@ struct fmha_bwd_launcher
|
||||
|
||||
if(traits_.is_group_mode)
|
||||
{
|
||||
if(!seqstart_q_dev || !seqstart_k_dev)
|
||||
throw std::runtime_error("fmha_bwd_launcher::prepare_workspace_async: "
|
||||
"seqstart_q_dev and seqstart_k_dev are required in "
|
||||
"group mode");
|
||||
HIP_CHECK_ERROR(hipMemcpyAsync(
|
||||
pin_q, seqstart_q_dev, seqstart_bytes, hipMemcpyDeviceToHost, stream));
|
||||
HIP_CHECK_ERROR(hipMemcpyAsync(
|
||||
pin_k, seqstart_k_dev, seqstart_bytes, hipMemcpyDeviceToHost, stream));
|
||||
}
|
||||
|
||||
auto* pack_closure = new std::function<void()>(
|
||||
auto pack_closure = std::make_unique<std::function<void()>>(
|
||||
[=, fn = pack_workspace_host_]() { fn(pin_w, seqstart_q_pinned, seqstart_k_pinned); });
|
||||
// Callback runs on the HIP driver helper thread across a C ABI boundary;
|
||||
// any exception escaping it would call std::terminate.
|
||||
HIP_CHECK_ERROR(hipLaunchHostFunc(
|
||||
stream,
|
||||
[](void* ud) {
|
||||
auto* c = static_cast<std::function<void()>*>(ud);
|
||||
(*c)();
|
||||
delete c;
|
||||
std::unique_ptr<std::function<void()>> c{static_cast<std::function<void()>*>(ud)};
|
||||
try
|
||||
{
|
||||
(*c)();
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "fmha_bwd_launcher: pack_workspace_host threw: " << e.what()
|
||||
<< '\n';
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
std::cerr << "fmha_bwd_launcher: pack_workspace_host threw unknown\n";
|
||||
}
|
||||
},
|
||||
pack_closure));
|
||||
pack_closure.get()));
|
||||
// Ownership transferred to the callback only after a successful launch.
|
||||
pack_closure.release();
|
||||
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpyAsync(device_ws_ptr, pin_w, host_ws_size_, hipMemcpyHostToDevice, stream));
|
||||
|
||||
@@ -418,7 +418,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
};
|
||||
|
||||
ck_tile::gpu_timer prepare_ws_timer;
|
||||
prepare_ws_timer.start(nullptr);
|
||||
prepare_ws_timer.start(stream_config.stream_id_);
|
||||
launcher.prepare_workspace_async(
|
||||
ws_buf.GetDeviceBuffer(),
|
||||
(mode == mode_enum::group) ? static_cast<const int*>(seqstart_q.GetDeviceBuffer())
|
||||
@@ -427,7 +427,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
: nullptr,
|
||||
stream_config,
|
||||
pinned_host_alloc);
|
||||
prepare_ws_timer.stop(nullptr);
|
||||
prepare_ws_timer.stop(stream_config.stream_id_);
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
|
||||
@@ -323,7 +323,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
template <typename... Args>
|
||||
CK_TILE_HOST static constexpr auto GetWorkspaceDeviceSizeUpperBound(Args&&... args)
|
||||
CK_TILE_HOST static size_t GetWorkspaceDeviceSizeUpperBound(Args&&... args)
|
||||
{
|
||||
return WorkspaceManager::template GetWorkspaceDeviceSizeUpperBound<
|
||||
kUseQrQtrDorPipeline,
|
||||
|
||||
Reference in New Issue
Block a user