[CK_TILE] Add async workspace prepare to FMHA BWD launcher

This commit is contained in:
Ding, Yi
2026-05-12 01:59:15 -05:00
parent 733bfee092
commit d434410e52
5 changed files with 297 additions and 42 deletions

View File

@@ -175,6 +175,16 @@ size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
return k_::GetWorkspaceHostSize(batch_size);
}}
template <>
size_t fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
ck_tile::index_t max_batch, ck_tile::index_t hdim_q, ck_tile::index_t nhead_q,
ck_tile::index_t max_seqlen_q, ck_tile::index_t max_seqlen_k)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
return k_::GetWorkspaceDeviceSizeUpperBound(
max_batch, hdim_q, nhead_q, max_seqlen_q, max_seqlen_k);
}}
template <>
size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
void* cpu_ws, ck_tile::index_t batch_size, ck_tile::index_t hdim_q,

View File

@@ -470,6 +470,12 @@ struct fmha_bwd_traits;
template <typename Traits_, typename Arch = void>
size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_(int batch_size);
template <typename Traits_, typename Arch = void>
size_t fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_(ck_tile::index_t max_batch,
ck_tile::index_t hdim_q,
ck_tile::index_t nhead_q,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t max_seqlen_k);
template <typename Traits_, typename Arch = void>
size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_(void* cpu_ws,
ck_tile::index_t batch_size,
ck_tile::index_t hdim_q,
@@ -539,11 +545,6 @@ struct fmha_bwd_traits
bool has_dropout;
bool is_store_randval;
bool is_deterministic;
// Raw pointers for group mode: cumulative physical seqlen arrays of length batch+1.
// Only need to remain valid during fmha_bwd_launcher construction (i.e. through
// PrepareWorkspaceHost); they are not retained afterward.
const int* seqstart_qs = nullptr;
const int* seqstart_ks = nullptr;
// TODO: padding check is inside this api
};
@@ -589,53 +590,141 @@ struct fmha_bwd_launcher
std::cerr << "fmha_bwd: no kernel found for given traits, skipping run\n";
return -1.0f;
}};
// Layout: [host_ws_size_ bytes (host-prepared metadata)][dq_acc region]
size_t workspace_size = 0;
std::function<void(void*)> prepare_workspace{[](void*) {
std::cerr << "fmha_bwd: no kernel found for given traits, skipping prepare_workspace\n";
}};
fmha_bwd_launcher(const fmha_bwd_traits&);
fmha_bwd_launcher(fmha_bwd_launcher&&) = delete;
fmha_bwd_launcher& operator=(fmha_bwd_launcher&&) = delete;
~fmha_bwd_launcher() noexcept { schedule_pin_staging_release(); }
// Stream-async: zero dq_acc, D2H seqstart, host-pack metadata, H2D into device_ws.
// `pinned_host_alloc` returns a shared_ptr to a pinned host buffer; its deleter
// is invoked on the stream tail after the H2D completes.
void prepare_workspace_async( //
void* device_ws_ptr,
const int* seqstart_q_dev,
const int* seqstart_k_dev,
const ck_tile::stream_config& s,
const std::function<std::shared_ptr<void>(size_t)>& pinned_host_alloc)
{
hipStream_t stream = s.stream_id_;
if(needs_zero_dq_acc_ && workspace_size > host_ws_size_)
HIP_CHECK_ERROR(hipMemsetAsync(static_cast<char*>(device_ws_ptr) + host_ws_size_,
0,
workspace_size - host_ws_size_,
stream));
if(host_ws_size_ == 0)
return;
if(!pinned_host_alloc)
throw std::runtime_error(
"fmha_bwd_launcher::prepare_workspace_async: pinned_host_alloc is required");
const size_t seqstart_bytes = traits_.is_group_mode ? sizeof(int) * (traits_.batch + 1) : 0;
const size_t total_bytes = 2 * seqstart_bytes + host_ws_size_;
auto pin_base = pinned_host_alloc(total_bytes);
char* base = static_cast<char*>(pin_base.get());
int* pin_q = reinterpret_cast<int*>(base);
int* pin_k = reinterpret_cast<int*>(base + seqstart_bytes);
void* pin_w = base + 2 * seqstart_bytes;
const int* seqstart_q_pinned = traits_.is_group_mode ? pin_q : nullptr;
const int* seqstart_k_pinned = traits_.is_group_mode ? pin_k : nullptr;
if(traits_.is_group_mode)
{
HIP_CHECK_ERROR(hipMemcpyAsync(
pin_q, seqstart_q_dev, seqstart_bytes, hipMemcpyDeviceToHost, stream));
HIP_CHECK_ERROR(hipMemcpyAsync(
pin_k, seqstart_k_dev, seqstart_bytes, hipMemcpyDeviceToHost, stream));
}
auto* pack_closure = new std::function<void()>(
[=, fn = pack_workspace_host_]() { fn(pin_w, seqstart_q_pinned, seqstart_k_pinned); });
HIP_CHECK_ERROR(hipLaunchHostFunc(
stream,
[](void* ud) {
auto* c = static_cast<std::function<void()>*>(ud);
(*c)();
delete c;
},
pack_closure));
HIP_CHECK_ERROR(
hipMemcpyAsync(device_ws_ptr, pin_w, host_ws_size_, hipMemcpyHostToDevice, stream));
// Release any previous in-flight buffer before taking a new one.
schedule_pin_staging_release();
pin_staging_ = std::move(pin_base);
release_stream_ = stream;
}
private:
size_t host_ws_size = 0;
size_t device_ws_size = 0;
std::unique_ptr<char[]> ws_host;
fmha_bwd_traits traits_{};
size_t host_ws_size_ = 0;
bool needs_zero_dq_acc_ = false;
// Pure CPU; safe to invoke from a hipLaunchHostFunc callback.
std::function<void(void* host_ws, const int* seqstart_q, const int* seqstart_k)>
pack_workspace_host_{[](void*, const int*, const int*) {
std::cerr
<< "fmha_bwd: no kernel found for given traits, skipping pack_workspace_host\n";
}};
std::shared_ptr<void> pin_staging_;
hipStream_t release_stream_ = nullptr;
// The pin_staging_ deleter MUST NOT call any HIP API: it fires from the
// hipLaunchHostFunc callback on the driver helper thread, which holds
// runtime locks (would deadlock against main-thread hipFree). PyTorch's
// CachingHostAllocator is safe; bare hipHostMalloc users should defer
// hipHostFree via ck_tile::pinned_host_releaser.
void schedule_pin_staging_release() noexcept
{
if(!pin_staging_)
return;
auto* heap_ref = new std::shared_ptr<void>(pin_staging_);
const hipError_t err = hipLaunchHostFunc(
release_stream_,
[](void* ud) { delete static_cast<std::shared_ptr<void>*>(ud); },
heap_ref);
if(err != hipSuccess)
{
std::cerr << "fmha_bwd_launcher: hipLaunchHostFunc failed: " << hipGetErrorString(err)
<< "; releasing eagerly\n";
delete heap_ref;
}
}
template <typename T0 /*dot_do_o_trait*/,
typename T1 /*dq_dk_dv_trait*/,
typename T2 /*convert_dq_trait*/,
typename Arch>
void init(const fmha_bwd_traits& traits)
void init(const fmha_bwd_traits& t)
{
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {
traits_ = t;
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {
return fmha_bwd_<T0, T1, T2, Arch>(s, a);
};
host_ws_size = fmha_bwd_dq_dk_dv_dq_ws_host_size_<T1, Arch>(traits.batch);
if(host_ws_size > 0)
host_ws_size_ = fmha_bwd_dq_dk_dv_dq_ws_host_size_<T1, Arch>(t.batch);
size_t device_ws_size = 0;
if(host_ws_size_ > 0)
{
ws_host = std::make_unique<char[]>(host_ws_size); // TODO: support host mem allocator
device_ws_size = fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<T1, Arch>( //
ws_host.get(),
traits.batch,
traits.hdim_q,
traits.nhead_q,
traits.seqlen_q,
traits.seqlen_k,
traits.seqstart_qs,
traits.seqstart_ks);
device_ws_size = fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_<T1, Arch>(
t.batch, t.hdim_q, t.nhead_q, t.max_seqlen_q, t.max_seqlen_k);
pack_workspace_host_ = [batch = t.batch,
hdim_q = t.hdim_q,
nhead_q = t.nhead_q,
seqlen_q = t.seqlen_q,
seqlen_k = t.seqlen_k //
](void* host_ws, const int* seqstart_q, const int* seqstart_k) {
fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<T1, Arch>(
host_ws, batch, hdim_q, nhead_q, seqlen_q, seqlen_k, seqstart_q, seqstart_k);
};
}
workspace_size = host_ws_size + device_ws_size;
const bool needs_zero_dq_acc = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<T1, Arch>();
prepare_workspace = [this, needs_zero_dq_acc](void* device_ws) {
if(host_ws_size > 0)
HIP_CHECK_ERROR(
hipMemcpy(device_ws, ws_host.get(), host_ws_size, hipMemcpyHostToDevice));
if(needs_zero_dq_acc)
HIP_CHECK_ERROR(
hipMemset(static_cast<char*>(device_ws) + host_ws_size, 0, device_ws_size));
};
workspace_size = host_ws_size_ + device_ws_size;
needs_zero_dq_acc_ = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<T1, Arch>();
}
public:

View File

@@ -8,10 +8,13 @@
#include "utils.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include "ck_tile/host/pinned_host_releaser.hpp"
#include <array>
#include <chrono>
#include <cstring>
#include <functional>
#include <memory>
#include <numeric>
#include <ostream>
#include <string>
@@ -391,17 +394,39 @@ bwd_result fmha_bwd_run(mode_enum mode,
p_drop > 0.0f,
s_randval,
deterministic,
(mode == mode_enum::group) ? seqstart_q_host.data() : nullptr,
(mode == mode_enum::group) ? seqstart_k_host.data() : nullptr,
});
const auto t1_launcher = std::chrono::high_resolution_clock::now();
const double launcher_ctor_ms =
std::chrono::duration<double, std::milli>(t1_launcher - t0_launcher).count();
const size_t ws_size = launcher.workspace_size;
ck_tile::DeviceMem ws_buf(ws_size);
// Stage seqstart to device before prepare_workspace_async (which D2Hs it back).
seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqstart_k_host.data());
// Pinned host allocator for the launcher's async prepare pipeline. The
// shared_ptr deleter MUST NOT call any HIP API: it runs from the launcher's
// tail hipLaunchHostFunc on the driver helper thread, which holds HIP
// runtime locks. Deleter enqueues to a worker thread that hipHostFrees off
// the callback path.
auto pinned_host_alloc = [](size_t bytes) -> std::shared_ptr<void> {
void* p = nullptr;
HIP_CHECK_ERROR(hipHostMalloc(&p, bytes, hipHostMallocDefault));
return std::shared_ptr<void>(
p, [](void* q) { ck_tile::pinned_host_releaser::instance().enqueue(q); });
};
ck_tile::gpu_timer prepare_ws_timer;
prepare_ws_timer.start(nullptr);
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
launcher.prepare_workspace_async(
ws_buf.GetDeviceBuffer(),
(mode == mode_enum::group) ? static_cast<const int*>(seqstart_q.GetDeviceBuffer())
: nullptr,
(mode == mode_enum::group) ? static_cast<const int*>(seqstart_k.GetDeviceBuffer())
: nullptr,
stream_config,
pinned_host_alloc);
prepare_ws_timer.stop(nullptr);
q_buf.ToDevice(q_host.data());
@@ -409,8 +434,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
v_buf.ToDevice(v_host.data());
bias_buf.ToDevice(bias_host.data());
do_buf.ToDevice(do_host.data());
seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqstart_k_host.data());
// seqstart_q/k were already ToDevice'd above before prepare_workspace_async.
if(mode == mode_enum::group)
{
std::vector<int32_t> seqlen_q_host(seqlen_qs.begin(), seqlen_qs.end());
@@ -902,8 +926,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
dq_buf.ToDevice(dq_host.data());
dk_buf.ToDevice(dk_host.data());
dv_buf.ToDevice(dv_host.data());
// re-initialize workspace for validation run
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
o_buf.ToDevice(o_host.data());
lse_buf.ToDevice(lse_host.data());
@@ -912,6 +934,15 @@ bwd_result fmha_bwd_run(mode_enum mode,
d_sink_buf.SetZero();
ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1};
// re-initialize workspace for validation run
launcher.prepare_workspace_async(
ws_buf.GetDeviceBuffer(),
(mode == mode_enum::group) ? static_cast<const int*>(seqstart_q.GetDeviceBuffer())
: nullptr,
(mode == mode_enum::group) ? static_cast<const int*>(seqstart_k.GetDeviceBuffer())
: nullptr,
stream_config_v,
pinned_host_alloc);
launcher(fmha_args, stream_config_v);
dq_buf.FromDevice(dq_host.data());

View File

@@ -0,0 +1,77 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <condition_variable>
#include <hip/hip_runtime.h>
#include <mutex>
#include <queue>
#include <thread>
namespace ck_tile {
// Defers hipHostFree off the HIP callback path. HIP callbacks hold runtime
// locks, so calling hipHostFree (or any HIP API) from one deadlocks against
// concurrent main-thread hipFree. enqueue() is HIP-API-free; a worker thread
// drains the queue and calls hipHostFree. Use instance() for a process-wide
// shared worker.
class pinned_host_releaser
{
std::mutex mtx_;
std::condition_variable cv_;
std::queue<void*> q_;
std::thread worker_;
bool stop_ = false;
void run()
{
for(;;)
{
void* p = nullptr;
{
std::unique_lock<std::mutex> lk(mtx_);
cv_.wait(lk, [&] { return stop_ || !q_.empty(); });
if(q_.empty())
return; // stop_ && empty
p = q_.front();
q_.pop();
}
(void)hipHostFree(p);
}
}
public:
pinned_host_releaser() : worker_([this] { run(); }) {}
~pinned_host_releaser()
{
{
std::lock_guard<std::mutex> lk(mtx_);
stop_ = true;
}
cv_.notify_all();
if(worker_.joinable())
worker_.join();
}
pinned_host_releaser(const pinned_host_releaser&) = delete;
pinned_host_releaser& operator=(const pinned_host_releaser&) = delete;
static pinned_host_releaser& instance()
{
static pinned_host_releaser r;
return r;
}
void enqueue(void* p)
{
{
std::lock_guard<std::mutex> lk(mtx_);
q_.push(p);
}
cv_.notify_one();
}
};
} // namespace ck_tile

View File

@@ -166,6 +166,47 @@ struct FmhaBwdWorkspaceManager
// In these cases we need to zero out it first
return kHasMask;
}
// Mirrors PrepareWorkspaceHost's return value but uses worst-case totals so
// device workspace can be pre-allocated before host has the seqstart values.
template <bool kUseQrQtrDorPipeline, index_t kN0>
CK_TILE_HOST static size_t GetWorkspaceDeviceSizeUpperBound(index_t max_batch,
index_t hdim_q,
index_t nhead_q,
index_t max_seqlen_q,
index_t max_seqlen_k)
{
if constexpr(kUseQrQtrDorPipeline)
return 0;
if constexpr(!kIsDeterministic)
{
return sizeof(AccDataType) * static_cast<long_index_t>(max_batch) * nhead_q *
max_seqlen_q * hdim_q;
}
else if constexpr(kIsGroupMode)
{
const index_t nsplits_max = integer_divide_ceil(max_seqlen_k, kN0);
return sizeof(AccDataType) * static_cast<long_index_t>(max_batch) * nhead_q *
nsplits_max * max_seqlen_q * hdim_q;
}
else // deterministic non-group mode (kUsePersistent)
{
const index_t dqdqkdv_workers = get_num_cus();
const index_t jobs_per_head = integer_divide_ceil(max_seqlen_k, kN0);
const index_t total_jobs = max_batch * nhead_q * jobs_per_head;
const index_t jobs_per_worker = integer_divide_ceil(total_jobs, dqdqkdv_workers);
index_t nsplits;
if(jobs_per_head % jobs_per_worker == 0)
nsplits = jobs_per_head / jobs_per_worker;
else if(jobs_per_worker % jobs_per_head == 0)
nsplits = 1;
else
nsplits = 1 + integer_divide_ceil(jobs_per_head - 1, jobs_per_worker);
return sizeof(AccDataType) * static_cast<long_index_t>(max_batch) * nhead_q * nsplits *
max_seqlen_q * hdim_q;
}
}
};
template <typename FmhaPipeline_,
@@ -281,6 +322,13 @@ struct FmhaBwdDQDKDVKernel
FmhaPipeline::BlockFmhaShape::kN0>(
std::forward<Args>(args)...);
}
template <typename... Args>
CK_TILE_HOST static constexpr auto GetWorkspaceDeviceSizeUpperBound(Args&&... args)
{
return WorkspaceManager::template GetWorkspaceDeviceSizeUpperBound<
kUseQrQtrDorPipeline,
FmhaPipeline::BlockFmhaShape::kN0>(std::forward<Args>(args)...);
}
CK_TILE_HOST static constexpr bool NeedsZeroDqAcc()
{
return WorkspaceManager::template NeedsZeroDqAcc<kUseQrQtrDorPipeline, kHasMask>();