[CK_TILE] Add async workspace prepare to FMHA BWD launcher

This commit is contained in:
Ding, Yi
2026-05-12 01:59:15 -05:00
parent 733bfee092
commit d434410e52
5 changed files with 297 additions and 42 deletions

View File

@@ -0,0 +1,77 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <condition_variable>
#include <hip/hip_runtime.h>
#include <mutex>
#include <queue>
#include <thread>
namespace ck_tile {
// Defers hipHostFree off the HIP callback path. HIP callbacks hold runtime
// locks, so calling hipHostFree (or any HIP API) from one deadlocks against
// concurrent main-thread hipFree. enqueue() is HIP-API-free; a worker thread
// drains the queue and calls hipHostFree. Use instance() for a process-wide
// shared worker.
class pinned_host_releaser
{
std::mutex mtx_;
std::condition_variable cv_;
std::queue<void*> q_;
std::thread worker_;
bool stop_ = false;
void run()
{
for(;;)
{
void* p = nullptr;
{
std::unique_lock<std::mutex> lk(mtx_);
cv_.wait(lk, [&] { return stop_ || !q_.empty(); });
if(q_.empty())
return; // stop_ && empty
p = q_.front();
q_.pop();
}
(void)hipHostFree(p);
}
}
public:
pinned_host_releaser() : worker_([this] { run(); }) {}
~pinned_host_releaser()
{
{
std::lock_guard<std::mutex> lk(mtx_);
stop_ = true;
}
cv_.notify_all();
if(worker_.joinable())
worker_.join();
}
pinned_host_releaser(const pinned_host_releaser&) = delete;
pinned_host_releaser& operator=(const pinned_host_releaser&) = delete;
static pinned_host_releaser& instance()
{
static pinned_host_releaser r;
return r;
}
void enqueue(void* p)
{
{
std::lock_guard<std::mutex> lk(mtx_);
q_.push(p);
}
cv_.notify_one();
}
};
} // namespace ck_tile

View File

@@ -166,6 +166,47 @@ struct FmhaBwdWorkspaceManager
// In these cases we need to zero out it first
return kHasMask;
}
// Mirrors PrepareWorkspaceHost's return value but uses worst-case totals so
// device workspace can be pre-allocated before host has the seqstart values.
template <bool kUseQrQtrDorPipeline, index_t kN0>
CK_TILE_HOST static size_t GetWorkspaceDeviceSizeUpperBound(index_t max_batch,
index_t hdim_q,
index_t nhead_q,
index_t max_seqlen_q,
index_t max_seqlen_k)
{
if constexpr(kUseQrQtrDorPipeline)
return 0;
if constexpr(!kIsDeterministic)
{
return sizeof(AccDataType) * static_cast<long_index_t>(max_batch) * nhead_q *
max_seqlen_q * hdim_q;
}
else if constexpr(kIsGroupMode)
{
const index_t nsplits_max = integer_divide_ceil(max_seqlen_k, kN0);
return sizeof(AccDataType) * static_cast<long_index_t>(max_batch) * nhead_q *
nsplits_max * max_seqlen_q * hdim_q;
}
else // deterministic non-group mode (kUsePersistent)
{
const index_t dqdqkdv_workers = get_num_cus();
const index_t jobs_per_head = integer_divide_ceil(max_seqlen_k, kN0);
const index_t total_jobs = max_batch * nhead_q * jobs_per_head;
const index_t jobs_per_worker = integer_divide_ceil(total_jobs, dqdqkdv_workers);
index_t nsplits;
if(jobs_per_head % jobs_per_worker == 0)
nsplits = jobs_per_head / jobs_per_worker;
else if(jobs_per_worker % jobs_per_head == 0)
nsplits = 1;
else
nsplits = 1 + integer_divide_ceil(jobs_per_head - 1, jobs_per_worker);
return sizeof(AccDataType) * static_cast<long_index_t>(max_batch) * nhead_q * nsplits *
max_seqlen_q * hdim_q;
}
}
};
template <typename FmhaPipeline_,
@@ -281,6 +322,13 @@ struct FmhaBwdDQDKDVKernel
FmhaPipeline::BlockFmhaShape::kN0>(
std::forward<Args>(args)...);
}
template <typename... Args>
CK_TILE_HOST static constexpr auto GetWorkspaceDeviceSizeUpperBound(Args&&... args)
{
return WorkspaceManager::template GetWorkspaceDeviceSizeUpperBound<
kUseQrQtrDorPipeline,
FmhaPipeline::BlockFmhaShape::kN0>(std::forward<Args>(args)...);
}
CK_TILE_HOST static constexpr bool NeedsZeroDqAcc()
{
return WorkspaceManager::template NeedsZeroDqAcc<kUseQrQtrDorPipeline, kHasMask>();