mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
[CK_TILE] FMHA BWD launcher: address PR #7331 review comments (round 2)
- prepare_workspace_async: allocate pinned host staging before enqueuing the dq_acc memset. If pinned_host_alloc throws, no stream work has been issued yet, so the workspace is left cleanly un-prepared rather than half-initialized. - pack_workspace_host catch: note that the H2D queued after the callback will copy indeterminate metadata if the catch fires (kernel will produce wrong results); unlikely since pack only throws on precondition violations. - schedule_pin_staging_release: std::move pin_staging_ into the heap shared_ptr; the next line in prepare_workspace_async overwrites it, so the extra atomic inc/dec from a copy is wasted.
This commit is contained in:
@@ -613,23 +613,31 @@ struct fmha_bwd_launcher
|
||||
const std::function<std::shared_ptr<void>(size_t)>& pinned_host_alloc)
|
||||
{
|
||||
hipStream_t stream = s.stream_id_;
|
||||
if(needs_zero_dq_acc_ && workspace_size > host_ws_size_)
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(static_cast<char*>(device_ws_ptr) + host_ws_size_,
|
||||
0,
|
||||
workspace_size - host_ws_size_,
|
||||
stream));
|
||||
|
||||
// Fast path: no host-side metadata to stage; just zero dq_acc if needed.
|
||||
if(host_ws_size_ == 0)
|
||||
{
|
||||
if(needs_zero_dq_acc_ && workspace_size > 0)
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(device_ws_ptr, 0, workspace_size, stream));
|
||||
return;
|
||||
}
|
||||
|
||||
if(!pinned_host_alloc)
|
||||
throw std::runtime_error(
|
||||
"fmha_bwd_launcher::prepare_workspace_async: pinned_host_alloc is required");
|
||||
|
||||
// Allocate pinned host staging first: if it throws we haven't issued any
|
||||
// stream work yet, leaving the workspace cleanly un-prepared.
|
||||
const size_t seqstart_bytes = traits_.is_group_mode ? sizeof(int) * (traits_.batch + 1) : 0;
|
||||
const size_t total_bytes = 2 * seqstart_bytes + host_ws_size_;
|
||||
auto pin_base = pinned_host_alloc(total_bytes);
|
||||
|
||||
if(needs_zero_dq_acc_ && workspace_size > host_ws_size_)
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(static_cast<char*>(device_ws_ptr) + host_ws_size_,
|
||||
0,
|
||||
workspace_size - host_ws_size_,
|
||||
stream));
|
||||
|
||||
char* base = static_cast<char*>(pin_base.get());
|
||||
int* pin_q = reinterpret_cast<int*>(base);
|
||||
int* pin_k = reinterpret_cast<int*>(base + seqstart_bytes);
|
||||
@@ -663,6 +671,10 @@ struct fmha_bwd_launcher
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
// The H2D queued after this callback will copy indeterminate
|
||||
// metadata to device and the kernel will produce wrong results;
|
||||
// unlikely in practice since pack_workspace_host_ only throws on
|
||||
// precondition violations.
|
||||
std::cerr << "fmha_bwd_launcher: pack_workspace_host threw: " << e.what()
|
||||
<< '\n';
|
||||
}
|
||||
@@ -706,7 +718,7 @@ struct fmha_bwd_launcher
|
||||
{
|
||||
if(!pin_staging_)
|
||||
return;
|
||||
auto* heap_ref = new std::shared_ptr<void>(pin_staging_);
|
||||
auto* heap_ref = new std::shared_ptr<void>(std::move(pin_staging_));
|
||||
const hipError_t err = hipLaunchHostFunc(
|
||||
release_stream_,
|
||||
[](void* ud) { delete static_cast<std::shared_ptr<void>*>(ud); },
|
||||
|
||||
Reference in New Issue
Block a user