[rocm-libraries] ROCm/rocm-libraries#7331 (commit 5692db0)

[CK_TILE] Add async workspace prepare to FMHA BWD launcher
 (#7331)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

`aiter::mha_bwd` in group mode currently issues two synchronous
`hipMemcpy` D2H copies to read `seqstart_q/k` for launcher construction.
These sync copies block the host (~10–30 µs each) and implicitly
synchronize the device by draining the stream, breaking CPU/GPU overlap
on hot training paths.

This PR adds a fully stream-async workspace preparation path on the FMHA
BWD launcher so callers can pre-allocate the device workspace from
upper-bound shapes and stage seqstart-dependent metadata via
D2H/host-pack/H2D entirely on the user's stream.

## Technical Details

- `FmhaBwdWorkspaceManager::GetWorkspaceDeviceSizeUpperBound`
(`include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp`): computes the
worst-case device dq_acc size from `(max_batch, hdim_q, nhead_q,
max_seqlen_q, max_seqlen_k)` without dereferencing any seqstart array.
Mirrors `PrepareWorkspaceHost`'s return value with worst-case bounds.
- `fmha_bwd_launcher::prepare_workspace_async`
(`example/ck_tile/01_fmha/fmha_bwd.hpp`): on the caller's stream, in
order:
  1. `hipMemsetAsync` of the dq_acc region (when `NeedsZeroDqAcc()`)
2. group mode: `hipMemcpyAsync` D2H of `seqstart_q/k` into a pinned host
staging buffer
3. `hipLaunchHostFunc` runs `PrepareWorkspaceHost` on the pinned buffer
  4. `hipMemcpyAsync` H2D of the packed metadata into `device_ws_ptr`

The pinned staging buffer is held via `std::shared_ptr<void>` returned
by a caller-provided `pinned_host_alloc` callback. Lifetime is extended
past stream completion by a tail `hipLaunchHostFunc` scheduled in the
launcher's destructor.

- `ck_tile::pinned_host_releaser`
(`include/ck_tile/host/pinned_host_releaser.hpp`): worker-thread utility
for callers using bare `hipHostMalloc`. Defers `hipHostFree` off the HIP
driver callback thread, which holds runtime locks and would deadlock
against concurrent main-thread `hipFree`. PyTorch's
`CachingHostAllocator` does not need this.

- Example runner (`example/ck_tile/01_fmha/fmha_bwd_runner.hpp`):
switched to the async path.

## Test Plan

- `tile_example_fmha_bwd` (gfx950, dev preset `-Werror -Weverything`):
  - batch + nondet / batch + det / group + nondet / group + det
- group + det 4-batch varlen (`-b=4 -h=8 -s=4096,3072,2048,1024 -d=128`)
- FA (`flash-attention`) integration on ROCm 7.1.1 + PyTorch 2.9.1:
  - `tests/test_flash_attn_ck.py::test_flash_attn_varlen_deterministic`
  - `tests/test_flash_attn_ck.py::test_flash_attn_bwd_varlen_seqq_zero`

## Test Result

- All CK runner cases `valid:y`.
- FA pytest: **1952 passed in 44.82s**.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Yi DING
2026-05-14 13:34:32 +00:00
committed by assistant-librarian[bot]
parent 4d852e80fb
commit 83566edb0f
5 changed files with 339 additions and 44 deletions

View File

@@ -0,0 +1,77 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <condition_variable>
#include <hip/hip_runtime.h>
#include <mutex>
#include <queue>
#include <thread>
namespace ck_tile {
// Defers hipHostFree off the HIP callback path. HIP callbacks hold runtime
// locks, so calling hipHostFree (or any HIP API) from one deadlocks against
// concurrent main-thread hipFree. enqueue() is HIP-API-free; a worker thread
// drains the queue and calls hipHostFree. Use instance() for a process-wide
// shared worker.
class pinned_host_releaser
{
std::mutex mtx_;
std::condition_variable cv_;
std::queue<void*> q_;
std::thread worker_;
bool stop_ = false;
void run()
{
for(;;)
{
void* p = nullptr;
{
std::unique_lock<std::mutex> lk(mtx_);
cv_.wait(lk, [&] { return stop_ || !q_.empty(); });
if(q_.empty())
return; // stop_ && empty
p = q_.front();
q_.pop();
}
(void)hipHostFree(p);
}
}
public:
pinned_host_releaser() : worker_([this] { run(); }) {}
~pinned_host_releaser()
{
{
std::lock_guard<std::mutex> lk(mtx_);
stop_ = true;
}
cv_.notify_all();
if(worker_.joinable())
worker_.join();
}
pinned_host_releaser(const pinned_host_releaser&) = delete;
pinned_host_releaser& operator=(const pinned_host_releaser&) = delete;
static pinned_host_releaser& instance()
{
static pinned_host_releaser r;
return r;
}
void enqueue(void* p)
{
{
std::lock_guard<std::mutex> lk(mtx_);
q_.push(p);
}
cv_.notify_one();
}
};
} // namespace ck_tile

View File

@@ -166,6 +166,49 @@ struct FmhaBwdWorkspaceManager
// In these cases we need to zero out it first
return kHasMask;
}
// Upper bound on PrepareWorkspaceHost's size, computable without seqstart so
// the device workspace can be allocated before any D2H.
//
// total_seqlen_q_padded: total q tokens incl. per-batch padding.
// Batch: max_batch * seqlen_q. Group: seqstart_q[batch].
// max_seqlen_k: deterministic-only; pass per-batch padded max if the caller
// does internal k padding, otherwise the logical max is fine.
template <bool kUseQrQtrDorPipeline, index_t kN0>
CK_TILE_HOST static size_t GetWorkspaceDeviceSizeUpperBound(index_t max_batch,
index_t hdim_q,
index_t nhead_q,
index_t total_seqlen_q_padded,
index_t max_seqlen_k)
{
if constexpr(kUseQrQtrDorPipeline)
return 0;
index_t nsplits_factor = 1;
if constexpr(kIsDeterministic)
{
if constexpr(kIsGroupMode)
{
nsplits_factor = integer_divide_ceil(max_seqlen_k, kN0);
}
else // persistent
{
const index_t dqdqkdv_workers = get_num_cus();
const index_t jobs_per_head = integer_divide_ceil(max_seqlen_k, kN0);
const index_t total_jobs = max_batch * nhead_q * jobs_per_head;
const index_t jobs_per_worker = integer_divide_ceil(total_jobs, dqdqkdv_workers);
if(jobs_per_head % jobs_per_worker == 0)
nsplits_factor = jobs_per_head / jobs_per_worker;
else if(jobs_per_worker % jobs_per_head == 0)
nsplits_factor = 1;
else
nsplits_factor = 1 + integer_divide_ceil(jobs_per_head - 1, jobs_per_worker);
}
}
return sizeof(AccDataType) * static_cast<long_index_t>(nhead_q) * nsplits_factor *
total_seqlen_q_padded * hdim_q;
}
};
template <typename FmhaPipeline_,
@@ -281,6 +324,13 @@ struct FmhaBwdDQDKDVKernel
FmhaPipeline::BlockFmhaShape::kN0>(
std::forward<Args>(args)...);
}
template <typename... Args>
CK_TILE_HOST static size_t GetWorkspaceDeviceSizeUpperBound(Args&&... args)
{
return WorkspaceManager::template GetWorkspaceDeviceSizeUpperBound<
kUseQrQtrDorPipeline,
FmhaPipeline::BlockFmhaShape::kN0>(std::forward<Args>(args)...);
}
CK_TILE_HOST static constexpr bool NeedsZeroDqAcc()
{
return WorkspaceManager::template NeedsZeroDqAcc<kUseQrQtrDorPipeline, kHasMask>();