mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
[CK_TILE] Add async workspace prepare to FMHA BWD launcher
This commit is contained in:
@@ -166,6 +166,47 @@ struct FmhaBwdWorkspaceManager
|
||||
// In these cases we need to zero out it first
|
||||
return kHasMask;
|
||||
}
|
||||
|
||||
// Mirrors PrepareWorkspaceHost's return value but uses worst-case totals so
|
||||
// device workspace can be pre-allocated before host has the seqstart values.
|
||||
template <bool kUseQrQtrDorPipeline, index_t kN0>
|
||||
CK_TILE_HOST static size_t GetWorkspaceDeviceSizeUpperBound(index_t max_batch,
|
||||
index_t hdim_q,
|
||||
index_t nhead_q,
|
||||
index_t max_seqlen_q,
|
||||
index_t max_seqlen_k)
|
||||
{
|
||||
if constexpr(kUseQrQtrDorPipeline)
|
||||
return 0;
|
||||
|
||||
if constexpr(!kIsDeterministic)
|
||||
{
|
||||
return sizeof(AccDataType) * static_cast<long_index_t>(max_batch) * nhead_q *
|
||||
max_seqlen_q * hdim_q;
|
||||
}
|
||||
else if constexpr(kIsGroupMode)
|
||||
{
|
||||
const index_t nsplits_max = integer_divide_ceil(max_seqlen_k, kN0);
|
||||
return sizeof(AccDataType) * static_cast<long_index_t>(max_batch) * nhead_q *
|
||||
nsplits_max * max_seqlen_q * hdim_q;
|
||||
}
|
||||
else // deterministic non-group mode (kUsePersistent)
|
||||
{
|
||||
const index_t dqdqkdv_workers = get_num_cus();
|
||||
const index_t jobs_per_head = integer_divide_ceil(max_seqlen_k, kN0);
|
||||
const index_t total_jobs = max_batch * nhead_q * jobs_per_head;
|
||||
const index_t jobs_per_worker = integer_divide_ceil(total_jobs, dqdqkdv_workers);
|
||||
index_t nsplits;
|
||||
if(jobs_per_head % jobs_per_worker == 0)
|
||||
nsplits = jobs_per_head / jobs_per_worker;
|
||||
else if(jobs_per_worker % jobs_per_head == 0)
|
||||
nsplits = 1;
|
||||
else
|
||||
nsplits = 1 + integer_divide_ceil(jobs_per_head - 1, jobs_per_worker);
|
||||
return sizeof(AccDataType) * static_cast<long_index_t>(max_batch) * nhead_q * nsplits *
|
||||
max_seqlen_q * hdim_q;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename FmhaPipeline_,
|
||||
@@ -281,6 +322,13 @@ struct FmhaBwdDQDKDVKernel
|
||||
FmhaPipeline::BlockFmhaShape::kN0>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
template <typename... Args>
|
||||
CK_TILE_HOST static constexpr auto GetWorkspaceDeviceSizeUpperBound(Args&&... args)
|
||||
{
|
||||
return WorkspaceManager::template GetWorkspaceDeviceSizeUpperBound<
|
||||
kUseQrQtrDorPipeline,
|
||||
FmhaPipeline::BlockFmhaShape::kN0>(std::forward<Args>(args)...);
|
||||
}
|
||||
CK_TILE_HOST static constexpr bool NeedsZeroDqAcc()
|
||||
{
|
||||
return WorkspaceManager::template NeedsZeroDqAcc<kUseQrQtrDorPipeline, kHasMask>();
|
||||
|
||||
Reference in New Issue
Block a user