diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 76c72fc159..a06e679cde 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -613,23 +613,31 @@ struct fmha_bwd_launcher const std::function(size_t)>& pinned_host_alloc) { hipStream_t stream = s.stream_id_; - if(needs_zero_dq_acc_ && workspace_size > host_ws_size_) - HIP_CHECK_ERROR(hipMemsetAsync(static_cast(device_ws_ptr) + host_ws_size_, - 0, - workspace_size - host_ws_size_, - stream)); + // Fast path: no host-side metadata to stage; just zero dq_acc if needed. if(host_ws_size_ == 0) + { + if(needs_zero_dq_acc_ && workspace_size > 0) + HIP_CHECK_ERROR(hipMemsetAsync(device_ws_ptr, 0, workspace_size, stream)); return; + } if(!pinned_host_alloc) throw std::runtime_error( "fmha_bwd_launcher::prepare_workspace_async: pinned_host_alloc is required"); + // Allocate pinned host staging first: if it throws we haven't issued any + // stream work yet, leaving the workspace cleanly un-prepared. const size_t seqstart_bytes = traits_.is_group_mode ? sizeof(int) * (traits_.batch + 1) : 0; const size_t total_bytes = 2 * seqstart_bytes + host_ws_size_; auto pin_base = pinned_host_alloc(total_bytes); + if(needs_zero_dq_acc_ && workspace_size > host_ws_size_) + HIP_CHECK_ERROR(hipMemsetAsync(static_cast(device_ws_ptr) + host_ws_size_, + 0, + workspace_size - host_ws_size_, + stream)); + char* base = static_cast(pin_base.get()); int* pin_q = reinterpret_cast(base); int* pin_k = reinterpret_cast(base + seqstart_bytes); @@ -663,6 +671,10 @@ struct fmha_bwd_launcher } catch(const std::exception& e) { + // The H2D queued after this callback will copy indeterminate + // metadata to device and the kernel will produce wrong results; + // unlikely in practice since pack_workspace_host_ only throws on + // precondition violations. std::cerr << "fmha_bwd_launcher: pack_workspace_host threw: " << e.what() << '\n'; } @@ -706,7 +718,7 @@ struct fmha_bwd_launcher { if(!pin_staging_) return; - auto* heap_ref = new std::shared_ptr(pin_staging_); + auto* heap_ref = new std::shared_ptr(std::move(pin_staging_)); const hipError_t err = hipLaunchHostFunc( release_stream_, [](void* ud) { delete static_cast*>(ud); },