[CK_TILE] FMHA BWD launcher: address PR #7331 review comments (round 2)

- prepare_workspace_async: allocate pinned host staging before enqueuing
  the dq_acc memset. If pinned_host_alloc throws, no stream work has
  been issued yet, so the workspace is left cleanly un-prepared rather
  than half-initialized.
- pack_workspace_host catch: note that the H2D queued after the
  callback will copy indeterminate metadata if the catch fires (kernel
  will produce wrong results); unlikely since pack only throws on
  precondition violations.
- schedule_pin_staging_release: std::move pin_staging_ into the heap
  shared_ptr; the next line in prepare_workspace_async overwrites it,
  so the extra atomic inc/dec from a copy is wasted.
This commit is contained in:
Ding, Yi
2026-05-13 02:52:49 -04:00
parent 422b6d6c16
commit ce838e19e5

View File

@@ -613,23 +613,31 @@ struct fmha_bwd_launcher
const std::function<std::shared_ptr<void>(size_t)>& pinned_host_alloc)
{
hipStream_t stream = s.stream_id_;
if(needs_zero_dq_acc_ && workspace_size > host_ws_size_)
HIP_CHECK_ERROR(hipMemsetAsync(static_cast<char*>(device_ws_ptr) + host_ws_size_,
0,
workspace_size - host_ws_size_,
stream));
// Fast path: no host-side metadata to stage; just zero dq_acc if needed.
if(host_ws_size_ == 0)
{
if(needs_zero_dq_acc_ && workspace_size > 0)
HIP_CHECK_ERROR(hipMemsetAsync(device_ws_ptr, 0, workspace_size, stream));
return;
}
if(!pinned_host_alloc)
throw std::runtime_error(
"fmha_bwd_launcher::prepare_workspace_async: pinned_host_alloc is required");
// Allocate pinned host staging first: if it throws we haven't issued any
// stream work yet, leaving the workspace cleanly un-prepared.
const size_t seqstart_bytes = traits_.is_group_mode ? sizeof(int) * (traits_.batch + 1) : 0;
const size_t total_bytes = 2 * seqstart_bytes + host_ws_size_;
auto pin_base = pinned_host_alloc(total_bytes);
if(needs_zero_dq_acc_ && workspace_size > host_ws_size_)
HIP_CHECK_ERROR(hipMemsetAsync(static_cast<char*>(device_ws_ptr) + host_ws_size_,
0,
workspace_size - host_ws_size_,
stream));
char* base = static_cast<char*>(pin_base.get());
int* pin_q = reinterpret_cast<int*>(base);
int* pin_k = reinterpret_cast<int*>(base + seqstart_bytes);
@@ -663,6 +671,10 @@ struct fmha_bwd_launcher
}
catch(const std::exception& e)
{
// The H2D queued after this callback will copy indeterminate
// metadata to device and the kernel will produce wrong results;
// unlikely in practice since pack_workspace_host_ only throws on
// precondition violations.
std::cerr << "fmha_bwd_launcher: pack_workspace_host threw: " << e.what()
<< '\n';
}
@@ -706,7 +718,7 @@ struct fmha_bwd_launcher
{
if(!pin_staging_)
return;
auto* heap_ref = new std::shared_ptr<void>(pin_staging_);
auto* heap_ref = new std::shared_ptr<void>(std::move(pin_staging_));
const hipError_t err = hipLaunchHostFunc(
release_stream_,
[](void* ud) { delete static_cast<std::shared_ptr<void>*>(ud); },