[rocm-libraries] ROCm/rocm-libraries#7331 (commit 5692db0)

[CK_TILE] Add async workspace prepare to FMHA BWD launcher
 (#7331)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

`aiter::mha_bwd` in group mode currently issues two synchronous
`hipMemcpy` D2H copies to read `seqstart_q/k` for launcher construction.
These sync copies block the host (~10–30 µs each) and implicitly
synchronize the device by draining the stream, breaking CPU/GPU overlap
on hot training paths.

This PR adds a fully stream-async workspace preparation path on the FMHA
BWD launcher so callers can pre-allocate the device workspace from
upper-bound shapes and stage seqstart-dependent metadata via
D2H/host-pack/H2D entirely on the user's stream.

## Technical Details

- `FmhaBwdWorkspaceManager::GetWorkspaceDeviceSizeUpperBound`
(`include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp`): computes the
worst-case device dq_acc size from `(max_batch, hdim_q, nhead_q,
max_seqlen_q, max_seqlen_k)` without dereferencing any seqstart array.
Mirrors `PrepareWorkspaceHost`'s return value with worst-case bounds.
- `fmha_bwd_launcher::prepare_workspace_async`
(`example/ck_tile/01_fmha/fmha_bwd.hpp`): on the caller's stream, in
order:
  1. `hipMemsetAsync` of the dq_acc region (when `NeedsZeroDqAcc()`)
2. group mode: `hipMemcpyAsync` D2H of `seqstart_q/k` into a pinned host
staging buffer
3. `hipLaunchHostFunc` runs `PrepareWorkspaceHost` on the pinned buffer
  4. `hipMemcpyAsync` H2D of the packed metadata into `device_ws_ptr`

The pinned staging buffer is held via `std::shared_ptr<void>` returned
by a caller-provided `pinned_host_alloc` callback. Lifetime is extended
past stream completion by a tail `hipLaunchHostFunc` scheduled in the
launcher's destructor.

- `ck_tile::pinned_host_releaser`
(`include/ck_tile/host/pinned_host_releaser.hpp`): worker-thread utility
for callers using bare `hipHostMalloc`. Defers `hipHostFree` off the HIP
driver callback thread, which holds runtime locks and would deadlock
against concurrent main-thread `hipFree`. PyTorch's
`CachingHostAllocator` does not need this.

- Example runner (`example/ck_tile/01_fmha/fmha_bwd_runner.hpp`):
switched to the async path.

## Test Plan

- `tile_example_fmha_bwd` (gfx950, dev preset `-Werror -Weverything`):
  - batch + nondet / batch + det / group + nondet / group + det
- group + det 4-batch varlen (`-b=4 -h=8 -s=4096,3072,2048,1024 -d=128`)
- FA (`flash-attention`) integration on ROCm 7.1.1 + PyTorch 2.9.1:
  - `tests/test_flash_attn_ck.py::test_flash_attn_varlen_deterministic`
  - `tests/test_flash_attn_ck.py::test_flash_attn_bwd_varlen_seqq_zero`

## Test Result

- All CK runner cases `valid:y`.
- FA pytest: **1952 passed in 44.82s**.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Yi DING
2026-05-14 13:34:32 +00:00
committed by assistant-librarian[bot]
parent 4d852e80fb
commit 83566edb0f
5 changed files with 339 additions and 44 deletions

View File

@@ -0,0 +1,77 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <condition_variable>
#include <hip/hip_runtime.h>
#include <mutex>
#include <queue>
#include <thread>
namespace ck_tile {
// Defers hipHostFree off the HIP callback path. HIP callbacks hold runtime
// locks, so calling hipHostFree (or any HIP API) from one deadlocks against
// concurrent main-thread hipFree. enqueue() is HIP-API-free; a worker thread
// drains the queue and calls hipHostFree. Use instance() for a process-wide
// shared worker.
class pinned_host_releaser
{
std::mutex mtx_;
std::condition_variable cv_;
std::queue<void*> q_;
std::thread worker_;
bool stop_ = false;
void run()
{
for(;;)
{
void* p = nullptr;
{
std::unique_lock<std::mutex> lk(mtx_);
cv_.wait(lk, [&] { return stop_ || !q_.empty(); });
if(q_.empty())
return; // stop_ && empty
p = q_.front();
q_.pop();
}
(void)hipHostFree(p);
}
}
public:
pinned_host_releaser() : worker_([this] { run(); }) {}
~pinned_host_releaser()
{
{
std::lock_guard<std::mutex> lk(mtx_);
stop_ = true;
}
cv_.notify_all();
if(worker_.joinable())
worker_.join();
}
pinned_host_releaser(const pinned_host_releaser&) = delete;
pinned_host_releaser& operator=(const pinned_host_releaser&) = delete;
static pinned_host_releaser& instance()
{
static pinned_host_releaser r;
return r;
}
void enqueue(void* p)
{
{
std::lock_guard<std::mutex> lk(mtx_);
q_.push(p);
}
cv_.notify_one();
}
};
} // namespace ck_tile