mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] Add async workspace prepare to FMHA BWD launcher
This commit is contained in:
@@ -175,6 +175,16 @@ size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
|
||||
return k_::GetWorkspaceHostSize(batch_size);
|
||||
}}
|
||||
|
||||
template <>
|
||||
size_t fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
|
||||
ck_tile::index_t max_batch, ck_tile::index_t hdim_q, ck_tile::index_t nhead_q,
|
||||
ck_tile::index_t max_seqlen_q, ck_tile::index_t max_seqlen_k)
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
return k_::GetWorkspaceDeviceSizeUpperBound(
|
||||
max_batch, hdim_q, nhead_q, max_seqlen_q, max_seqlen_k);
|
||||
}}
|
||||
|
||||
template <>
|
||||
size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
|
||||
void* cpu_ws, ck_tile::index_t batch_size, ck_tile::index_t hdim_q,
|
||||
|
||||
@@ -470,6 +470,12 @@ struct fmha_bwd_traits;
|
||||
template <typename Traits_, typename Arch = void>
|
||||
size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_(int batch_size);
|
||||
template <typename Traits_, typename Arch = void>
|
||||
size_t fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_(ck_tile::index_t max_batch,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t nhead_q,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t max_seqlen_k);
|
||||
template <typename Traits_, typename Arch = void>
|
||||
size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_(void* cpu_ws,
|
||||
ck_tile::index_t batch_size,
|
||||
ck_tile::index_t hdim_q,
|
||||
@@ -539,11 +545,6 @@ struct fmha_bwd_traits
|
||||
bool has_dropout;
|
||||
bool is_store_randval;
|
||||
bool is_deterministic;
|
||||
// Raw pointers for group mode: cumulative physical seqlen arrays of length batch+1.
|
||||
// Only need to remain valid during fmha_bwd_launcher construction (i.e. through
|
||||
// PrepareWorkspaceHost); they are not retained afterward.
|
||||
const int* seqstart_qs = nullptr;
|
||||
const int* seqstart_ks = nullptr;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
|
||||
@@ -589,53 +590,141 @@ struct fmha_bwd_launcher
|
||||
std::cerr << "fmha_bwd: no kernel found for given traits, skipping run\n";
|
||||
return -1.0f;
|
||||
}};
|
||||
// Layout: [host_ws_size_ bytes (host-prepared metadata)][dq_acc region]
|
||||
size_t workspace_size = 0;
|
||||
std::function<void(void*)> prepare_workspace{[](void*) {
|
||||
std::cerr << "fmha_bwd: no kernel found for given traits, skipping prepare_workspace\n";
|
||||
}};
|
||||
|
||||
fmha_bwd_launcher(const fmha_bwd_traits&);
|
||||
fmha_bwd_launcher(fmha_bwd_launcher&&) = delete;
|
||||
fmha_bwd_launcher& operator=(fmha_bwd_launcher&&) = delete;
|
||||
|
||||
~fmha_bwd_launcher() noexcept { schedule_pin_staging_release(); }
|
||||
|
||||
// Stream-async: zero dq_acc, D2H seqstart, host-pack metadata, H2D into device_ws.
|
||||
// `pinned_host_alloc` returns a shared_ptr to a pinned host buffer; its deleter
|
||||
// is invoked on the stream tail after the H2D completes.
|
||||
void prepare_workspace_async( //
|
||||
void* device_ws_ptr,
|
||||
const int* seqstart_q_dev,
|
||||
const int* seqstart_k_dev,
|
||||
const ck_tile::stream_config& s,
|
||||
const std::function<std::shared_ptr<void>(size_t)>& pinned_host_alloc)
|
||||
{
|
||||
hipStream_t stream = s.stream_id_;
|
||||
if(needs_zero_dq_acc_ && workspace_size > host_ws_size_)
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(static_cast<char*>(device_ws_ptr) + host_ws_size_,
|
||||
0,
|
||||
workspace_size - host_ws_size_,
|
||||
stream));
|
||||
|
||||
if(host_ws_size_ == 0)
|
||||
return;
|
||||
|
||||
if(!pinned_host_alloc)
|
||||
throw std::runtime_error(
|
||||
"fmha_bwd_launcher::prepare_workspace_async: pinned_host_alloc is required");
|
||||
|
||||
const size_t seqstart_bytes = traits_.is_group_mode ? sizeof(int) * (traits_.batch + 1) : 0;
|
||||
const size_t total_bytes = 2 * seqstart_bytes + host_ws_size_;
|
||||
auto pin_base = pinned_host_alloc(total_bytes);
|
||||
|
||||
char* base = static_cast<char*>(pin_base.get());
|
||||
int* pin_q = reinterpret_cast<int*>(base);
|
||||
int* pin_k = reinterpret_cast<int*>(base + seqstart_bytes);
|
||||
void* pin_w = base + 2 * seqstart_bytes;
|
||||
const int* seqstart_q_pinned = traits_.is_group_mode ? pin_q : nullptr;
|
||||
const int* seqstart_k_pinned = traits_.is_group_mode ? pin_k : nullptr;
|
||||
|
||||
if(traits_.is_group_mode)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipMemcpyAsync(
|
||||
pin_q, seqstart_q_dev, seqstart_bytes, hipMemcpyDeviceToHost, stream));
|
||||
HIP_CHECK_ERROR(hipMemcpyAsync(
|
||||
pin_k, seqstart_k_dev, seqstart_bytes, hipMemcpyDeviceToHost, stream));
|
||||
}
|
||||
|
||||
auto* pack_closure = new std::function<void()>(
|
||||
[=, fn = pack_workspace_host_]() { fn(pin_w, seqstart_q_pinned, seqstart_k_pinned); });
|
||||
HIP_CHECK_ERROR(hipLaunchHostFunc(
|
||||
stream,
|
||||
[](void* ud) {
|
||||
auto* c = static_cast<std::function<void()>*>(ud);
|
||||
(*c)();
|
||||
delete c;
|
||||
},
|
||||
pack_closure));
|
||||
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpyAsync(device_ws_ptr, pin_w, host_ws_size_, hipMemcpyHostToDevice, stream));
|
||||
|
||||
// Release any previous in-flight buffer before taking a new one.
|
||||
schedule_pin_staging_release();
|
||||
pin_staging_ = std::move(pin_base);
|
||||
release_stream_ = stream;
|
||||
}
|
||||
|
||||
private:
|
||||
size_t host_ws_size = 0;
|
||||
size_t device_ws_size = 0;
|
||||
std::unique_ptr<char[]> ws_host;
|
||||
fmha_bwd_traits traits_{};
|
||||
size_t host_ws_size_ = 0;
|
||||
bool needs_zero_dq_acc_ = false;
|
||||
// Pure CPU; safe to invoke from a hipLaunchHostFunc callback.
|
||||
std::function<void(void* host_ws, const int* seqstart_q, const int* seqstart_k)>
|
||||
pack_workspace_host_{[](void*, const int*, const int*) {
|
||||
std::cerr
|
||||
<< "fmha_bwd: no kernel found for given traits, skipping pack_workspace_host\n";
|
||||
}};
|
||||
std::shared_ptr<void> pin_staging_;
|
||||
hipStream_t release_stream_ = nullptr;
|
||||
|
||||
// The pin_staging_ deleter MUST NOT call any HIP API: it fires from the
|
||||
// hipLaunchHostFunc callback on the driver helper thread, which holds
|
||||
// runtime locks (would deadlock against main-thread hipFree). PyTorch's
|
||||
// CachingHostAllocator is safe; bare hipHostMalloc users should defer
|
||||
// hipHostFree via ck_tile::pinned_host_releaser.
|
||||
void schedule_pin_staging_release() noexcept
|
||||
{
|
||||
if(!pin_staging_)
|
||||
return;
|
||||
auto* heap_ref = new std::shared_ptr<void>(pin_staging_);
|
||||
const hipError_t err = hipLaunchHostFunc(
|
||||
release_stream_,
|
||||
[](void* ud) { delete static_cast<std::shared_ptr<void>*>(ud); },
|
||||
heap_ref);
|
||||
if(err != hipSuccess)
|
||||
{
|
||||
std::cerr << "fmha_bwd_launcher: hipLaunchHostFunc failed: " << hipGetErrorString(err)
|
||||
<< "; releasing eagerly\n";
|
||||
delete heap_ref;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T0 /*dot_do_o_trait*/,
|
||||
typename T1 /*dq_dk_dv_trait*/,
|
||||
typename T2 /*convert_dq_trait*/,
|
||||
typename Arch>
|
||||
void init(const fmha_bwd_traits& traits)
|
||||
void init(const fmha_bwd_traits& t)
|
||||
{
|
||||
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {
|
||||
traits_ = t;
|
||||
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {
|
||||
return fmha_bwd_<T0, T1, T2, Arch>(s, a);
|
||||
};
|
||||
host_ws_size = fmha_bwd_dq_dk_dv_dq_ws_host_size_<T1, Arch>(traits.batch);
|
||||
if(host_ws_size > 0)
|
||||
host_ws_size_ = fmha_bwd_dq_dk_dv_dq_ws_host_size_<T1, Arch>(t.batch);
|
||||
size_t device_ws_size = 0;
|
||||
if(host_ws_size_ > 0)
|
||||
{
|
||||
ws_host = std::make_unique<char[]>(host_ws_size); // TODO: support host mem allocator
|
||||
device_ws_size = fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<T1, Arch>( //
|
||||
ws_host.get(),
|
||||
traits.batch,
|
||||
traits.hdim_q,
|
||||
traits.nhead_q,
|
||||
traits.seqlen_q,
|
||||
traits.seqlen_k,
|
||||
traits.seqstart_qs,
|
||||
traits.seqstart_ks);
|
||||
device_ws_size = fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_<T1, Arch>(
|
||||
t.batch, t.hdim_q, t.nhead_q, t.max_seqlen_q, t.max_seqlen_k);
|
||||
pack_workspace_host_ = [batch = t.batch,
|
||||
hdim_q = t.hdim_q,
|
||||
nhead_q = t.nhead_q,
|
||||
seqlen_q = t.seqlen_q,
|
||||
seqlen_k = t.seqlen_k //
|
||||
](void* host_ws, const int* seqstart_q, const int* seqstart_k) {
|
||||
fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<T1, Arch>(
|
||||
host_ws, batch, hdim_q, nhead_q, seqlen_q, seqlen_k, seqstart_q, seqstart_k);
|
||||
};
|
||||
}
|
||||
workspace_size = host_ws_size + device_ws_size;
|
||||
const bool needs_zero_dq_acc = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<T1, Arch>();
|
||||
prepare_workspace = [this, needs_zero_dq_acc](void* device_ws) {
|
||||
if(host_ws_size > 0)
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpy(device_ws, ws_host.get(), host_ws_size, hipMemcpyHostToDevice));
|
||||
if(needs_zero_dq_acc)
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemset(static_cast<char*>(device_ws) + host_ws_size, 0, device_ws_size));
|
||||
};
|
||||
workspace_size = host_ws_size_ + device_ws_size;
|
||||
needs_zero_dq_acc_ = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<T1, Arch>();
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
@@ -8,10 +8,13 @@
|
||||
#include "utils.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#include "ck_tile/host/pinned_host_releaser.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <chrono>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
@@ -391,17 +394,39 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
p_drop > 0.0f,
|
||||
s_randval,
|
||||
deterministic,
|
||||
(mode == mode_enum::group) ? seqstart_q_host.data() : nullptr,
|
||||
(mode == mode_enum::group) ? seqstart_k_host.data() : nullptr,
|
||||
});
|
||||
const auto t1_launcher = std::chrono::high_resolution_clock::now();
|
||||
const double launcher_ctor_ms =
|
||||
std::chrono::duration<double, std::milli>(t1_launcher - t0_launcher).count();
|
||||
const size_t ws_size = launcher.workspace_size;
|
||||
ck_tile::DeviceMem ws_buf(ws_size);
|
||||
|
||||
// Stage seqstart to device before prepare_workspace_async (which D2Hs it back).
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
seqstart_k.ToDevice(seqstart_k_host.data());
|
||||
|
||||
// Pinned host allocator for the launcher's async prepare pipeline. The
|
||||
// shared_ptr deleter MUST NOT call any HIP API: it runs from the launcher's
|
||||
// tail hipLaunchHostFunc on the driver helper thread, which holds HIP
|
||||
// runtime locks. Deleter enqueues to a worker thread that hipHostFrees off
|
||||
// the callback path.
|
||||
auto pinned_host_alloc = [](size_t bytes) -> std::shared_ptr<void> {
|
||||
void* p = nullptr;
|
||||
HIP_CHECK_ERROR(hipHostMalloc(&p, bytes, hipHostMallocDefault));
|
||||
return std::shared_ptr<void>(
|
||||
p, [](void* q) { ck_tile::pinned_host_releaser::instance().enqueue(q); });
|
||||
};
|
||||
|
||||
ck_tile::gpu_timer prepare_ws_timer;
|
||||
prepare_ws_timer.start(nullptr);
|
||||
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
|
||||
launcher.prepare_workspace_async(
|
||||
ws_buf.GetDeviceBuffer(),
|
||||
(mode == mode_enum::group) ? static_cast<const int*>(seqstart_q.GetDeviceBuffer())
|
||||
: nullptr,
|
||||
(mode == mode_enum::group) ? static_cast<const int*>(seqstart_k.GetDeviceBuffer())
|
||||
: nullptr,
|
||||
stream_config,
|
||||
pinned_host_alloc);
|
||||
prepare_ws_timer.stop(nullptr);
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
@@ -409,8 +434,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
v_buf.ToDevice(v_host.data());
|
||||
bias_buf.ToDevice(bias_host.data());
|
||||
do_buf.ToDevice(do_host.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
seqstart_k.ToDevice(seqstart_k_host.data());
|
||||
// seqstart_q/k were already ToDevice'd above before prepare_workspace_async.
|
||||
if(mode == mode_enum::group)
|
||||
{
|
||||
std::vector<int32_t> seqlen_q_host(seqlen_qs.begin(), seqlen_qs.end());
|
||||
@@ -902,8 +926,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
dq_buf.ToDevice(dq_host.data());
|
||||
dk_buf.ToDevice(dk_host.data());
|
||||
dv_buf.ToDevice(dv_host.data());
|
||||
// re-initialize workspace for validation run
|
||||
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
|
||||
|
||||
o_buf.ToDevice(o_host.data());
|
||||
lse_buf.ToDevice(lse_host.data());
|
||||
@@ -912,6 +934,15 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
d_sink_buf.SetZero();
|
||||
|
||||
ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1};
|
||||
// re-initialize workspace for validation run
|
||||
launcher.prepare_workspace_async(
|
||||
ws_buf.GetDeviceBuffer(),
|
||||
(mode == mode_enum::group) ? static_cast<const int*>(seqstart_q.GetDeviceBuffer())
|
||||
: nullptr,
|
||||
(mode == mode_enum::group) ? static_cast<const int*>(seqstart_k.GetDeviceBuffer())
|
||||
: nullptr,
|
||||
stream_config_v,
|
||||
pinned_host_alloc);
|
||||
launcher(fmha_args, stream_config_v);
|
||||
|
||||
dq_buf.FromDevice(dq_host.data());
|
||||
|
||||
77
include/ck_tile/host/pinned_host_releaser.hpp
Normal file
77
include/ck_tile/host/pinned_host_releaser.hpp
Normal file
@@ -0,0 +1,77 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <condition_variable>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Defers hipHostFree off the HIP callback path. HIP callbacks hold runtime
|
||||
// locks, so calling hipHostFree (or any HIP API) from one deadlocks against
|
||||
// concurrent main-thread hipFree. enqueue() is HIP-API-free; a worker thread
|
||||
// drains the queue and calls hipHostFree. Use instance() for a process-wide
|
||||
// shared worker.
|
||||
class pinned_host_releaser
|
||||
{
|
||||
std::mutex mtx_;
|
||||
std::condition_variable cv_;
|
||||
std::queue<void*> q_;
|
||||
std::thread worker_;
|
||||
bool stop_ = false;
|
||||
|
||||
void run()
|
||||
{
|
||||
for(;;)
|
||||
{
|
||||
void* p = nullptr;
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(mtx_);
|
||||
cv_.wait(lk, [&] { return stop_ || !q_.empty(); });
|
||||
if(q_.empty())
|
||||
return; // stop_ && empty
|
||||
p = q_.front();
|
||||
q_.pop();
|
||||
}
|
||||
(void)hipHostFree(p);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
pinned_host_releaser() : worker_([this] { run(); }) {}
|
||||
|
||||
~pinned_host_releaser()
|
||||
{
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mtx_);
|
||||
stop_ = true;
|
||||
}
|
||||
cv_.notify_all();
|
||||
if(worker_.joinable())
|
||||
worker_.join();
|
||||
}
|
||||
|
||||
pinned_host_releaser(const pinned_host_releaser&) = delete;
|
||||
pinned_host_releaser& operator=(const pinned_host_releaser&) = delete;
|
||||
|
||||
static pinned_host_releaser& instance()
|
||||
{
|
||||
static pinned_host_releaser r;
|
||||
return r;
|
||||
}
|
||||
|
||||
void enqueue(void* p)
|
||||
{
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mtx_);
|
||||
q_.push(p);
|
||||
}
|
||||
cv_.notify_one();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -166,6 +166,47 @@ struct FmhaBwdWorkspaceManager
|
||||
// In these cases we need to zero out it first
|
||||
return kHasMask;
|
||||
}
|
||||
|
||||
// Mirrors PrepareWorkspaceHost's return value but uses worst-case totals so
|
||||
// device workspace can be pre-allocated before host has the seqstart values.
|
||||
template <bool kUseQrQtrDorPipeline, index_t kN0>
|
||||
CK_TILE_HOST static size_t GetWorkspaceDeviceSizeUpperBound(index_t max_batch,
|
||||
index_t hdim_q,
|
||||
index_t nhead_q,
|
||||
index_t max_seqlen_q,
|
||||
index_t max_seqlen_k)
|
||||
{
|
||||
if constexpr(kUseQrQtrDorPipeline)
|
||||
return 0;
|
||||
|
||||
if constexpr(!kIsDeterministic)
|
||||
{
|
||||
return sizeof(AccDataType) * static_cast<long_index_t>(max_batch) * nhead_q *
|
||||
max_seqlen_q * hdim_q;
|
||||
}
|
||||
else if constexpr(kIsGroupMode)
|
||||
{
|
||||
const index_t nsplits_max = integer_divide_ceil(max_seqlen_k, kN0);
|
||||
return sizeof(AccDataType) * static_cast<long_index_t>(max_batch) * nhead_q *
|
||||
nsplits_max * max_seqlen_q * hdim_q;
|
||||
}
|
||||
else // deterministic non-group mode (kUsePersistent)
|
||||
{
|
||||
const index_t dqdqkdv_workers = get_num_cus();
|
||||
const index_t jobs_per_head = integer_divide_ceil(max_seqlen_k, kN0);
|
||||
const index_t total_jobs = max_batch * nhead_q * jobs_per_head;
|
||||
const index_t jobs_per_worker = integer_divide_ceil(total_jobs, dqdqkdv_workers);
|
||||
index_t nsplits;
|
||||
if(jobs_per_head % jobs_per_worker == 0)
|
||||
nsplits = jobs_per_head / jobs_per_worker;
|
||||
else if(jobs_per_worker % jobs_per_head == 0)
|
||||
nsplits = 1;
|
||||
else
|
||||
nsplits = 1 + integer_divide_ceil(jobs_per_head - 1, jobs_per_worker);
|
||||
return sizeof(AccDataType) * static_cast<long_index_t>(max_batch) * nhead_q * nsplits *
|
||||
max_seqlen_q * hdim_q;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename FmhaPipeline_,
|
||||
@@ -281,6 +322,13 @@ struct FmhaBwdDQDKDVKernel
|
||||
FmhaPipeline::BlockFmhaShape::kN0>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
template <typename... Args>
|
||||
CK_TILE_HOST static constexpr auto GetWorkspaceDeviceSizeUpperBound(Args&&... args)
|
||||
{
|
||||
return WorkspaceManager::template GetWorkspaceDeviceSizeUpperBound<
|
||||
kUseQrQtrDorPipeline,
|
||||
FmhaPipeline::BlockFmhaShape::kN0>(std::forward<Args>(args)...);
|
||||
}
|
||||
CK_TILE_HOST static constexpr bool NeedsZeroDqAcc()
|
||||
{
|
||||
return WorkspaceManager::template NeedsZeroDqAcc<kUseQrQtrDorPipeline, kHasMask>();
|
||||
|
||||
Reference in New Issue
Block a user