Add sequence padding and variable length support in fmha (#2932)

* * [CK_TILE] Add sequence padding and variable length support in fmha (and v3)

 - Group Mode Padding: Introduces the `-s_qpad` argument to support
   physically padded layouts. Kernels now use padded start pointers
   (`seqstart_padded_*_ptr`) for memory addressing.

 - Batch Mode Variable Length: Adds `-q_eff_lens` and `-kv_eff_lens`
   arguments for efficient processing of variable-length sequences by
   passing cumulative effective lengths (`cu_seqlen_*_ptr`) to the kernel.

 - FMHA examples: Support padding and variable length both in
   group and batch mode. Dispatcher is updated as well (dispatch to
   kPadSeqLenK enabled pipeline).

 - New padding test cases: Add padding test cases to `smoke_test_fwd.sh` and
   `test_fmha_fwd.inc`, and add benchmarks to `benchmark_fwd.sh` and
   `benchmark_fwd_v3.sh` as well. These test cases and benchmarks that
   specifically validate/benchmark the new padding and variable-length
   functionalities in both group and batch modes.

* [CK_TILE] Fix build error in fmha unit tests

* [CK_TILE] add mqa, gqa to sequence padding unit tests

* [CI_TILE] Reduce the number of padding seqlen unit tests in FMHA to avoid timeouts in CI

* [CK_TILE] remove unnecessary MageKArgs overload in FmhaFwdV3Kernel and FmhaFwdKernel
This commit is contained in:
Jeff Huang
2025-09-26 12:36:27 +08:00
committed by GitHub
parent b0a2d99d10
commit 518d24e662
14 changed files with 1155 additions and 72 deletions

View File

@@ -36,6 +36,13 @@ args:
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode)
-s_k seqlen_k (including new key/value), -1 means equal to s (default:-1)
also with "-s_k=s0,s1,s2..." comma-separated ints to set seqlen per batch (group mode)
-s_qpad seqlen_q stride between 2 batches (group-mode optional) (default:-1)
Provide positive strides per-batch to simulate physical padding on Q
-s_kpad seqlen_k stride between 2 batches, currently used in group-mode only (default:-1)
for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride
along seqlen, instead of packed, same as xformer kv_padding,
must be greater than or equal to s_k
-d head dim for q, k (default:128)
-d_v head dim for v, -1 means equal to d (default:-1)
-scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
@@ -76,11 +83,20 @@ args:
-repeat number of iterations to benchmark the kernel (default:20)
-json 0: No Json, 1: Dump Results in Json format (default:0)
-jsonfile json file name to dump results (default:fmha_fwd.json)
-q_eff_lens Batch-mode only: per-batch effective seqlen for Q (exclude PAD) (default:"")
Comma-separated list of length 'b'. If empty, no override
-kv_eff_lens Batch-mode only: per-batch effective seqlen for KV (exclude PAD) (default:"")
Comma-separated list of length 'b'. If empty, no override
```
Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with
batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case
## Padding Examples
Example 3 (Group mode with padding): `./bin/tile_example_fmha_fwd -mode=1 -b=2 -h=8 -s=1024,2048 -s_k=1024,2048 -s_qpad=1536,3072 -s_kpad=1536,3072 -d=128` will run group mode with 2 batches having different sequence lengths (1024, 2048) but physically padded to (1536, 3072) respectively.
Example 4 (Batch mode with effective lengths): `./bin/tile_example_fmha_fwd -mode=0 -b=2 -h=8 -s=2048 -s_k=2048 -d=128 -q_eff_lens=1024,1536 -kv_eff_lens=1024,1536` will run batch mode where all batches use 2048 as physical sequence length but have effective lengths of (1024, 1536) for Q and KV respectively.
## support features
Currently we are still in rapid development stage, so more features/optimizations will be coming soon.
@@ -128,6 +144,15 @@ Note FA use bottom-right by default to express swa case, here we require you exp
### dropout
TBD
### sequence padding and variable length support
We support sequence padding and variable-length processing in both batch and group modes fmha forward to handle real-world scenarios where sequences have different lengths.
**Group Mode Padding**: Use `-s_qpad` and `-s_kpad` to specify physical stride between batches, enabling padded layouts. Each batch can have different logical sequence lengths (`-s`, `-s_k`) but use larger physical strides for memory alignment.
**Batch Mode Variable Length**: Use `-q_eff_lens` and `-kv_eff_lens` to specify effective sequence lengths per batch. All batches share the same physical sequence length, but the kernel processes only the effective portions. This enables efficient variable-length attention without memory waste.
Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios.
## FP8 experimental support
As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+.