From ce838e19e5ea1931041537112d85185bb7ce7fa7 Mon Sep 17 00:00:00 2001 From: "Ding, Yi" Date: Wed, 13 May 2026 02:52:49 -0400 Subject: [PATCH] [CK_TILE] FMHA BWD launcher: address PR #7331 review comments (round 2) - prepare_workspace_async: allocate pinned host staging before enqueuing the dq_acc memset. If pinned_host_alloc throws, no stream work has been issued yet, so the workspace is left cleanly un-prepared rather than half-initialized. - pack_workspace_host catch: note that the H2D queued after the callback will copy indeterminate metadata if the catch fires (kernel will produce wrong results); unlikely since pack only throws on precondition violations. - schedule_pin_staging_release: std::move pin_staging_ into the heap shared_ptr; the next line in prepare_workspace_async overwrites it, so the extra atomic inc/dec from a copy is wasted. --- example/ck_tile/01_fmha/fmha_bwd.hpp | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 76c72fc159..a06e679cde 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -613,23 +613,31 @@ struct fmha_bwd_launcher const std::function(size_t)>& pinned_host_alloc) { hipStream_t stream = s.stream_id_; - if(needs_zero_dq_acc_ && workspace_size > host_ws_size_) - HIP_CHECK_ERROR(hipMemsetAsync(static_cast(device_ws_ptr) + host_ws_size_, - 0, - workspace_size - host_ws_size_, - stream)); + // Fast path: no host-side metadata to stage; just zero dq_acc if needed. if(host_ws_size_ == 0) + { + if(needs_zero_dq_acc_ && workspace_size > 0) + HIP_CHECK_ERROR(hipMemsetAsync(device_ws_ptr, 0, workspace_size, stream)); return; + } if(!pinned_host_alloc) throw std::runtime_error( "fmha_bwd_launcher::prepare_workspace_async: pinned_host_alloc is required"); + // Allocate pinned host staging first: if it throws we haven't issued any + // stream work yet, leaving the workspace cleanly un-prepared. const size_t seqstart_bytes = traits_.is_group_mode ? sizeof(int) * (traits_.batch + 1) : 0; const size_t total_bytes = 2 * seqstart_bytes + host_ws_size_; auto pin_base = pinned_host_alloc(total_bytes); + if(needs_zero_dq_acc_ && workspace_size > host_ws_size_) + HIP_CHECK_ERROR(hipMemsetAsync(static_cast(device_ws_ptr) + host_ws_size_, + 0, + workspace_size - host_ws_size_, + stream)); + char* base = static_cast(pin_base.get()); int* pin_q = reinterpret_cast(base); int* pin_k = reinterpret_cast(base + seqstart_bytes); @@ -663,6 +671,10 @@ struct fmha_bwd_launcher } catch(const std::exception& e) { + // The H2D queued after this callback will copy indeterminate + // metadata to device and the kernel will produce wrong results; + // unlikely in practice since pack_workspace_host_ only throws on + // precondition violations. std::cerr << "fmha_bwd_launcher: pack_workspace_host threw: " << e.what() << '\n'; } @@ -706,7 +718,7 @@ struct fmha_bwd_launcher { if(!pin_staging_) return; - auto* heap_ref = new std::shared_ptr(pin_staging_); + auto* heap_ref = new std::shared_ptr(std::move(pin_staging_)); const hipError_t err = hipLaunchHostFunc( release_stream_, [](void* ud) { delete static_cast*>(ud); },