mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
[CK_TILE] FMHA BWD Use Persistent Kernels in Deterministic Mode (#5174)
## Motivation This PR enables a persistent-kernel execution path for FMHA backward (dQ/dK/dV) in deterministic mode, adjusting how dQ accumulation is split, stored, and converted back to final gradients. ## Technical Details - Introduces a persistent-kernel grid mapping in deterministic mode and updates split-count calculation accordingly. - Extends kernel kargs to carry batch-related info needed for persistent scheduling and dQ conversion. - Refactors dQ store conditions and adds mask-type traits/utilities and runner logging updates. ## Test Plan - Jenkins [base](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-5174/10/pipeline) - Jenkins [AITER](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-5174/12/pipeline) - Jenkins [FMHA](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-5174/11/pipeline) - local FA tests ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
@@ -169,10 +169,17 @@ int fmha_bwd_dq_dk_dv_maxq_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
|
||||
}}
|
||||
|
||||
template <>
|
||||
int fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(ck_tile::index_t seqlen_k)
|
||||
int fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(const fmha_bwd_traits& t)
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
return k_::GetDqAccSplits(seqlen_k);
|
||||
return k_::GetDqAccSplits(t.batch, t.nhead_q, t.max_seqlen_k);
|
||||
}}
|
||||
|
||||
template <>
|
||||
bool fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
return k_::NeedsZeroDqAcc();
|
||||
}}
|
||||
|
||||
template <>
|
||||
@@ -192,6 +199,7 @@ fmha_bwd_launcher::fmha_bwd_launcher(const fmha_bwd_traits& t){{
|
||||
{F_launcher}
|
||||
run = [](fmha_bwd_args, const ck_tile::stream_config&) {{ return -1.0f; }};
|
||||
dq_acc_splits = 1;
|
||||
needs_zero_dq_acc = false;
|
||||
}}
|
||||
|
||||
|
||||
@@ -231,7 +239,8 @@ FMHA_BWD_API_INNER_DISPATCH_LAUNCHER = """
|
||||
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {{
|
||||
return fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(s, a);
|
||||
}};
|
||||
dq_acc_splits = fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_, {F_arch.tag}>(t.max_seqlen_k);
|
||||
dq_acc_splits = fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_, {F_arch.tag}>(t);
|
||||
needs_zero_dq_acc = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<dq_dk_dv_trait_, {F_arch.tag}>();
|
||||
return;
|
||||
}}
|
||||
"""
|
||||
@@ -447,7 +456,7 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9):
|
||||
results = KernelComponentFactoryGfx9.get_dq_dk_dv_tiles(dtype, tr_load)
|
||||
if dtype in ["fp16", "bf16"] and tr_load == "t":
|
||||
results.extend([
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
|
||||
FmhaBwdDQDKDVTileSize( 32, 256, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# FmhaBwdDQDKDVTileSize( 32, 32, 64, 32, 64, 32, 32, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, 1, 32),
|
||||
@@ -823,7 +832,7 @@ class FmhaBwdApiTrait:
|
||||
|
||||
@property
|
||||
def extra_cond(self) -> str:
|
||||
if self.tr_load == "t" and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128:
|
||||
if self.tr_load == "t" and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128 and self.tile.F_bhdq == 128:
|
||||
return " && (t.seqlen_k <= 256)"
|
||||
else:
|
||||
return ""
|
||||
|
||||
@@ -251,6 +251,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.seqlen_k_ptr,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.batch,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
@@ -300,6 +301,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
dq_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.batch,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
@@ -429,7 +431,9 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.nhead_stride_dq_acc,
|
||||
args.batch_stride_dq,
|
||||
args.batch_stride_dq_acc,
|
||||
args.split_stride_dq_acc);
|
||||
args.split_stride_dq_acc,
|
||||
args.batch,
|
||||
args.nhead_q);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -465,8 +469,11 @@ template <typename Traits_, typename Arch = void>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_();
|
||||
template <typename Traits_, typename Arch = void>
|
||||
int fmha_bwd_dq_dk_dv_maxq_();
|
||||
struct fmha_bwd_traits;
|
||||
template <typename Traits_, typename Arch = void>
|
||||
int fmha_bwd_dq_dk_dv_dq_acc_splits_(ck_tile::index_t seqlen_k);
|
||||
int fmha_bwd_dq_dk_dv_dq_acc_splits_(const fmha_bwd_traits& t);
|
||||
template <typename Traits_, typename Arch = void>
|
||||
bool fmha_bwd_dq_dk_dv_needs_zero_dq_acc_();
|
||||
|
||||
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
|
||||
struct fmha_bwd_dot_do_o_traits_
|
||||
@@ -569,6 +576,7 @@ struct fmha_bwd_launcher
|
||||
{
|
||||
std::function<float(fmha_bwd_args, const ck_tile::stream_config&)> run{};
|
||||
ck_tile::index_t dq_acc_splits{0};
|
||||
bool needs_zero_dq_acc{true};
|
||||
|
||||
fmha_bwd_launcher(const fmha_bwd_traits&);
|
||||
|
||||
|
||||
@@ -416,9 +416,10 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
<< "/" << seqlen_ks[0] << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale
|
||||
<< ", bias:" << bias << ", dbias:" << use_dbias << ", p_drop:" << p_drop
|
||||
<< ", s_randval:" << s_randval << ", deterministic:" << deterministic
|
||||
<< (deterministic ? std::string(", workspace:") +
|
||||
std::to_string(workspace_size_in_megabytes) + "MiB"
|
||||
: "")
|
||||
<< (deterministic
|
||||
? std::string(", workspace:") + std::to_string(workspace_size_in_megabytes) +
|
||||
"MiB|" + std::to_string(nsplits) + "splits"
|
||||
: "")
|
||||
<< ", mask:" << mask << std::flush;
|
||||
|
||||
auto fmha_args = [&]() {
|
||||
@@ -842,10 +843,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
lse_buf.ToDevice(lse_host.data());
|
||||
dbias_buf.SetZero();
|
||||
|
||||
// non-deterministic kernels use atomic add to write dq
|
||||
// Some block may be skipped with causal mask and dq are not set to zeros
|
||||
// In these cases thus we need to zero out it first
|
||||
if(!deterministic || mask.type != mask_enum::no_mask)
|
||||
if(launcher.needs_zero_dq_acc)
|
||||
dq_acc_buf.SetZero();
|
||||
|
||||
ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1};
|
||||
|
||||
Reference in New Issue
Block a user