mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
Add sequence padding and variable length support in fmha (#2932)
* * [CK_TILE] Add sequence padding and variable length support in fmha (and v3) - Group Mode Padding: Introduces the `-s_qpad` argument to support physically padded layouts. Kernels now use padded start pointers (`seqstart_padded_*_ptr`) for memory addressing. - Batch Mode Variable Length: Adds `-q_eff_lens` and `-kv_eff_lens` arguments for efficient processing of variable-length sequences by passing cumulative effective lengths (`cu_seqlen_*_ptr`) to the kernel. - FMHA examples: Support padding and variable length both in group and batch mode. Dispatcher is updated as well (dispatch to kPadSeqLenK enabled pipeline). - New padding test cases: Add padding test cases to `smoke_test_fwd.sh` and `test_fmha_fwd.inc`, and add benchmarks to `benchmark_fwd.sh` and `benchmark_fwd_v3.sh` as well. These test cases and benchmarks that specifically validate/benchmark the new padding and variable-length functionalities in both group and batch modes. * [CK_TILE] Fix build error in fmha unit tests * [CK_TILE] add mqa, gqa to sequence padding unit tests * [CI_TILE] Reduce the number of padding seqlen unit tests in FMHA to avoid timeouts in CI * [CK_TILE] remove unnecessary MageKArgs overload in FmhaFwdV3Kernel and FmhaFwdKernel
This commit is contained in:
@@ -36,6 +36,13 @@ args:
|
||||
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
|
||||
also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode)
|
||||
-s_k seqlen_k (including new key/value), -1 means equal to s (default:-1)
|
||||
also with "-s_k=s0,s1,s2..." comma-separated ints to set seqlen per batch (group mode)
|
||||
-s_qpad seqlen_q stride between 2 batches (group-mode optional) (default:-1)
|
||||
Provide positive strides per-batch to simulate physical padding on Q
|
||||
-s_kpad seqlen_k stride between 2 batches, currently used in group-mode only (default:-1)
|
||||
for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride
|
||||
along seqlen, instead of packed, same as xformer kv_padding,
|
||||
must be greater than or equal to s_k
|
||||
-d head dim for q, k (default:128)
|
||||
-d_v head dim for v, -1 means equal to d (default:-1)
|
||||
-scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
|
||||
@@ -76,11 +83,20 @@ args:
|
||||
-repeat number of iterations to benchmark the kernel (default:20)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:fmha_fwd.json)
|
||||
-q_eff_lens Batch-mode only: per-batch effective seqlen for Q (exclude PAD) (default:"")
|
||||
Comma-separated list of length 'b'. If empty, no override
|
||||
-kv_eff_lens Batch-mode only: per-batch effective seqlen for KV (exclude PAD) (default:"")
|
||||
Comma-separated list of length 'b'. If empty, no override
|
||||
```
|
||||
Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
|
||||
Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with
|
||||
batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case
|
||||
|
||||
## Padding Examples
|
||||
Example 3 (Group mode with padding): `./bin/tile_example_fmha_fwd -mode=1 -b=2 -h=8 -s=1024,2048 -s_k=1024,2048 -s_qpad=1536,3072 -s_kpad=1536,3072 -d=128` will run group mode with 2 batches having different sequence lengths (1024, 2048) but physically padded to (1536, 3072) respectively.
|
||||
|
||||
Example 4 (Batch mode with effective lengths): `./bin/tile_example_fmha_fwd -mode=0 -b=2 -h=8 -s=2048 -s_k=2048 -d=128 -q_eff_lens=1024,1536 -kv_eff_lens=1024,1536` will run batch mode where all batches use 2048 as physical sequence length but have effective lengths of (1024, 1536) for Q and KV respectively.
|
||||
|
||||
## support features
|
||||
Currently we are still in rapid development stage, so more features/optimizations will be coming soon.
|
||||
|
||||
@@ -128,6 +144,15 @@ Note FA use bottom-right by default to express swa case, here we require you exp
|
||||
### dropout
|
||||
TBD
|
||||
|
||||
### sequence padding and variable length support
|
||||
We support sequence padding and variable-length processing in both batch and group modes fmha forward to handle real-world scenarios where sequences have different lengths.
|
||||
|
||||
**Group Mode Padding**: Use `-s_qpad` and `-s_kpad` to specify physical stride between batches, enabling padded layouts. Each batch can have different logical sequence lengths (`-s`, `-s_k`) but use larger physical strides for memory alignment.
|
||||
|
||||
**Batch Mode Variable Length**: Use `-q_eff_lens` and `-kv_eff_lens` to specify effective sequence lengths per batch. All batches share the same physical sequence length, but the kernel processes only the effective portions. This enables efficient variable-length attention without memory waste.
|
||||
|
||||
Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios.
|
||||
|
||||
## FP8 experimental support
|
||||
As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user