mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
[rocm-libraries] ROCm/rocm-libraries#7331 (commit 5692db0)
[CK_TILE] Add async workspace prepare to FMHA BWD launcher (#7331) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation `aiter::mha_bwd` in group mode currently issues two synchronous `hipMemcpy` D2H copies to read `seqstart_q/k` for launcher construction. These sync copies block the host (~10–30 µs each) and implicitly synchronize the device by draining the stream, breaking CPU/GPU overlap on hot training paths. This PR adds a fully stream-async workspace preparation path on the FMHA BWD launcher so callers can pre-allocate the device workspace from upper-bound shapes and stage seqstart-dependent metadata via D2H/host-pack/H2D entirely on the user's stream. ## Technical Details - `FmhaBwdWorkspaceManager::GetWorkspaceDeviceSizeUpperBound` (`include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp`): computes the worst-case device dq_acc size from `(max_batch, hdim_q, nhead_q, max_seqlen_q, max_seqlen_k)` without dereferencing any seqstart array. Mirrors `PrepareWorkspaceHost`'s return value with worst-case bounds. - `fmha_bwd_launcher::prepare_workspace_async` (`example/ck_tile/01_fmha/fmha_bwd.hpp`): on the caller's stream, in order: 1. `hipMemsetAsync` of the dq_acc region (when `NeedsZeroDqAcc()`) 2. group mode: `hipMemcpyAsync` D2H of `seqstart_q/k` into a pinned host staging buffer 3. `hipLaunchHostFunc` runs `PrepareWorkspaceHost` on the pinned buffer 4. `hipMemcpyAsync` H2D of the packed metadata into `device_ws_ptr` The pinned staging buffer is held via `std::shared_ptr<void>` returned by a caller-provided `pinned_host_alloc` callback. Lifetime is extended past stream completion by a tail `hipLaunchHostFunc` scheduled in the launcher's destructor. - `ck_tile::pinned_host_releaser` (`include/ck_tile/host/pinned_host_releaser.hpp`): worker-thread utility for callers using bare `hipHostMalloc`. Defers `hipHostFree` off the HIP driver callback thread, which holds runtime locks and would deadlock against concurrent main-thread `hipFree`. PyTorch's `CachingHostAllocator` does not need this. - Example runner (`example/ck_tile/01_fmha/fmha_bwd_runner.hpp`): switched to the async path. ## Test Plan - `tile_example_fmha_bwd` (gfx950, dev preset `-Werror -Weverything`): - batch + nondet / batch + det / group + nondet / group + det - group + det 4-batch varlen (`-b=4 -h=8 -s=4096,3072,2048,1024 -d=128`) - FA (`flash-attention`) integration on ROCm 7.1.1 + PyTorch 2.9.1: - `tests/test_flash_attn_ck.py::test_flash_attn_varlen_deterministic` - `tests/test_flash_attn_ck.py::test_flash_attn_bwd_varlen_seqq_zero` ## Test Result - All CK runner cases `valid:y`. - FA pytest: **1952 passed in 44.82s**. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
4d852e80fb
commit
83566edb0f
@@ -175,6 +175,16 @@ size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
|
||||
return k_::GetWorkspaceHostSize(batch_size);
|
||||
}}
|
||||
|
||||
template <>
|
||||
size_t fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
|
||||
ck_tile::index_t max_batch, ck_tile::index_t hdim_q, ck_tile::index_t nhead_q,
|
||||
ck_tile::index_t total_seqlen_q_padded, ck_tile::index_t max_seqlen_k)
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
return k_::GetWorkspaceDeviceSizeUpperBound(
|
||||
max_batch, hdim_q, nhead_q, total_seqlen_q_padded, max_seqlen_k);
|
||||
}}
|
||||
|
||||
template <>
|
||||
size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
|
||||
void* cpu_ws, ck_tile::index_t batch_size, ck_tile::index_t hdim_q,
|
||||
|
||||
@@ -469,6 +469,15 @@ int fmha_bwd_dq_dk_dv_maxq_();
|
||||
struct fmha_bwd_traits;
|
||||
template <typename Traits_, typename Arch = void>
|
||||
size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_(int batch_size);
|
||||
// `total_seqlen_q_padded` is total q tokens across all batches (incl. per-batch padding):
|
||||
// - batch mode: max_batch * seqlen_q
|
||||
// - group mode: seqstart_q[batch] (== varlen q tensor's first dim)
|
||||
template <typename Traits_, typename Arch = void>
|
||||
size_t fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_(ck_tile::index_t max_batch,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t nhead_q,
|
||||
ck_tile::index_t total_seqlen_q_padded,
|
||||
ck_tile::index_t max_seqlen_k);
|
||||
template <typename Traits_, typename Arch = void>
|
||||
size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_(void* cpu_ws,
|
||||
ck_tile::index_t batch_size,
|
||||
@@ -539,11 +548,6 @@ struct fmha_bwd_traits
|
||||
bool has_dropout;
|
||||
bool is_store_randval;
|
||||
bool is_deterministic;
|
||||
// Raw pointers for group mode: cumulative physical seqlen arrays of length batch+1.
|
||||
// Only need to remain valid during fmha_bwd_launcher construction (i.e. through
|
||||
// PrepareWorkspaceHost); they are not retained afterward.
|
||||
const int* seqstart_qs = nullptr;
|
||||
const int* seqstart_ks = nullptr;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
|
||||
@@ -589,53 +593,176 @@ struct fmha_bwd_launcher
|
||||
std::cerr << "fmha_bwd: no kernel found for given traits, skipping run\n";
|
||||
return -1.0f;
|
||||
}};
|
||||
// Layout: [host_ws_size_ bytes (host-prepared metadata)][dq_acc region]
|
||||
size_t workspace_size = 0;
|
||||
std::function<void(void*)> prepare_workspace{[](void*) {
|
||||
std::cerr << "fmha_bwd: no kernel found for given traits, skipping prepare_workspace\n";
|
||||
}};
|
||||
|
||||
fmha_bwd_launcher(const fmha_bwd_traits&);
|
||||
fmha_bwd_launcher(fmha_bwd_launcher&&) = delete;
|
||||
fmha_bwd_launcher& operator=(fmha_bwd_launcher&&) = delete;
|
||||
|
||||
~fmha_bwd_launcher() noexcept { schedule_pin_staging_release(); }
|
||||
|
||||
// Stream-async: zero dq_acc, D2H seqstart, host-pack metadata, H2D into device_ws.
|
||||
// `pinned_host_alloc` returns a shared_ptr to a pinned host buffer; its deleter
|
||||
// is invoked on the stream tail after the H2D completes.
|
||||
void prepare_workspace_async( //
|
||||
void* device_ws_ptr,
|
||||
const int* seqstart_q_dev,
|
||||
const int* seqstart_k_dev,
|
||||
const ck_tile::stream_config& s,
|
||||
const std::function<std::shared_ptr<void>(size_t)>& pinned_host_alloc)
|
||||
{
|
||||
hipStream_t stream = s.stream_id_;
|
||||
|
||||
// Fast path: no host-side metadata to stage; just zero dq_acc if needed.
|
||||
if(host_ws_size_ == 0)
|
||||
{
|
||||
if(needs_zero_dq_acc_ && workspace_size > 0)
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(device_ws_ptr, 0, workspace_size, stream));
|
||||
return;
|
||||
}
|
||||
|
||||
if(!pinned_host_alloc)
|
||||
throw std::runtime_error(
|
||||
"fmha_bwd_launcher::prepare_workspace_async: pinned_host_alloc is required");
|
||||
|
||||
// Allocate pinned host staging first: if it throws we haven't issued any
|
||||
// stream work yet, leaving the workspace cleanly un-prepared.
|
||||
const size_t seqstart_bytes = traits_.is_group_mode ? sizeof(int) * (traits_.batch + 1) : 0;
|
||||
const size_t total_bytes = 2 * seqstart_bytes + host_ws_size_;
|
||||
auto pin_base = pinned_host_alloc(total_bytes);
|
||||
|
||||
if(needs_zero_dq_acc_ && workspace_size > host_ws_size_)
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(static_cast<char*>(device_ws_ptr) + host_ws_size_,
|
||||
0,
|
||||
workspace_size - host_ws_size_,
|
||||
stream));
|
||||
|
||||
char* base = static_cast<char*>(pin_base.get());
|
||||
int* pin_q = reinterpret_cast<int*>(base);
|
||||
int* pin_k = reinterpret_cast<int*>(base + seqstart_bytes);
|
||||
void* pin_w = base + 2 * seqstart_bytes;
|
||||
const int* seqstart_q_pinned = traits_.is_group_mode ? pin_q : nullptr;
|
||||
const int* seqstart_k_pinned = traits_.is_group_mode ? pin_k : nullptr;
|
||||
|
||||
if(traits_.is_group_mode)
|
||||
{
|
||||
if(!seqstart_q_dev || !seqstart_k_dev)
|
||||
throw std::runtime_error("fmha_bwd_launcher::prepare_workspace_async: "
|
||||
"seqstart_q_dev and seqstart_k_dev are required in "
|
||||
"group mode");
|
||||
HIP_CHECK_ERROR(hipMemcpyAsync(
|
||||
pin_q, seqstart_q_dev, seqstart_bytes, hipMemcpyDeviceToHost, stream));
|
||||
HIP_CHECK_ERROR(hipMemcpyAsync(
|
||||
pin_k, seqstart_k_dev, seqstart_bytes, hipMemcpyDeviceToHost, stream));
|
||||
}
|
||||
|
||||
auto pack_closure = std::make_unique<std::function<void()>>(
|
||||
[=, fn = pack_workspace_host_]() { fn(pin_w, seqstart_q_pinned, seqstart_k_pinned); });
|
||||
// Callback runs on the HIP driver helper thread across a C ABI boundary;
|
||||
// any exception escaping it would call std::terminate.
|
||||
HIP_CHECK_ERROR(hipLaunchHostFunc(
|
||||
stream,
|
||||
[](void* ud) {
|
||||
std::unique_ptr<std::function<void()>> c{static_cast<std::function<void()>*>(ud)};
|
||||
try
|
||||
{
|
||||
(*c)();
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
// The H2D queued after this callback will copy indeterminate
|
||||
// metadata to device and the kernel will produce wrong results;
|
||||
// unlikely in practice since pack_workspace_host_ only throws on
|
||||
// precondition violations.
|
||||
std::cerr << "fmha_bwd_launcher: pack_workspace_host threw: " << e.what()
|
||||
<< '\n';
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
std::cerr << "fmha_bwd_launcher: pack_workspace_host threw unknown\n";
|
||||
}
|
||||
},
|
||||
pack_closure.get()));
|
||||
// Ownership transferred to the callback only after a successful launch.
|
||||
pack_closure.release();
|
||||
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpyAsync(device_ws_ptr, pin_w, host_ws_size_, hipMemcpyHostToDevice, stream));
|
||||
|
||||
// Release any previous in-flight buffer before taking a new one.
|
||||
schedule_pin_staging_release();
|
||||
pin_staging_ = std::move(pin_base);
|
||||
release_stream_ = stream;
|
||||
}
|
||||
|
||||
private:
|
||||
size_t host_ws_size = 0;
|
||||
size_t device_ws_size = 0;
|
||||
std::unique_ptr<char[]> ws_host;
|
||||
fmha_bwd_traits traits_{};
|
||||
size_t host_ws_size_ = 0;
|
||||
bool needs_zero_dq_acc_ = false;
|
||||
// Pure CPU; safe to invoke from a hipLaunchHostFunc callback.
|
||||
std::function<void(void* host_ws, const int* seqstart_q, const int* seqstart_k)>
|
||||
pack_workspace_host_{[](void*, const int*, const int*) {
|
||||
std::cerr
|
||||
<< "fmha_bwd: no kernel found for given traits, skipping pack_workspace_host\n";
|
||||
}};
|
||||
std::shared_ptr<void> pin_staging_;
|
||||
hipStream_t release_stream_ = nullptr;
|
||||
|
||||
// The pin_staging_ deleter MUST NOT call any HIP API: it fires from the
|
||||
// hipLaunchHostFunc callback on the driver helper thread, which holds
|
||||
// runtime locks (would deadlock against main-thread hipFree). PyTorch's
|
||||
// CachingHostAllocator is safe; bare hipHostMalloc users should defer
|
||||
// hipHostFree via ck_tile::pinned_host_releaser.
|
||||
void schedule_pin_staging_release() noexcept
|
||||
{
|
||||
if(!pin_staging_)
|
||||
return;
|
||||
auto* heap_ref = new std::shared_ptr<void>(std::move(pin_staging_));
|
||||
const hipError_t err = hipLaunchHostFunc(
|
||||
release_stream_,
|
||||
[](void* ud) { delete static_cast<std::shared_ptr<void>*>(ud); },
|
||||
heap_ref);
|
||||
if(err != hipSuccess)
|
||||
{
|
||||
std::cerr << "fmha_bwd_launcher: hipLaunchHostFunc failed: " << hipGetErrorString(err)
|
||||
<< "; releasing eagerly\n";
|
||||
delete heap_ref;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T0 /*dot_do_o_trait*/,
|
||||
typename T1 /*dq_dk_dv_trait*/,
|
||||
typename T2 /*convert_dq_trait*/,
|
||||
typename Arch>
|
||||
void init(const fmha_bwd_traits& traits)
|
||||
void init(const fmha_bwd_traits& t)
|
||||
{
|
||||
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {
|
||||
traits_ = t;
|
||||
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {
|
||||
return fmha_bwd_<T0, T1, T2, Arch>(s, a);
|
||||
};
|
||||
host_ws_size = fmha_bwd_dq_dk_dv_dq_ws_host_size_<T1, Arch>(traits.batch);
|
||||
if(host_ws_size > 0)
|
||||
host_ws_size_ = fmha_bwd_dq_dk_dv_dq_ws_host_size_<T1, Arch>(t.batch);
|
||||
size_t device_ws_size = 0;
|
||||
if(host_ws_size_ > 0)
|
||||
{
|
||||
ws_host = std::make_unique<char[]>(host_ws_size); // TODO: support host mem allocator
|
||||
device_ws_size = fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<T1, Arch>( //
|
||||
ws_host.get(),
|
||||
traits.batch,
|
||||
traits.hdim_q,
|
||||
traits.nhead_q,
|
||||
traits.seqlen_q,
|
||||
traits.seqlen_k,
|
||||
traits.seqstart_qs,
|
||||
traits.seqstart_ks);
|
||||
// In group mode t.seqlen_q is already the padded total (== seqstart_q[batch]);
|
||||
// in batch mode it's per-batch and the total is batch * seqlen_q.
|
||||
const ck_tile::index_t total_seqlen_q_padded =
|
||||
t.is_group_mode ? t.seqlen_q : t.batch * t.seqlen_q;
|
||||
device_ws_size = fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_<T1, Arch>(
|
||||
t.batch, t.hdim_q, t.nhead_q, total_seqlen_q_padded, t.max_seqlen_k);
|
||||
pack_workspace_host_ = [batch = t.batch,
|
||||
hdim_q = t.hdim_q,
|
||||
nhead_q = t.nhead_q,
|
||||
seqlen_q = t.seqlen_q,
|
||||
seqlen_k = t.seqlen_k //
|
||||
](void* host_ws, const int* seqstart_q, const int* seqstart_k) {
|
||||
fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<T1, Arch>(
|
||||
host_ws, batch, hdim_q, nhead_q, seqlen_q, seqlen_k, seqstart_q, seqstart_k);
|
||||
};
|
||||
}
|
||||
workspace_size = host_ws_size + device_ws_size;
|
||||
const bool needs_zero_dq_acc = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<T1, Arch>();
|
||||
prepare_workspace = [this, needs_zero_dq_acc](void* device_ws) {
|
||||
if(host_ws_size > 0)
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpy(device_ws, ws_host.get(), host_ws_size, hipMemcpyHostToDevice));
|
||||
if(needs_zero_dq_acc)
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemset(static_cast<char*>(device_ws) + host_ws_size, 0, device_ws_size));
|
||||
};
|
||||
workspace_size = host_ws_size_ + device_ws_size;
|
||||
needs_zero_dq_acc_ = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<T1, Arch>();
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
@@ -8,10 +8,13 @@
|
||||
#include "utils.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#include "ck_tile/host/pinned_host_releaser.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <chrono>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
@@ -391,26 +394,47 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
p_drop > 0.0f,
|
||||
s_randval,
|
||||
deterministic,
|
||||
(mode == mode_enum::group) ? seqstart_q_host.data() : nullptr,
|
||||
(mode == mode_enum::group) ? seqstart_k_host.data() : nullptr,
|
||||
});
|
||||
const auto t1_launcher = std::chrono::high_resolution_clock::now();
|
||||
const double launcher_ctor_ms =
|
||||
std::chrono::duration<double, std::milli>(t1_launcher - t0_launcher).count();
|
||||
const size_t ws_size = launcher.workspace_size;
|
||||
ck_tile::DeviceMem ws_buf(ws_size);
|
||||
|
||||
// Stage seqstart to device before prepare_workspace_async (which D2Hs it back).
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
seqstart_k.ToDevice(seqstart_k_host.data());
|
||||
|
||||
// Pinned host allocator for the launcher's async prepare pipeline. The
|
||||
// shared_ptr deleter MUST NOT call any HIP API: it runs from the launcher's
|
||||
// tail hipLaunchHostFunc on the driver helper thread, which holds HIP
|
||||
// runtime locks. Deleter enqueues to a worker thread that hipHostFrees off
|
||||
// the callback path.
|
||||
auto pinned_host_alloc = [](size_t bytes) -> std::shared_ptr<void> {
|
||||
void* p = nullptr;
|
||||
HIP_CHECK_ERROR(hipHostMalloc(&p, bytes, hipHostMallocDefault));
|
||||
return std::shared_ptr<void>(
|
||||
p, [](void* q) { ck_tile::pinned_host_releaser::instance().enqueue(q); });
|
||||
};
|
||||
|
||||
ck_tile::gpu_timer prepare_ws_timer;
|
||||
prepare_ws_timer.start(nullptr);
|
||||
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
|
||||
prepare_ws_timer.stop(nullptr);
|
||||
prepare_ws_timer.start(stream_config.stream_id_);
|
||||
launcher.prepare_workspace_async(
|
||||
ws_buf.GetDeviceBuffer(),
|
||||
(mode == mode_enum::group) ? static_cast<const int*>(seqstart_q.GetDeviceBuffer())
|
||||
: nullptr,
|
||||
(mode == mode_enum::group) ? static_cast<const int*>(seqstart_k.GetDeviceBuffer())
|
||||
: nullptr,
|
||||
stream_config,
|
||||
pinned_host_alloc);
|
||||
prepare_ws_timer.stop(stream_config.stream_id_);
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
v_buf.ToDevice(v_host.data());
|
||||
bias_buf.ToDevice(bias_host.data());
|
||||
do_buf.ToDevice(do_host.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
seqstart_k.ToDevice(seqstart_k_host.data());
|
||||
// seqstart_q/k were already ToDevice'd above before prepare_workspace_async.
|
||||
if(mode == mode_enum::group)
|
||||
{
|
||||
std::vector<int32_t> seqlen_q_host(seqlen_qs.begin(), seqlen_qs.end());
|
||||
@@ -902,8 +926,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
dq_buf.ToDevice(dq_host.data());
|
||||
dk_buf.ToDevice(dk_host.data());
|
||||
dv_buf.ToDevice(dv_host.data());
|
||||
// re-initialize workspace for validation run
|
||||
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
|
||||
|
||||
o_buf.ToDevice(o_host.data());
|
||||
lse_buf.ToDevice(lse_host.data());
|
||||
@@ -912,6 +934,15 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
d_sink_buf.SetZero();
|
||||
|
||||
ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1};
|
||||
// re-initialize workspace for validation run
|
||||
launcher.prepare_workspace_async(
|
||||
ws_buf.GetDeviceBuffer(),
|
||||
(mode == mode_enum::group) ? static_cast<const int*>(seqstart_q.GetDeviceBuffer())
|
||||
: nullptr,
|
||||
(mode == mode_enum::group) ? static_cast<const int*>(seqstart_k.GetDeviceBuffer())
|
||||
: nullptr,
|
||||
stream_config_v,
|
||||
pinned_host_alloc);
|
||||
launcher(fmha_args, stream_config_v);
|
||||
|
||||
dq_buf.FromDevice(dq_host.data());
|
||||
|
||||
Reference in New Issue
Block a user