mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[CK_TILE] fmha: Add query padding support to backward pass (#3097)
* [CK_TILE] fmha: Add query padding support to backward pass Introduces support for query sequence padding (q_padding) in the FMHA backward pass kernels. - Passing `seqlen_q_ptr` to the backward kernels to distinguish logical from physical sequence lengths. - Updating `OGradDotO`, `ConvertQGrad`, and `DQDKDV` kernels to respect logical lengths and handle zero-length sequences. - Aligning LSE indexing in the forward kernel with the padded layout for consistency. - Adding a new GTest suite (`test_fmha_bwd_kernel_padding.cpp`) with comprehensive tests for various padding scenarios, including zero-length sequences and deterministic mode. * fix clang format * Adapt fmha_bwd_runner.cpp to new q, kv sequence padding Add backward q/kv sequence padding unit tests. * [CK_TILE] fmha: Unify sequence length and padding handling Refactor the handling of sequence lengths and padding in the FMHA forward and backward kernels to provide a more unified and flexible interface. - Replaced `seqstart_padded_*_ptr` with a more robust system that uses `seqstart_*_ptr` for physical sequence lengths and introduces `seqlen_*_ptr` and `cu_seqlen_*_ptr` for logical (unpadded) lengths. - Established a clear order of precedence for determining sequence length: cumulative lengths (`cu_seqlen_*_ptr`) take priority, followed by per-sequence lengths (`seqlen_*_ptr`), and finally physical lengths derived from `seqstart_*_ptr`. - Clarified the distinction between "group mode" and "batch mode" and how sequence lengths are handled in each case. - Renamed `cu_seqlen_kv_ptr` to `cu_seqlen_k_ptr` for consistency. - Updated comments and documentation to reflect the new argument structure and usage. --------- Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
This commit is contained in:
@@ -24,11 +24,19 @@ auto create_args(int argc, char* argv[])
|
||||
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
|
||||
"also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
|
||||
"(group mode)")
|
||||
.insert("s_qpad",
|
||||
"-1",
|
||||
"padded seqlen_q per batch (group mode only). "
|
||||
"Use \"-s_qpad=p0,p1,...\"; -1 disables explicit padding")
|
||||
.insert("s_k",
|
||||
"-1",
|
||||
"seqlen_k, -1 means equal to s\n"
|
||||
"also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
|
||||
"(group mode)")
|
||||
.insert("s_kpad",
|
||||
"-1",
|
||||
"padded seqlen_k per batch (group mode only). "
|
||||
"Use \"-s_kpad=k0,k1,...\"; -1 disables explicit padding")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)")
|
||||
@@ -96,7 +104,9 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
auto seqlen_qs = arg_parser.get_int_vec("s");
|
||||
auto seqlen_qpads = arg_parser.get_int_vec("s_qpad");
|
||||
auto seqlen_ks = arg_parser.get_int_vec("s_k");
|
||||
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
@@ -130,6 +140,8 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
nhead_k,
|
||||
seqlen_qs,
|
||||
seqlen_ks,
|
||||
seqlen_qpads,
|
||||
seqlen_kpads,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
|
||||
Reference in New Issue
Block a user