From 83566edb0fded5e1c618c2c19110adbb74532762 Mon Sep 17 00:00:00 2001 From: Yi DING <28386673+DDEle@users.noreply.github.com> Date: Thu, 14 May 2026 13:34:32 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#7331 (commit 5692db0) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [CK_TILE] Add async workspace prepare to FMHA BWD launcher (#7331) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation `aiter::mha_bwd` in group mode currently issues two synchronous `hipMemcpy` D2H copies to read `seqstart_q/k` for launcher construction. These sync copies block the host (~10–30 µs each) and implicitly synchronize the device by draining the stream, breaking CPU/GPU overlap on hot training paths. This PR adds a fully stream-async workspace preparation path on the FMHA BWD launcher so callers can pre-allocate the device workspace from upper-bound shapes and stage seqstart-dependent metadata via D2H/host-pack/H2D entirely on the user's stream. ## Technical Details - `FmhaBwdWorkspaceManager::GetWorkspaceDeviceSizeUpperBound` (`include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp`): computes the worst-case device dq_acc size from `(max_batch, hdim_q, nhead_q, max_seqlen_q, max_seqlen_k)` without dereferencing any seqstart array. Mirrors `PrepareWorkspaceHost`'s return value with worst-case bounds. - `fmha_bwd_launcher::prepare_workspace_async` (`example/ck_tile/01_fmha/fmha_bwd.hpp`): on the caller's stream, in order: 1. `hipMemsetAsync` of the dq_acc region (when `NeedsZeroDqAcc()`) 2. group mode: `hipMemcpyAsync` D2H of `seqstart_q/k` into a pinned host staging buffer 3. `hipLaunchHostFunc` runs `PrepareWorkspaceHost` on the pinned buffer 4. `hipMemcpyAsync` H2D of the packed metadata into `device_ws_ptr` The pinned staging buffer is held via `std::shared_ptr` returned by a caller-provided `pinned_host_alloc` callback. Lifetime is extended past stream completion by a tail `hipLaunchHostFunc` scheduled in the launcher's destructor. - `ck_tile::pinned_host_releaser` (`include/ck_tile/host/pinned_host_releaser.hpp`): worker-thread utility for callers using bare `hipHostMalloc`. Defers `hipHostFree` off the HIP driver callback thread, which holds runtime locks and would deadlock against concurrent main-thread `hipFree`. PyTorch's `CachingHostAllocator` does not need this. - Example runner (`example/ck_tile/01_fmha/fmha_bwd_runner.hpp`): switched to the async path. ## Test Plan - `tile_example_fmha_bwd` (gfx950, dev preset `-Werror -Weverything`): - batch + nondet / batch + det / group + nondet / group + det - group + det 4-batch varlen (`-b=4 -h=8 -s=4096,3072,2048,1024 -d=128`) - FA (`flash-attention`) integration on ROCm 7.1.1 + PyTorch 2.9.1: - `tests/test_flash_attn_ck.py::test_flash_attn_varlen_deterministic` - `tests/test_flash_attn_ck.py::test_flash_attn_bwd_varlen_seqq_zero` ## Test Result - All CK runner cases `valid:y`. - FA pytest: **1952 passed in 44.82s**. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 10 + example/ck_tile/01_fmha/fmha_bwd.hpp | 197 ++++++++++++++---- example/ck_tile/01_fmha/fmha_bwd_runner.hpp | 49 ++++- include/ck_tile/host/pinned_host_releaser.hpp | 77 +++++++ .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 50 +++++ 5 files changed, 339 insertions(+), 44 deletions(-) create mode 100644 include/ck_tile/host/pinned_host_releaser.hpp diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index f89a7d75e4..8da6eed212 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -175,6 +175,16 @@ size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_( return k_::GetWorkspaceHostSize(batch_size); }} +template <> +size_t fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_( + ck_tile::index_t max_batch, ck_tile::index_t hdim_q, ck_tile::index_t nhead_q, + ck_tile::index_t total_seqlen_q_padded, ck_tile::index_t max_seqlen_k) +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + return k_::GetWorkspaceDeviceSizeUpperBound( + max_batch, hdim_q, nhead_q, total_seqlen_q_padded, max_seqlen_k); +}} + template <> size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_( void* cpu_ws, ck_tile::index_t batch_size, ck_tile::index_t hdim_q, diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 14f4c210f0..a06e679cde 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -469,6 +469,15 @@ int fmha_bwd_dq_dk_dv_maxq_(); struct fmha_bwd_traits; template size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_(int batch_size); +// `total_seqlen_q_padded` is total q tokens across all batches (incl. per-batch padding): +// - batch mode: max_batch * seqlen_q +// - group mode: seqstart_q[batch] (== varlen q tensor's first dim) +template +size_t fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_(ck_tile::index_t max_batch, + ck_tile::index_t hdim_q, + ck_tile::index_t nhead_q, + ck_tile::index_t total_seqlen_q_padded, + ck_tile::index_t max_seqlen_k); template size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_(void* cpu_ws, ck_tile::index_t batch_size, @@ -539,11 +548,6 @@ struct fmha_bwd_traits bool has_dropout; bool is_store_randval; bool is_deterministic; - // Raw pointers for group mode: cumulative physical seqlen arrays of length batch+1. - // Only need to remain valid during fmha_bwd_launcher construction (i.e. through - // PrepareWorkspaceHost); they are not retained afterward. - const int* seqstart_qs = nullptr; - const int* seqstart_ks = nullptr; // TODO: padding check is inside this api }; @@ -589,53 +593,176 @@ struct fmha_bwd_launcher std::cerr << "fmha_bwd: no kernel found for given traits, skipping run\n"; return -1.0f; }}; + // Layout: [host_ws_size_ bytes (host-prepared metadata)][dq_acc region] size_t workspace_size = 0; - std::function prepare_workspace{[](void*) { - std::cerr << "fmha_bwd: no kernel found for given traits, skipping prepare_workspace\n"; - }}; fmha_bwd_launcher(const fmha_bwd_traits&); fmha_bwd_launcher(fmha_bwd_launcher&&) = delete; fmha_bwd_launcher& operator=(fmha_bwd_launcher&&) = delete; + ~fmha_bwd_launcher() noexcept { schedule_pin_staging_release(); } + + // Stream-async: zero dq_acc, D2H seqstart, host-pack metadata, H2D into device_ws. + // `pinned_host_alloc` returns a shared_ptr to a pinned host buffer; its deleter + // is invoked on the stream tail after the H2D completes. + void prepare_workspace_async( // + void* device_ws_ptr, + const int* seqstart_q_dev, + const int* seqstart_k_dev, + const ck_tile::stream_config& s, + const std::function(size_t)>& pinned_host_alloc) + { + hipStream_t stream = s.stream_id_; + + // Fast path: no host-side metadata to stage; just zero dq_acc if needed. + if(host_ws_size_ == 0) + { + if(needs_zero_dq_acc_ && workspace_size > 0) + HIP_CHECK_ERROR(hipMemsetAsync(device_ws_ptr, 0, workspace_size, stream)); + return; + } + + if(!pinned_host_alloc) + throw std::runtime_error( + "fmha_bwd_launcher::prepare_workspace_async: pinned_host_alloc is required"); + + // Allocate pinned host staging first: if it throws we haven't issued any + // stream work yet, leaving the workspace cleanly un-prepared. + const size_t seqstart_bytes = traits_.is_group_mode ? sizeof(int) * (traits_.batch + 1) : 0; + const size_t total_bytes = 2 * seqstart_bytes + host_ws_size_; + auto pin_base = pinned_host_alloc(total_bytes); + + if(needs_zero_dq_acc_ && workspace_size > host_ws_size_) + HIP_CHECK_ERROR(hipMemsetAsync(static_cast(device_ws_ptr) + host_ws_size_, + 0, + workspace_size - host_ws_size_, + stream)); + + char* base = static_cast(pin_base.get()); + int* pin_q = reinterpret_cast(base); + int* pin_k = reinterpret_cast(base + seqstart_bytes); + void* pin_w = base + 2 * seqstart_bytes; + const int* seqstart_q_pinned = traits_.is_group_mode ? pin_q : nullptr; + const int* seqstart_k_pinned = traits_.is_group_mode ? pin_k : nullptr; + + if(traits_.is_group_mode) + { + if(!seqstart_q_dev || !seqstart_k_dev) + throw std::runtime_error("fmha_bwd_launcher::prepare_workspace_async: " + "seqstart_q_dev and seqstart_k_dev are required in " + "group mode"); + HIP_CHECK_ERROR(hipMemcpyAsync( + pin_q, seqstart_q_dev, seqstart_bytes, hipMemcpyDeviceToHost, stream)); + HIP_CHECK_ERROR(hipMemcpyAsync( + pin_k, seqstart_k_dev, seqstart_bytes, hipMemcpyDeviceToHost, stream)); + } + + auto pack_closure = std::make_unique>( + [=, fn = pack_workspace_host_]() { fn(pin_w, seqstart_q_pinned, seqstart_k_pinned); }); + // Callback runs on the HIP driver helper thread across a C ABI boundary; + // any exception escaping it would call std::terminate. + HIP_CHECK_ERROR(hipLaunchHostFunc( + stream, + [](void* ud) { + std::unique_ptr> c{static_cast*>(ud)}; + try + { + (*c)(); + } + catch(const std::exception& e) + { + // The H2D queued after this callback will copy indeterminate + // metadata to device and the kernel will produce wrong results; + // unlikely in practice since pack_workspace_host_ only throws on + // precondition violations. + std::cerr << "fmha_bwd_launcher: pack_workspace_host threw: " << e.what() + << '\n'; + } + catch(...) + { + std::cerr << "fmha_bwd_launcher: pack_workspace_host threw unknown\n"; + } + }, + pack_closure.get())); + // Ownership transferred to the callback only after a successful launch. + pack_closure.release(); + + HIP_CHECK_ERROR( + hipMemcpyAsync(device_ws_ptr, pin_w, host_ws_size_, hipMemcpyHostToDevice, stream)); + + // Release any previous in-flight buffer before taking a new one. + schedule_pin_staging_release(); + pin_staging_ = std::move(pin_base); + release_stream_ = stream; + } + private: - size_t host_ws_size = 0; - size_t device_ws_size = 0; - std::unique_ptr ws_host; + fmha_bwd_traits traits_{}; + size_t host_ws_size_ = 0; + bool needs_zero_dq_acc_ = false; + // Pure CPU; safe to invoke from a hipLaunchHostFunc callback. + std::function + pack_workspace_host_{[](void*, const int*, const int*) { + std::cerr + << "fmha_bwd: no kernel found for given traits, skipping pack_workspace_host\n"; + }}; + std::shared_ptr pin_staging_; + hipStream_t release_stream_ = nullptr; + + // The pin_staging_ deleter MUST NOT call any HIP API: it fires from the + // hipLaunchHostFunc callback on the driver helper thread, which holds + // runtime locks (would deadlock against main-thread hipFree). PyTorch's + // CachingHostAllocator is safe; bare hipHostMalloc users should defer + // hipHostFree via ck_tile::pinned_host_releaser. + void schedule_pin_staging_release() noexcept + { + if(!pin_staging_) + return; + auto* heap_ref = new std::shared_ptr(std::move(pin_staging_)); + const hipError_t err = hipLaunchHostFunc( + release_stream_, + [](void* ud) { delete static_cast*>(ud); }, + heap_ref); + if(err != hipSuccess) + { + std::cerr << "fmha_bwd_launcher: hipLaunchHostFunc failed: " << hipGetErrorString(err) + << "; releasing eagerly\n"; + delete heap_ref; + } + } template - void init(const fmha_bwd_traits& traits) + void init(const fmha_bwd_traits& t) { - run = [](fmha_bwd_args a, const ck_tile::stream_config& s) { + traits_ = t; + run = [](fmha_bwd_args a, const ck_tile::stream_config& s) { return fmha_bwd_(s, a); }; - host_ws_size = fmha_bwd_dq_dk_dv_dq_ws_host_size_(traits.batch); - if(host_ws_size > 0) + host_ws_size_ = fmha_bwd_dq_dk_dv_dq_ws_host_size_(t.batch); + size_t device_ws_size = 0; + if(host_ws_size_ > 0) { - ws_host = std::make_unique(host_ws_size); // TODO: support host mem allocator - device_ws_size = fmha_bwd_dq_dk_dv_dq_prepare_ws_host_( // - ws_host.get(), - traits.batch, - traits.hdim_q, - traits.nhead_q, - traits.seqlen_q, - traits.seqlen_k, - traits.seqstart_qs, - traits.seqstart_ks); + // In group mode t.seqlen_q is already the padded total (== seqstart_q[batch]); + // in batch mode it's per-batch and the total is batch * seqlen_q. + const ck_tile::index_t total_seqlen_q_padded = + t.is_group_mode ? t.seqlen_q : t.batch * t.seqlen_q; + device_ws_size = fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_( + t.batch, t.hdim_q, t.nhead_q, total_seqlen_q_padded, t.max_seqlen_k); + pack_workspace_host_ = [batch = t.batch, + hdim_q = t.hdim_q, + nhead_q = t.nhead_q, + seqlen_q = t.seqlen_q, + seqlen_k = t.seqlen_k // + ](void* host_ws, const int* seqstart_q, const int* seqstart_k) { + fmha_bwd_dq_dk_dv_dq_prepare_ws_host_( + host_ws, batch, hdim_q, nhead_q, seqlen_q, seqlen_k, seqstart_q, seqstart_k); + }; } - workspace_size = host_ws_size + device_ws_size; - const bool needs_zero_dq_acc = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_(); - prepare_workspace = [this, needs_zero_dq_acc](void* device_ws) { - if(host_ws_size > 0) - HIP_CHECK_ERROR( - hipMemcpy(device_ws, ws_host.get(), host_ws_size, hipMemcpyHostToDevice)); - if(needs_zero_dq_acc) - HIP_CHECK_ERROR( - hipMemset(static_cast(device_ws) + host_ws_size, 0, device_ws_size)); - }; + workspace_size = host_ws_size_ + device_ws_size; + needs_zero_dq_acc_ = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_(); } public: diff --git a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp index ac86bf4635..b99649074d 100644 --- a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp @@ -8,10 +8,13 @@ #include "utils.hpp" #include "ck_tile/utility/json_dump.hpp" +#include "ck_tile/host/pinned_host_releaser.hpp" + #include #include #include #include +#include #include #include #include @@ -391,26 +394,47 @@ bwd_result fmha_bwd_run(mode_enum mode, p_drop > 0.0f, s_randval, deterministic, - (mode == mode_enum::group) ? seqstart_q_host.data() : nullptr, - (mode == mode_enum::group) ? seqstart_k_host.data() : nullptr, }); const auto t1_launcher = std::chrono::high_resolution_clock::now(); const double launcher_ctor_ms = std::chrono::duration(t1_launcher - t0_launcher).count(); const size_t ws_size = launcher.workspace_size; ck_tile::DeviceMem ws_buf(ws_size); + + // Stage seqstart to device before prepare_workspace_async (which D2Hs it back). + seqstart_q.ToDevice(seqstart_q_host.data()); + seqstart_k.ToDevice(seqstart_k_host.data()); + + // Pinned host allocator for the launcher's async prepare pipeline. The + // shared_ptr deleter MUST NOT call any HIP API: it runs from the launcher's + // tail hipLaunchHostFunc on the driver helper thread, which holds HIP + // runtime locks. Deleter enqueues to a worker thread that hipHostFrees off + // the callback path. + auto pinned_host_alloc = [](size_t bytes) -> std::shared_ptr { + void* p = nullptr; + HIP_CHECK_ERROR(hipHostMalloc(&p, bytes, hipHostMallocDefault)); + return std::shared_ptr( + p, [](void* q) { ck_tile::pinned_host_releaser::instance().enqueue(q); }); + }; + ck_tile::gpu_timer prepare_ws_timer; - prepare_ws_timer.start(nullptr); - launcher.prepare_workspace(ws_buf.GetDeviceBuffer()); - prepare_ws_timer.stop(nullptr); + prepare_ws_timer.start(stream_config.stream_id_); + launcher.prepare_workspace_async( + ws_buf.GetDeviceBuffer(), + (mode == mode_enum::group) ? static_cast(seqstart_q.GetDeviceBuffer()) + : nullptr, + (mode == mode_enum::group) ? static_cast(seqstart_k.GetDeviceBuffer()) + : nullptr, + stream_config, + pinned_host_alloc); + prepare_ws_timer.stop(stream_config.stream_id_); q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); v_buf.ToDevice(v_host.data()); bias_buf.ToDevice(bias_host.data()); do_buf.ToDevice(do_host.data()); - seqstart_q.ToDevice(seqstart_q_host.data()); - seqstart_k.ToDevice(seqstart_k_host.data()); + // seqstart_q/k were already ToDevice'd above before prepare_workspace_async. if(mode == mode_enum::group) { std::vector seqlen_q_host(seqlen_qs.begin(), seqlen_qs.end()); @@ -902,8 +926,6 @@ bwd_result fmha_bwd_run(mode_enum mode, dq_buf.ToDevice(dq_host.data()); dk_buf.ToDevice(dk_host.data()); dv_buf.ToDevice(dv_host.data()); - // re-initialize workspace for validation run - launcher.prepare_workspace(ws_buf.GetDeviceBuffer()); o_buf.ToDevice(o_host.data()); lse_buf.ToDevice(lse_host.data()); @@ -912,6 +934,15 @@ bwd_result fmha_bwd_run(mode_enum mode, d_sink_buf.SetZero(); ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1}; + // re-initialize workspace for validation run + launcher.prepare_workspace_async( + ws_buf.GetDeviceBuffer(), + (mode == mode_enum::group) ? static_cast(seqstart_q.GetDeviceBuffer()) + : nullptr, + (mode == mode_enum::group) ? static_cast(seqstart_k.GetDeviceBuffer()) + : nullptr, + stream_config_v, + pinned_host_alloc); launcher(fmha_args, stream_config_v); dq_buf.FromDevice(dq_host.data()); diff --git a/include/ck_tile/host/pinned_host_releaser.hpp b/include/ck_tile/host/pinned_host_releaser.hpp new file mode 100644 index 0000000000..8a24d5b201 --- /dev/null +++ b/include/ck_tile/host/pinned_host_releaser.hpp @@ -0,0 +1,77 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include + +namespace ck_tile { + +// Defers hipHostFree off the HIP callback path. HIP callbacks hold runtime +// locks, so calling hipHostFree (or any HIP API) from one deadlocks against +// concurrent main-thread hipFree. enqueue() is HIP-API-free; a worker thread +// drains the queue and calls hipHostFree. Use instance() for a process-wide +// shared worker. +class pinned_host_releaser +{ + std::mutex mtx_; + std::condition_variable cv_; + std::queue q_; + std::thread worker_; + bool stop_ = false; + + void run() + { + for(;;) + { + void* p = nullptr; + { + std::unique_lock lk(mtx_); + cv_.wait(lk, [&] { return stop_ || !q_.empty(); }); + if(q_.empty()) + return; // stop_ && empty + p = q_.front(); + q_.pop(); + } + (void)hipHostFree(p); + } + } + + public: + pinned_host_releaser() : worker_([this] { run(); }) {} + + ~pinned_host_releaser() + { + { + std::lock_guard lk(mtx_); + stop_ = true; + } + cv_.notify_all(); + if(worker_.joinable()) + worker_.join(); + } + + pinned_host_releaser(const pinned_host_releaser&) = delete; + pinned_host_releaser& operator=(const pinned_host_releaser&) = delete; + + static pinned_host_releaser& instance() + { + static pinned_host_releaser r; + return r; + } + + void enqueue(void* p) + { + { + std::lock_guard lk(mtx_); + q_.push(p); + } + cv_.notify_one(); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 23c73e5f43..7aff21530d 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -166,6 +166,49 @@ struct FmhaBwdWorkspaceManager // In these cases we need to zero out it first return kHasMask; } + + // Upper bound on PrepareWorkspaceHost's size, computable without seqstart so + // the device workspace can be allocated before any D2H. + // + // total_seqlen_q_padded: total q tokens incl. per-batch padding. + // Batch: max_batch * seqlen_q. Group: seqstart_q[batch]. + // max_seqlen_k: deterministic-only; pass per-batch padded max if the caller + // does internal k padding, otherwise the logical max is fine. + template + CK_TILE_HOST static size_t GetWorkspaceDeviceSizeUpperBound(index_t max_batch, + index_t hdim_q, + index_t nhead_q, + index_t total_seqlen_q_padded, + index_t max_seqlen_k) + { + if constexpr(kUseQrQtrDorPipeline) + return 0; + + index_t nsplits_factor = 1; + if constexpr(kIsDeterministic) + { + if constexpr(kIsGroupMode) + { + nsplits_factor = integer_divide_ceil(max_seqlen_k, kN0); + } + else // persistent + { + const index_t dqdqkdv_workers = get_num_cus(); + const index_t jobs_per_head = integer_divide_ceil(max_seqlen_k, kN0); + const index_t total_jobs = max_batch * nhead_q * jobs_per_head; + const index_t jobs_per_worker = integer_divide_ceil(total_jobs, dqdqkdv_workers); + if(jobs_per_head % jobs_per_worker == 0) + nsplits_factor = jobs_per_head / jobs_per_worker; + else if(jobs_per_worker % jobs_per_head == 0) + nsplits_factor = 1; + else + nsplits_factor = 1 + integer_divide_ceil(jobs_per_head - 1, jobs_per_worker); + } + } + + return sizeof(AccDataType) * static_cast(nhead_q) * nsplits_factor * + total_seqlen_q_padded * hdim_q; + } }; template ( std::forward(args)...); } + template + CK_TILE_HOST static size_t GetWorkspaceDeviceSizeUpperBound(Args&&... args) + { + return WorkspaceManager::template GetWorkspaceDeviceSizeUpperBound< + kUseQrQtrDorPipeline, + FmhaPipeline::BlockFmhaShape::kN0>(std::forward(args)...); + } CK_TILE_HOST static constexpr bool NeedsZeroDqAcc() { return WorkspaceManager::template NeedsZeroDqAcc();