diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index edcae41b49..d6493eb533 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -175,6 +175,16 @@ size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_( return k_::GetWorkspaceHostSize(batch_size); }} +template <> +size_t fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_( + ck_tile::index_t max_batch, ck_tile::index_t hdim_q, ck_tile::index_t nhead_q, + ck_tile::index_t max_seqlen_q, ck_tile::index_t max_seqlen_k) +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + return k_::GetWorkspaceDeviceSizeUpperBound( + max_batch, hdim_q, nhead_q, max_seqlen_q, max_seqlen_k); +}} + template <> size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_( void* cpu_ws, ck_tile::index_t batch_size, ck_tile::index_t hdim_q, diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 14f4c210f0..1cf3581859 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -470,6 +470,12 @@ struct fmha_bwd_traits; template size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_(int batch_size); template +size_t fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_(ck_tile::index_t max_batch, + ck_tile::index_t hdim_q, + ck_tile::index_t nhead_q, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t max_seqlen_k); +template size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_(void* cpu_ws, ck_tile::index_t batch_size, ck_tile::index_t hdim_q, @@ -539,11 +545,6 @@ struct fmha_bwd_traits bool has_dropout; bool is_store_randval; bool is_deterministic; - // Raw pointers for group mode: cumulative physical seqlen arrays of length batch+1. - // Only need to remain valid during fmha_bwd_launcher construction (i.e. through - // PrepareWorkspaceHost); they are not retained afterward. - const int* seqstart_qs = nullptr; - const int* seqstart_ks = nullptr; // TODO: padding check is inside this api }; @@ -589,53 +590,141 @@ struct fmha_bwd_launcher std::cerr << "fmha_bwd: no kernel found for given traits, skipping run\n"; return -1.0f; }}; + // Layout: [host_ws_size_ bytes (host-prepared metadata)][dq_acc region] size_t workspace_size = 0; - std::function prepare_workspace{[](void*) { - std::cerr << "fmha_bwd: no kernel found for given traits, skipping prepare_workspace\n"; - }}; fmha_bwd_launcher(const fmha_bwd_traits&); fmha_bwd_launcher(fmha_bwd_launcher&&) = delete; fmha_bwd_launcher& operator=(fmha_bwd_launcher&&) = delete; + ~fmha_bwd_launcher() noexcept { schedule_pin_staging_release(); } + + // Stream-async: zero dq_acc, D2H seqstart, host-pack metadata, H2D into device_ws. + // `pinned_host_alloc` returns a shared_ptr to a pinned host buffer; its deleter + // is invoked on the stream tail after the H2D completes. + void prepare_workspace_async( // + void* device_ws_ptr, + const int* seqstart_q_dev, + const int* seqstart_k_dev, + const ck_tile::stream_config& s, + const std::function(size_t)>& pinned_host_alloc) + { + hipStream_t stream = s.stream_id_; + if(needs_zero_dq_acc_ && workspace_size > host_ws_size_) + HIP_CHECK_ERROR(hipMemsetAsync(static_cast(device_ws_ptr) + host_ws_size_, + 0, + workspace_size - host_ws_size_, + stream)); + + if(host_ws_size_ == 0) + return; + + if(!pinned_host_alloc) + throw std::runtime_error( + "fmha_bwd_launcher::prepare_workspace_async: pinned_host_alloc is required"); + + const size_t seqstart_bytes = traits_.is_group_mode ? sizeof(int) * (traits_.batch + 1) : 0; + const size_t total_bytes = 2 * seqstart_bytes + host_ws_size_; + auto pin_base = pinned_host_alloc(total_bytes); + + char* base = static_cast(pin_base.get()); + int* pin_q = reinterpret_cast(base); + int* pin_k = reinterpret_cast(base + seqstart_bytes); + void* pin_w = base + 2 * seqstart_bytes; + const int* seqstart_q_pinned = traits_.is_group_mode ? pin_q : nullptr; + const int* seqstart_k_pinned = traits_.is_group_mode ? pin_k : nullptr; + + if(traits_.is_group_mode) + { + HIP_CHECK_ERROR(hipMemcpyAsync( + pin_q, seqstart_q_dev, seqstart_bytes, hipMemcpyDeviceToHost, stream)); + HIP_CHECK_ERROR(hipMemcpyAsync( + pin_k, seqstart_k_dev, seqstart_bytes, hipMemcpyDeviceToHost, stream)); + } + + auto* pack_closure = new std::function( + [=, fn = pack_workspace_host_]() { fn(pin_w, seqstart_q_pinned, seqstart_k_pinned); }); + HIP_CHECK_ERROR(hipLaunchHostFunc( + stream, + [](void* ud) { + auto* c = static_cast*>(ud); + (*c)(); + delete c; + }, + pack_closure)); + + HIP_CHECK_ERROR( + hipMemcpyAsync(device_ws_ptr, pin_w, host_ws_size_, hipMemcpyHostToDevice, stream)); + + // Release any previous in-flight buffer before taking a new one. + schedule_pin_staging_release(); + pin_staging_ = std::move(pin_base); + release_stream_ = stream; + } + private: - size_t host_ws_size = 0; - size_t device_ws_size = 0; - std::unique_ptr ws_host; + fmha_bwd_traits traits_{}; + size_t host_ws_size_ = 0; + bool needs_zero_dq_acc_ = false; + // Pure CPU; safe to invoke from a hipLaunchHostFunc callback. + std::function + pack_workspace_host_{[](void*, const int*, const int*) { + std::cerr + << "fmha_bwd: no kernel found for given traits, skipping pack_workspace_host\n"; + }}; + std::shared_ptr pin_staging_; + hipStream_t release_stream_ = nullptr; + + // The pin_staging_ deleter MUST NOT call any HIP API: it fires from the + // hipLaunchHostFunc callback on the driver helper thread, which holds + // runtime locks (would deadlock against main-thread hipFree). PyTorch's + // CachingHostAllocator is safe; bare hipHostMalloc users should defer + // hipHostFree via ck_tile::pinned_host_releaser. + void schedule_pin_staging_release() noexcept + { + if(!pin_staging_) + return; + auto* heap_ref = new std::shared_ptr(pin_staging_); + const hipError_t err = hipLaunchHostFunc( + release_stream_, + [](void* ud) { delete static_cast*>(ud); }, + heap_ref); + if(err != hipSuccess) + { + std::cerr << "fmha_bwd_launcher: hipLaunchHostFunc failed: " << hipGetErrorString(err) + << "; releasing eagerly\n"; + delete heap_ref; + } + } template - void init(const fmha_bwd_traits& traits) + void init(const fmha_bwd_traits& t) { - run = [](fmha_bwd_args a, const ck_tile::stream_config& s) { + traits_ = t; + run = [](fmha_bwd_args a, const ck_tile::stream_config& s) { return fmha_bwd_(s, a); }; - host_ws_size = fmha_bwd_dq_dk_dv_dq_ws_host_size_(traits.batch); - if(host_ws_size > 0) + host_ws_size_ = fmha_bwd_dq_dk_dv_dq_ws_host_size_(t.batch); + size_t device_ws_size = 0; + if(host_ws_size_ > 0) { - ws_host = std::make_unique(host_ws_size); // TODO: support host mem allocator - device_ws_size = fmha_bwd_dq_dk_dv_dq_prepare_ws_host_( // - ws_host.get(), - traits.batch, - traits.hdim_q, - traits.nhead_q, - traits.seqlen_q, - traits.seqlen_k, - traits.seqstart_qs, - traits.seqstart_ks); + device_ws_size = fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_( + t.batch, t.hdim_q, t.nhead_q, t.max_seqlen_q, t.max_seqlen_k); + pack_workspace_host_ = [batch = t.batch, + hdim_q = t.hdim_q, + nhead_q = t.nhead_q, + seqlen_q = t.seqlen_q, + seqlen_k = t.seqlen_k // + ](void* host_ws, const int* seqstart_q, const int* seqstart_k) { + fmha_bwd_dq_dk_dv_dq_prepare_ws_host_( + host_ws, batch, hdim_q, nhead_q, seqlen_q, seqlen_k, seqstart_q, seqstart_k); + }; } - workspace_size = host_ws_size + device_ws_size; - const bool needs_zero_dq_acc = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_(); - prepare_workspace = [this, needs_zero_dq_acc](void* device_ws) { - if(host_ws_size > 0) - HIP_CHECK_ERROR( - hipMemcpy(device_ws, ws_host.get(), host_ws_size, hipMemcpyHostToDevice)); - if(needs_zero_dq_acc) - HIP_CHECK_ERROR( - hipMemset(static_cast(device_ws) + host_ws_size, 0, device_ws_size)); - }; + workspace_size = host_ws_size_ + device_ws_size; + needs_zero_dq_acc_ = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_(); } public: diff --git a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp index f81ae34501..3a328b96d7 100644 --- a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp @@ -8,10 +8,13 @@ #include "utils.hpp" #include "ck_tile/utility/json_dump.hpp" +#include "ck_tile/host/pinned_host_releaser.hpp" + #include #include #include #include +#include #include #include #include @@ -391,17 +394,39 @@ bwd_result fmha_bwd_run(mode_enum mode, p_drop > 0.0f, s_randval, deterministic, - (mode == mode_enum::group) ? seqstart_q_host.data() : nullptr, - (mode == mode_enum::group) ? seqstart_k_host.data() : nullptr, }); const auto t1_launcher = std::chrono::high_resolution_clock::now(); const double launcher_ctor_ms = std::chrono::duration(t1_launcher - t0_launcher).count(); const size_t ws_size = launcher.workspace_size; ck_tile::DeviceMem ws_buf(ws_size); + + // Stage seqstart to device before prepare_workspace_async (which D2Hs it back). + seqstart_q.ToDevice(seqstart_q_host.data()); + seqstart_k.ToDevice(seqstart_k_host.data()); + + // Pinned host allocator for the launcher's async prepare pipeline. The + // shared_ptr deleter MUST NOT call any HIP API: it runs from the launcher's + // tail hipLaunchHostFunc on the driver helper thread, which holds HIP + // runtime locks. Deleter enqueues to a worker thread that hipHostFrees off + // the callback path. + auto pinned_host_alloc = [](size_t bytes) -> std::shared_ptr { + void* p = nullptr; + HIP_CHECK_ERROR(hipHostMalloc(&p, bytes, hipHostMallocDefault)); + return std::shared_ptr( + p, [](void* q) { ck_tile::pinned_host_releaser::instance().enqueue(q); }); + }; + ck_tile::gpu_timer prepare_ws_timer; prepare_ws_timer.start(nullptr); - launcher.prepare_workspace(ws_buf.GetDeviceBuffer()); + launcher.prepare_workspace_async( + ws_buf.GetDeviceBuffer(), + (mode == mode_enum::group) ? static_cast(seqstart_q.GetDeviceBuffer()) + : nullptr, + (mode == mode_enum::group) ? static_cast(seqstart_k.GetDeviceBuffer()) + : nullptr, + stream_config, + pinned_host_alloc); prepare_ws_timer.stop(nullptr); q_buf.ToDevice(q_host.data()); @@ -409,8 +434,7 @@ bwd_result fmha_bwd_run(mode_enum mode, v_buf.ToDevice(v_host.data()); bias_buf.ToDevice(bias_host.data()); do_buf.ToDevice(do_host.data()); - seqstart_q.ToDevice(seqstart_q_host.data()); - seqstart_k.ToDevice(seqstart_k_host.data()); + // seqstart_q/k were already ToDevice'd above before prepare_workspace_async. if(mode == mode_enum::group) { std::vector seqlen_q_host(seqlen_qs.begin(), seqlen_qs.end()); @@ -902,8 +926,6 @@ bwd_result fmha_bwd_run(mode_enum mode, dq_buf.ToDevice(dq_host.data()); dk_buf.ToDevice(dk_host.data()); dv_buf.ToDevice(dv_host.data()); - // re-initialize workspace for validation run - launcher.prepare_workspace(ws_buf.GetDeviceBuffer()); o_buf.ToDevice(o_host.data()); lse_buf.ToDevice(lse_host.data()); @@ -912,6 +934,15 @@ bwd_result fmha_bwd_run(mode_enum mode, d_sink_buf.SetZero(); ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1}; + // re-initialize workspace for validation run + launcher.prepare_workspace_async( + ws_buf.GetDeviceBuffer(), + (mode == mode_enum::group) ? static_cast(seqstart_q.GetDeviceBuffer()) + : nullptr, + (mode == mode_enum::group) ? static_cast(seqstart_k.GetDeviceBuffer()) + : nullptr, + stream_config_v, + pinned_host_alloc); launcher(fmha_args, stream_config_v); dq_buf.FromDevice(dq_host.data()); diff --git a/include/ck_tile/host/pinned_host_releaser.hpp b/include/ck_tile/host/pinned_host_releaser.hpp new file mode 100644 index 0000000000..8a24d5b201 --- /dev/null +++ b/include/ck_tile/host/pinned_host_releaser.hpp @@ -0,0 +1,77 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include + +namespace ck_tile { + +// Defers hipHostFree off the HIP callback path. HIP callbacks hold runtime +// locks, so calling hipHostFree (or any HIP API) from one deadlocks against +// concurrent main-thread hipFree. enqueue() is HIP-API-free; a worker thread +// drains the queue and calls hipHostFree. Use instance() for a process-wide +// shared worker. +class pinned_host_releaser +{ + std::mutex mtx_; + std::condition_variable cv_; + std::queue q_; + std::thread worker_; + bool stop_ = false; + + void run() + { + for(;;) + { + void* p = nullptr; + { + std::unique_lock lk(mtx_); + cv_.wait(lk, [&] { return stop_ || !q_.empty(); }); + if(q_.empty()) + return; // stop_ && empty + p = q_.front(); + q_.pop(); + } + (void)hipHostFree(p); + } + } + + public: + pinned_host_releaser() : worker_([this] { run(); }) {} + + ~pinned_host_releaser() + { + { + std::lock_guard lk(mtx_); + stop_ = true; + } + cv_.notify_all(); + if(worker_.joinable()) + worker_.join(); + } + + pinned_host_releaser(const pinned_host_releaser&) = delete; + pinned_host_releaser& operator=(const pinned_host_releaser&) = delete; + + static pinned_host_releaser& instance() + { + static pinned_host_releaser r; + return r; + } + + void enqueue(void* p) + { + { + std::lock_guard lk(mtx_); + q_.push(p); + } + cv_.notify_one(); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 23c73e5f43..8716c93bfb 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -166,6 +166,47 @@ struct FmhaBwdWorkspaceManager // In these cases we need to zero out it first return kHasMask; } + + // Mirrors PrepareWorkspaceHost's return value but uses worst-case totals so + // device workspace can be pre-allocated before host has the seqstart values. + template + CK_TILE_HOST static size_t GetWorkspaceDeviceSizeUpperBound(index_t max_batch, + index_t hdim_q, + index_t nhead_q, + index_t max_seqlen_q, + index_t max_seqlen_k) + { + if constexpr(kUseQrQtrDorPipeline) + return 0; + + if constexpr(!kIsDeterministic) + { + return sizeof(AccDataType) * static_cast(max_batch) * nhead_q * + max_seqlen_q * hdim_q; + } + else if constexpr(kIsGroupMode) + { + const index_t nsplits_max = integer_divide_ceil(max_seqlen_k, kN0); + return sizeof(AccDataType) * static_cast(max_batch) * nhead_q * + nsplits_max * max_seqlen_q * hdim_q; + } + else // deterministic non-group mode (kUsePersistent) + { + const index_t dqdqkdv_workers = get_num_cus(); + const index_t jobs_per_head = integer_divide_ceil(max_seqlen_k, kN0); + const index_t total_jobs = max_batch * nhead_q * jobs_per_head; + const index_t jobs_per_worker = integer_divide_ceil(total_jobs, dqdqkdv_workers); + index_t nsplits; + if(jobs_per_head % jobs_per_worker == 0) + nsplits = jobs_per_head / jobs_per_worker; + else if(jobs_per_worker % jobs_per_head == 0) + nsplits = 1; + else + nsplits = 1 + integer_divide_ceil(jobs_per_head - 1, jobs_per_worker); + return sizeof(AccDataType) * static_cast(max_batch) * nhead_q * nsplits * + max_seqlen_q * hdim_q; + } + } }; template ( std::forward(args)...); } + template + CK_TILE_HOST static constexpr auto GetWorkspaceDeviceSizeUpperBound(Args&&... args) + { + return WorkspaceManager::template GetWorkspaceDeviceSizeUpperBound< + kUseQrQtrDorPipeline, + FmhaPipeline::BlockFmhaShape::kN0>(std::forward(args)...); + } CK_TILE_HOST static constexpr bool NeedsZeroDqAcc() { return WorkspaceManager::template NeedsZeroDqAcc();