mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
[rocm-libraries] ROCm/rocm-libraries#7331 (commit 5692db0)
[CK_TILE] Add async workspace prepare to FMHA BWD launcher (#7331) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation `aiter::mha_bwd` in group mode currently issues two synchronous `hipMemcpy` D2H copies to read `seqstart_q/k` for launcher construction. These sync copies block the host (~10–30 µs each) and implicitly synchronize the device by draining the stream, breaking CPU/GPU overlap on hot training paths. This PR adds a fully stream-async workspace preparation path on the FMHA BWD launcher so callers can pre-allocate the device workspace from upper-bound shapes and stage seqstart-dependent metadata via D2H/host-pack/H2D entirely on the user's stream. ## Technical Details - `FmhaBwdWorkspaceManager::GetWorkspaceDeviceSizeUpperBound` (`include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp`): computes the worst-case device dq_acc size from `(max_batch, hdim_q, nhead_q, max_seqlen_q, max_seqlen_k)` without dereferencing any seqstart array. Mirrors `PrepareWorkspaceHost`'s return value with worst-case bounds. - `fmha_bwd_launcher::prepare_workspace_async` (`example/ck_tile/01_fmha/fmha_bwd.hpp`): on the caller's stream, in order: 1. `hipMemsetAsync` of the dq_acc region (when `NeedsZeroDqAcc()`) 2. group mode: `hipMemcpyAsync` D2H of `seqstart_q/k` into a pinned host staging buffer 3. `hipLaunchHostFunc` runs `PrepareWorkspaceHost` on the pinned buffer 4. `hipMemcpyAsync` H2D of the packed metadata into `device_ws_ptr` The pinned staging buffer is held via `std::shared_ptr<void>` returned by a caller-provided `pinned_host_alloc` callback. Lifetime is extended past stream completion by a tail `hipLaunchHostFunc` scheduled in the launcher's destructor. - `ck_tile::pinned_host_releaser` (`include/ck_tile/host/pinned_host_releaser.hpp`): worker-thread utility for callers using bare `hipHostMalloc`. Defers `hipHostFree` off the HIP driver callback thread, which holds runtime locks and would deadlock against concurrent main-thread `hipFree`. PyTorch's `CachingHostAllocator` does not need this. - Example runner (`example/ck_tile/01_fmha/fmha_bwd_runner.hpp`): switched to the async path. ## Test Plan - `tile_example_fmha_bwd` (gfx950, dev preset `-Werror -Weverything`): - batch + nondet / batch + det / group + nondet / group + det - group + det 4-batch varlen (`-b=4 -h=8 -s=4096,3072,2048,1024 -d=128`) - FA (`flash-attention`) integration on ROCm 7.1.1 + PyTorch 2.9.1: - `tests/test_flash_attn_ck.py::test_flash_attn_varlen_deterministic` - `tests/test_flash_attn_ck.py::test_flash_attn_bwd_varlen_seqq_zero` ## Test Result - All CK runner cases `valid:y`. - FA pytest: **1952 passed in 44.82s**. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
4d852e80fb
commit
83566edb0f
77
include/ck_tile/host/pinned_host_releaser.hpp
Normal file
77
include/ck_tile/host/pinned_host_releaser.hpp
Normal file
@@ -0,0 +1,77 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <condition_variable>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Defers hipHostFree off the HIP callback path. HIP callbacks hold runtime
|
||||
// locks, so calling hipHostFree (or any HIP API) from one deadlocks against
|
||||
// concurrent main-thread hipFree. enqueue() is HIP-API-free; a worker thread
|
||||
// drains the queue and calls hipHostFree. Use instance() for a process-wide
|
||||
// shared worker.
|
||||
class pinned_host_releaser
|
||||
{
|
||||
std::mutex mtx_;
|
||||
std::condition_variable cv_;
|
||||
std::queue<void*> q_;
|
||||
std::thread worker_;
|
||||
bool stop_ = false;
|
||||
|
||||
void run()
|
||||
{
|
||||
for(;;)
|
||||
{
|
||||
void* p = nullptr;
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(mtx_);
|
||||
cv_.wait(lk, [&] { return stop_ || !q_.empty(); });
|
||||
if(q_.empty())
|
||||
return; // stop_ && empty
|
||||
p = q_.front();
|
||||
q_.pop();
|
||||
}
|
||||
(void)hipHostFree(p);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
pinned_host_releaser() : worker_([this] { run(); }) {}
|
||||
|
||||
~pinned_host_releaser()
|
||||
{
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mtx_);
|
||||
stop_ = true;
|
||||
}
|
||||
cv_.notify_all();
|
||||
if(worker_.joinable())
|
||||
worker_.join();
|
||||
}
|
||||
|
||||
pinned_host_releaser(const pinned_host_releaser&) = delete;
|
||||
pinned_host_releaser& operator=(const pinned_host_releaser&) = delete;
|
||||
|
||||
static pinned_host_releaser& instance()
|
||||
{
|
||||
static pinned_host_releaser r;
|
||||
return r;
|
||||
}
|
||||
|
||||
void enqueue(void* p)
|
||||
{
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mtx_);
|
||||
q_.push(p);
|
||||
}
|
||||
cv_.notify_one();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -166,6 +166,49 @@ struct FmhaBwdWorkspaceManager
|
||||
// In these cases we need to zero out it first
|
||||
return kHasMask;
|
||||
}
|
||||
|
||||
// Upper bound on PrepareWorkspaceHost's size, computable without seqstart so
|
||||
// the device workspace can be allocated before any D2H.
|
||||
//
|
||||
// total_seqlen_q_padded: total q tokens incl. per-batch padding.
|
||||
// Batch: max_batch * seqlen_q. Group: seqstart_q[batch].
|
||||
// max_seqlen_k: deterministic-only; pass per-batch padded max if the caller
|
||||
// does internal k padding, otherwise the logical max is fine.
|
||||
template <bool kUseQrQtrDorPipeline, index_t kN0>
|
||||
CK_TILE_HOST static size_t GetWorkspaceDeviceSizeUpperBound(index_t max_batch,
|
||||
index_t hdim_q,
|
||||
index_t nhead_q,
|
||||
index_t total_seqlen_q_padded,
|
||||
index_t max_seqlen_k)
|
||||
{
|
||||
if constexpr(kUseQrQtrDorPipeline)
|
||||
return 0;
|
||||
|
||||
index_t nsplits_factor = 1;
|
||||
if constexpr(kIsDeterministic)
|
||||
{
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
nsplits_factor = integer_divide_ceil(max_seqlen_k, kN0);
|
||||
}
|
||||
else // persistent
|
||||
{
|
||||
const index_t dqdqkdv_workers = get_num_cus();
|
||||
const index_t jobs_per_head = integer_divide_ceil(max_seqlen_k, kN0);
|
||||
const index_t total_jobs = max_batch * nhead_q * jobs_per_head;
|
||||
const index_t jobs_per_worker = integer_divide_ceil(total_jobs, dqdqkdv_workers);
|
||||
if(jobs_per_head % jobs_per_worker == 0)
|
||||
nsplits_factor = jobs_per_head / jobs_per_worker;
|
||||
else if(jobs_per_worker % jobs_per_head == 0)
|
||||
nsplits_factor = 1;
|
||||
else
|
||||
nsplits_factor = 1 + integer_divide_ceil(jobs_per_head - 1, jobs_per_worker);
|
||||
}
|
||||
}
|
||||
|
||||
return sizeof(AccDataType) * static_cast<long_index_t>(nhead_q) * nsplits_factor *
|
||||
total_seqlen_q_padded * hdim_q;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename FmhaPipeline_,
|
||||
@@ -281,6 +324,13 @@ struct FmhaBwdDQDKDVKernel
|
||||
FmhaPipeline::BlockFmhaShape::kN0>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
template <typename... Args>
|
||||
CK_TILE_HOST static size_t GetWorkspaceDeviceSizeUpperBound(Args&&... args)
|
||||
{
|
||||
return WorkspaceManager::template GetWorkspaceDeviceSizeUpperBound<
|
||||
kUseQrQtrDorPipeline,
|
||||
FmhaPipeline::BlockFmhaShape::kN0>(std::forward<Args>(args)...);
|
||||
}
|
||||
CK_TILE_HOST static constexpr bool NeedsZeroDqAcc()
|
||||
{
|
||||
return WorkspaceManager::template NeedsZeroDqAcc<kUseQrQtrDorPipeline, kHasMask>();
|
||||
|
||||
Reference in New Issue
Block a user