[CK_TILE] fmha: Add query padding support to backward pass (#3097)

* [CK_TILE] fmha: Add query padding support to backward pass

Introduces support for query sequence padding (q_padding) in the FMHA backward pass kernels.
- Passing `seqlen_q_ptr` to the backward kernels to distinguish logical from physical sequence lengths.
- Updating `OGradDotO`, `ConvertQGrad`, and `DQDKDV` kernels to respect logical lengths and handle zero-length sequences.
- Aligning LSE indexing in the forward kernel with the padded layout for consistency.
- Adding a new GTest suite (`test_fmha_bwd_kernel_padding.cpp`) with comprehensive tests for various padding scenarios, including zero-length
  sequences and deterministic mode.

* fix clang format

* Adapt fmha_bwd_runner.cpp to new q, kv sequence padding
Add backward q/kv sequence padding unit tests.

* [CK_TILE] fmha: Unify sequence length and padding handling

Refactor the handling of sequence lengths and padding in the
FMHA forward and backward kernels to provide a more unified and flexible
interface.

- Replaced `seqstart_padded_*_ptr` with a more robust system that uses
  `seqstart_*_ptr` for physical sequence lengths and introduces
  `seqlen_*_ptr` and `cu_seqlen_*_ptr` for logical (unpadded) lengths.
- Established a clear order of precedence for determining sequence
  length: cumulative lengths (`cu_seqlen_*_ptr`) take priority,
  followed by per-sequence lengths (`seqlen_*_ptr`), and finally
  physical lengths derived from `seqstart_*_ptr`.
- Clarified the distinction between "group mode" and "batch mode" and
  how sequence lengths are handled in each case.
- Renamed `cu_seqlen_kv_ptr` to `cu_seqlen_k_ptr` for consistency.
- Updated comments and documentation to reflect the new argument
  structure and usage.

---------

Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
This commit is contained in:
Jeff Huang
2025-10-29 13:56:11 +08:00
committed by GitHub
parent 13e13ce359
commit 7c6430eca0
11 changed files with 1292 additions and 214 deletions

View File

@@ -24,11 +24,19 @@ auto create_args(int argc, char* argv[])
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
"also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_qpad",
"-1",
"padded seqlen_q per batch (group mode only). "
"Use \"-s_qpad=p0,p1,...\"; -1 disables explicit padding")
.insert("s_k",
"-1",
"seqlen_k, -1 means equal to s\n"
"also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_kpad",
"-1",
"padded seqlen_k per batch (group mode only). "
"Use \"-s_kpad=k0,k1,...\"; -1 disables explicit padding")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)")
@@ -96,7 +104,9 @@ auto run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
auto seqlen_qs = arg_parser.get_int_vec("s");
auto seqlen_qpads = arg_parser.get_int_vec("s_qpad");
auto seqlen_ks = arg_parser.get_int_vec("s_k");
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
bool i_perm = arg_parser.get_bool("iperm");
@@ -130,6 +140,8 @@ auto run(const ck_tile::ArgParser& arg_parser)
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,