[CK_TILE] FMHA BWD launcher: address PR #7331 review comments

This commit is contained in:
Ding, Yi
2026-05-12 02:46:41 -05:00
parent d434410e52
commit 1f4cc34e68
3 changed files with 27 additions and 8 deletions

View File

@@ -636,22 +636,41 @@ struct fmha_bwd_launcher
if(traits_.is_group_mode)
{
if(!seqstart_q_dev || !seqstart_k_dev)
throw std::runtime_error("fmha_bwd_launcher::prepare_workspace_async: "
"seqstart_q_dev and seqstart_k_dev are required in "
"group mode");
HIP_CHECK_ERROR(hipMemcpyAsync(
pin_q, seqstart_q_dev, seqstart_bytes, hipMemcpyDeviceToHost, stream));
HIP_CHECK_ERROR(hipMemcpyAsync(
pin_k, seqstart_k_dev, seqstart_bytes, hipMemcpyDeviceToHost, stream));
}
auto* pack_closure = new std::function<void()>(
auto pack_closure = std::make_unique<std::function<void()>>(
[=, fn = pack_workspace_host_]() { fn(pin_w, seqstart_q_pinned, seqstart_k_pinned); });
// Callback runs on the HIP driver helper thread across a C ABI boundary;
// any exception escaping it would call std::terminate.
HIP_CHECK_ERROR(hipLaunchHostFunc(
stream,
[](void* ud) {
auto* c = static_cast<std::function<void()>*>(ud);
(*c)();
delete c;
std::unique_ptr<std::function<void()>> c{static_cast<std::function<void()>*>(ud)};
try
{
(*c)();
}
catch(const std::exception& e)
{
std::cerr << "fmha_bwd_launcher: pack_workspace_host threw: " << e.what()
<< '\n';
}
catch(...)
{
std::cerr << "fmha_bwd_launcher: pack_workspace_host threw unknown\n";
}
},
pack_closure));
pack_closure.get()));
// Ownership transferred to the callback only after a successful launch.
pack_closure.release();
HIP_CHECK_ERROR(
hipMemcpyAsync(device_ws_ptr, pin_w, host_ws_size_, hipMemcpyHostToDevice, stream));

View File

@@ -418,7 +418,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
};
ck_tile::gpu_timer prepare_ws_timer;
prepare_ws_timer.start(nullptr);
prepare_ws_timer.start(stream_config.stream_id_);
launcher.prepare_workspace_async(
ws_buf.GetDeviceBuffer(),
(mode == mode_enum::group) ? static_cast<const int*>(seqstart_q.GetDeviceBuffer())
@@ -427,7 +427,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
: nullptr,
stream_config,
pinned_host_alloc);
prepare_ws_timer.stop(nullptr);
prepare_ws_timer.stop(stream_config.stream_id_);
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());

View File

@@ -323,7 +323,7 @@ struct FmhaBwdDQDKDVKernel
std::forward<Args>(args)...);
}
template <typename... Args>
CK_TILE_HOST static constexpr auto GetWorkspaceDeviceSizeUpperBound(Args&&... args)
CK_TILE_HOST static size_t GetWorkspaceDeviceSizeUpperBound(Args&&... args)
{
return WorkspaceManager::template GetWorkspaceDeviceSizeUpperBound<
kUseQrQtrDorPipeline,