[rocm-libraries] ROCm/rocm-libraries#7331 (commit 5692db0)

[CK_TILE] Add async workspace prepare to FMHA BWD launcher
 (#7331)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

`aiter::mha_bwd` in group mode currently issues two synchronous
`hipMemcpy` D2H copies to read `seqstart_q/k` for launcher construction.
These sync copies block the host (~10–30 µs each) and implicitly
synchronize the device by draining the stream, breaking CPU/GPU overlap
on hot training paths.

This PR adds a fully stream-async workspace preparation path on the FMHA
BWD launcher so callers can pre-allocate the device workspace from
upper-bound shapes and stage seqstart-dependent metadata via
D2H/host-pack/H2D entirely on the user's stream.

## Technical Details

- `FmhaBwdWorkspaceManager::GetWorkspaceDeviceSizeUpperBound`
(`include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp`): computes the
worst-case device dq_acc size from `(max_batch, hdim_q, nhead_q,
max_seqlen_q, max_seqlen_k)` without dereferencing any seqstart array.
Mirrors `PrepareWorkspaceHost`'s return value with worst-case bounds.
- `fmha_bwd_launcher::prepare_workspace_async`
(`example/ck_tile/01_fmha/fmha_bwd.hpp`): on the caller's stream, in
order:
  1. `hipMemsetAsync` of the dq_acc region (when `NeedsZeroDqAcc()`)
2. group mode: `hipMemcpyAsync` D2H of `seqstart_q/k` into a pinned host
staging buffer
3. `hipLaunchHostFunc` runs `PrepareWorkspaceHost` on the pinned buffer
  4. `hipMemcpyAsync` H2D of the packed metadata into `device_ws_ptr`

The pinned staging buffer is held via `std::shared_ptr<void>` returned
by a caller-provided `pinned_host_alloc` callback. Lifetime is extended
past stream completion by a tail `hipLaunchHostFunc` scheduled in the
launcher's destructor.

- `ck_tile::pinned_host_releaser`
(`include/ck_tile/host/pinned_host_releaser.hpp`): worker-thread utility
for callers using bare `hipHostMalloc`. Defers `hipHostFree` off the HIP
driver callback thread, which holds runtime locks and would deadlock
against concurrent main-thread `hipFree`. PyTorch's
`CachingHostAllocator` does not need this.

- Example runner (`example/ck_tile/01_fmha/fmha_bwd_runner.hpp`):
switched to the async path.

## Test Plan

- `tile_example_fmha_bwd` (gfx950, dev preset `-Werror -Weverything`):
  - batch + nondet / batch + det / group + nondet / group + det
- group + det 4-batch varlen (`-b=4 -h=8 -s=4096,3072,2048,1024 -d=128`)
- FA (`flash-attention`) integration on ROCm 7.1.1 + PyTorch 2.9.1:
  - `tests/test_flash_attn_ck.py::test_flash_attn_varlen_deterministic`
  - `tests/test_flash_attn_ck.py::test_flash_attn_bwd_varlen_seqq_zero`

## Test Result

- All CK runner cases `valid:y`.
- FA pytest: **1952 passed in 44.82s**.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Yi DING
2026-05-14 13:34:32 +00:00
committed by assistant-librarian[bot]
parent 4d852e80fb
commit 83566edb0f
5 changed files with 339 additions and 44 deletions

View File

@@ -175,6 +175,16 @@ size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
return k_::GetWorkspaceHostSize(batch_size);
}}
template <>
size_t fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
ck_tile::index_t max_batch, ck_tile::index_t hdim_q, ck_tile::index_t nhead_q,
ck_tile::index_t total_seqlen_q_padded, ck_tile::index_t max_seqlen_k)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
return k_::GetWorkspaceDeviceSizeUpperBound(
max_batch, hdim_q, nhead_q, total_seqlen_q_padded, max_seqlen_k);
}}
template <>
size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
void* cpu_ws, ck_tile::index_t batch_size, ck_tile::index_t hdim_q,

View File

@@ -469,6 +469,15 @@ int fmha_bwd_dq_dk_dv_maxq_();
struct fmha_bwd_traits;
template <typename Traits_, typename Arch = void>
size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_(int batch_size);
// `total_seqlen_q_padded` is total q tokens across all batches (incl. per-batch padding):
// - batch mode: max_batch * seqlen_q
// - group mode: seqstart_q[batch] (== varlen q tensor's first dim)
template <typename Traits_, typename Arch = void>
size_t fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_(ck_tile::index_t max_batch,
ck_tile::index_t hdim_q,
ck_tile::index_t nhead_q,
ck_tile::index_t total_seqlen_q_padded,
ck_tile::index_t max_seqlen_k);
template <typename Traits_, typename Arch = void>
size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_(void* cpu_ws,
ck_tile::index_t batch_size,
@@ -539,11 +548,6 @@ struct fmha_bwd_traits
bool has_dropout;
bool is_store_randval;
bool is_deterministic;
// Raw pointers for group mode: cumulative physical seqlen arrays of length batch+1.
// Only need to remain valid during fmha_bwd_launcher construction (i.e. through
// PrepareWorkspaceHost); they are not retained afterward.
const int* seqstart_qs = nullptr;
const int* seqstart_ks = nullptr;
// TODO: padding check is inside this api
};
@@ -589,53 +593,176 @@ struct fmha_bwd_launcher
std::cerr << "fmha_bwd: no kernel found for given traits, skipping run\n";
return -1.0f;
}};
// Layout: [host_ws_size_ bytes (host-prepared metadata)][dq_acc region]
size_t workspace_size = 0;
std::function<void(void*)> prepare_workspace{[](void*) {
std::cerr << "fmha_bwd: no kernel found for given traits, skipping prepare_workspace\n";
}};
fmha_bwd_launcher(const fmha_bwd_traits&);
fmha_bwd_launcher(fmha_bwd_launcher&&) = delete;
fmha_bwd_launcher& operator=(fmha_bwd_launcher&&) = delete;
~fmha_bwd_launcher() noexcept { schedule_pin_staging_release(); }
// Stream-async: zero dq_acc, D2H seqstart, host-pack metadata, H2D into device_ws.
// `pinned_host_alloc` returns a shared_ptr to a pinned host buffer; its deleter
// is invoked on the stream tail after the H2D completes.
void prepare_workspace_async( //
void* device_ws_ptr,
const int* seqstart_q_dev,
const int* seqstart_k_dev,
const ck_tile::stream_config& s,
const std::function<std::shared_ptr<void>(size_t)>& pinned_host_alloc)
{
hipStream_t stream = s.stream_id_;
// Fast path: no host-side metadata to stage; just zero dq_acc if needed.
if(host_ws_size_ == 0)
{
if(needs_zero_dq_acc_ && workspace_size > 0)
HIP_CHECK_ERROR(hipMemsetAsync(device_ws_ptr, 0, workspace_size, stream));
return;
}
if(!pinned_host_alloc)
throw std::runtime_error(
"fmha_bwd_launcher::prepare_workspace_async: pinned_host_alloc is required");
// Allocate pinned host staging first: if it throws we haven't issued any
// stream work yet, leaving the workspace cleanly un-prepared.
const size_t seqstart_bytes = traits_.is_group_mode ? sizeof(int) * (traits_.batch + 1) : 0;
const size_t total_bytes = 2 * seqstart_bytes + host_ws_size_;
auto pin_base = pinned_host_alloc(total_bytes);
if(needs_zero_dq_acc_ && workspace_size > host_ws_size_)
HIP_CHECK_ERROR(hipMemsetAsync(static_cast<char*>(device_ws_ptr) + host_ws_size_,
0,
workspace_size - host_ws_size_,
stream));
char* base = static_cast<char*>(pin_base.get());
int* pin_q = reinterpret_cast<int*>(base);
int* pin_k = reinterpret_cast<int*>(base + seqstart_bytes);
void* pin_w = base + 2 * seqstart_bytes;
const int* seqstart_q_pinned = traits_.is_group_mode ? pin_q : nullptr;
const int* seqstart_k_pinned = traits_.is_group_mode ? pin_k : nullptr;
if(traits_.is_group_mode)
{
if(!seqstart_q_dev || !seqstart_k_dev)
throw std::runtime_error("fmha_bwd_launcher::prepare_workspace_async: "
"seqstart_q_dev and seqstart_k_dev are required in "
"group mode");
HIP_CHECK_ERROR(hipMemcpyAsync(
pin_q, seqstart_q_dev, seqstart_bytes, hipMemcpyDeviceToHost, stream));
HIP_CHECK_ERROR(hipMemcpyAsync(
pin_k, seqstart_k_dev, seqstart_bytes, hipMemcpyDeviceToHost, stream));
}
auto pack_closure = std::make_unique<std::function<void()>>(
[=, fn = pack_workspace_host_]() { fn(pin_w, seqstart_q_pinned, seqstart_k_pinned); });
// Callback runs on the HIP driver helper thread across a C ABI boundary;
// any exception escaping it would call std::terminate.
HIP_CHECK_ERROR(hipLaunchHostFunc(
stream,
[](void* ud) {
std::unique_ptr<std::function<void()>> c{static_cast<std::function<void()>*>(ud)};
try
{
(*c)();
}
catch(const std::exception& e)
{
// The H2D queued after this callback will copy indeterminate
// metadata to device and the kernel will produce wrong results;
// unlikely in practice since pack_workspace_host_ only throws on
// precondition violations.
std::cerr << "fmha_bwd_launcher: pack_workspace_host threw: " << e.what()
<< '\n';
}
catch(...)
{
std::cerr << "fmha_bwd_launcher: pack_workspace_host threw unknown\n";
}
},
pack_closure.get()));
// Ownership transferred to the callback only after a successful launch.
pack_closure.release();
HIP_CHECK_ERROR(
hipMemcpyAsync(device_ws_ptr, pin_w, host_ws_size_, hipMemcpyHostToDevice, stream));
// Release any previous in-flight buffer before taking a new one.
schedule_pin_staging_release();
pin_staging_ = std::move(pin_base);
release_stream_ = stream;
}
private:
size_t host_ws_size = 0;
size_t device_ws_size = 0;
std::unique_ptr<char[]> ws_host;
fmha_bwd_traits traits_{};
size_t host_ws_size_ = 0;
bool needs_zero_dq_acc_ = false;
// Pure CPU; safe to invoke from a hipLaunchHostFunc callback.
std::function<void(void* host_ws, const int* seqstart_q, const int* seqstart_k)>
pack_workspace_host_{[](void*, const int*, const int*) {
std::cerr
<< "fmha_bwd: no kernel found for given traits, skipping pack_workspace_host\n";
}};
std::shared_ptr<void> pin_staging_;
hipStream_t release_stream_ = nullptr;
// The pin_staging_ deleter MUST NOT call any HIP API: it fires from the
// hipLaunchHostFunc callback on the driver helper thread, which holds
// runtime locks (would deadlock against main-thread hipFree). PyTorch's
// CachingHostAllocator is safe; bare hipHostMalloc users should defer
// hipHostFree via ck_tile::pinned_host_releaser.
void schedule_pin_staging_release() noexcept
{
if(!pin_staging_)
return;
auto* heap_ref = new std::shared_ptr<void>(std::move(pin_staging_));
const hipError_t err = hipLaunchHostFunc(
release_stream_,
[](void* ud) { delete static_cast<std::shared_ptr<void>*>(ud); },
heap_ref);
if(err != hipSuccess)
{
std::cerr << "fmha_bwd_launcher: hipLaunchHostFunc failed: " << hipGetErrorString(err)
<< "; releasing eagerly\n";
delete heap_ref;
}
}
template <typename T0 /*dot_do_o_trait*/,
typename T1 /*dq_dk_dv_trait*/,
typename T2 /*convert_dq_trait*/,
typename Arch>
void init(const fmha_bwd_traits& traits)
void init(const fmha_bwd_traits& t)
{
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {
traits_ = t;
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {
return fmha_bwd_<T0, T1, T2, Arch>(s, a);
};
host_ws_size = fmha_bwd_dq_dk_dv_dq_ws_host_size_<T1, Arch>(traits.batch);
if(host_ws_size > 0)
host_ws_size_ = fmha_bwd_dq_dk_dv_dq_ws_host_size_<T1, Arch>(t.batch);
size_t device_ws_size = 0;
if(host_ws_size_ > 0)
{
ws_host = std::make_unique<char[]>(host_ws_size); // TODO: support host mem allocator
device_ws_size = fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<T1, Arch>( //
ws_host.get(),
traits.batch,
traits.hdim_q,
traits.nhead_q,
traits.seqlen_q,
traits.seqlen_k,
traits.seqstart_qs,
traits.seqstart_ks);
// In group mode t.seqlen_q is already the padded total (== seqstart_q[batch]);
// in batch mode it's per-batch and the total is batch * seqlen_q.
const ck_tile::index_t total_seqlen_q_padded =
t.is_group_mode ? t.seqlen_q : t.batch * t.seqlen_q;
device_ws_size = fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_<T1, Arch>(
t.batch, t.hdim_q, t.nhead_q, total_seqlen_q_padded, t.max_seqlen_k);
pack_workspace_host_ = [batch = t.batch,
hdim_q = t.hdim_q,
nhead_q = t.nhead_q,
seqlen_q = t.seqlen_q,
seqlen_k = t.seqlen_k //
](void* host_ws, const int* seqstart_q, const int* seqstart_k) {
fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<T1, Arch>(
host_ws, batch, hdim_q, nhead_q, seqlen_q, seqlen_k, seqstart_q, seqstart_k);
};
}
workspace_size = host_ws_size + device_ws_size;
const bool needs_zero_dq_acc = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<T1, Arch>();
prepare_workspace = [this, needs_zero_dq_acc](void* device_ws) {
if(host_ws_size > 0)
HIP_CHECK_ERROR(
hipMemcpy(device_ws, ws_host.get(), host_ws_size, hipMemcpyHostToDevice));
if(needs_zero_dq_acc)
HIP_CHECK_ERROR(
hipMemset(static_cast<char*>(device_ws) + host_ws_size, 0, device_ws_size));
};
workspace_size = host_ws_size_ + device_ws_size;
needs_zero_dq_acc_ = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<T1, Arch>();
}
public:

View File

@@ -8,10 +8,13 @@
#include "utils.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include "ck_tile/host/pinned_host_releaser.hpp"
#include <array>
#include <chrono>
#include <cstring>
#include <functional>
#include <memory>
#include <numeric>
#include <ostream>
#include <string>
@@ -391,26 +394,47 @@ bwd_result fmha_bwd_run(mode_enum mode,
p_drop > 0.0f,
s_randval,
deterministic,
(mode == mode_enum::group) ? seqstart_q_host.data() : nullptr,
(mode == mode_enum::group) ? seqstart_k_host.data() : nullptr,
});
const auto t1_launcher = std::chrono::high_resolution_clock::now();
const double launcher_ctor_ms =
std::chrono::duration<double, std::milli>(t1_launcher - t0_launcher).count();
const size_t ws_size = launcher.workspace_size;
ck_tile::DeviceMem ws_buf(ws_size);
// Stage seqstart to device before prepare_workspace_async (which D2Hs it back).
seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqstart_k_host.data());
// Pinned host allocator for the launcher's async prepare pipeline. The
// shared_ptr deleter MUST NOT call any HIP API: it runs from the launcher's
// tail hipLaunchHostFunc on the driver helper thread, which holds HIP
// runtime locks. Deleter enqueues to a worker thread that hipHostFrees off
// the callback path.
auto pinned_host_alloc = [](size_t bytes) -> std::shared_ptr<void> {
void* p = nullptr;
HIP_CHECK_ERROR(hipHostMalloc(&p, bytes, hipHostMallocDefault));
return std::shared_ptr<void>(
p, [](void* q) { ck_tile::pinned_host_releaser::instance().enqueue(q); });
};
ck_tile::gpu_timer prepare_ws_timer;
prepare_ws_timer.start(nullptr);
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
prepare_ws_timer.stop(nullptr);
prepare_ws_timer.start(stream_config.stream_id_);
launcher.prepare_workspace_async(
ws_buf.GetDeviceBuffer(),
(mode == mode_enum::group) ? static_cast<const int*>(seqstart_q.GetDeviceBuffer())
: nullptr,
(mode == mode_enum::group) ? static_cast<const int*>(seqstart_k.GetDeviceBuffer())
: nullptr,
stream_config,
pinned_host_alloc);
prepare_ws_timer.stop(stream_config.stream_id_);
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
v_buf.ToDevice(v_host.data());
bias_buf.ToDevice(bias_host.data());
do_buf.ToDevice(do_host.data());
seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqstart_k_host.data());
// seqstart_q/k were already ToDevice'd above before prepare_workspace_async.
if(mode == mode_enum::group)
{
std::vector<int32_t> seqlen_q_host(seqlen_qs.begin(), seqlen_qs.end());
@@ -902,8 +926,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
dq_buf.ToDevice(dq_host.data());
dk_buf.ToDevice(dk_host.data());
dv_buf.ToDevice(dv_host.data());
// re-initialize workspace for validation run
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
o_buf.ToDevice(o_host.data());
lse_buf.ToDevice(lse_host.data());
@@ -912,6 +934,15 @@ bwd_result fmha_bwd_run(mode_enum mode,
d_sink_buf.SetZero();
ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1};
// re-initialize workspace for validation run
launcher.prepare_workspace_async(
ws_buf.GetDeviceBuffer(),
(mode == mode_enum::group) ? static_cast<const int*>(seqstart_q.GetDeviceBuffer())
: nullptr,
(mode == mode_enum::group) ? static_cast<const int*>(seqstart_k.GetDeviceBuffer())
: nullptr,
stream_config_v,
pinned_host_alloc);
launcher(fmha_args, stream_config_v);
dq_buf.FromDevice(dq_host.data());