Add sequence padding and variable length support in fmha (#2932)

* * [CK_TILE] Add sequence padding and variable length support in fmha (and v3)

 - Group Mode Padding: Introduces the `-s_qpad` argument to support
   physically padded layouts. Kernels now use padded start pointers
   (`seqstart_padded_*_ptr`) for memory addressing.

 - Batch Mode Variable Length: Adds `-q_eff_lens` and `-kv_eff_lens`
   arguments for efficient processing of variable-length sequences by
   passing cumulative effective lengths (`cu_seqlen_*_ptr`) to the kernel.

 - FMHA examples: Support padding and variable length both in
   group and batch mode. Dispatcher is updated as well (dispatch to
   kPadSeqLenK enabled pipeline).

 - New padding test cases: Add padding test cases to `smoke_test_fwd.sh` and
   `test_fmha_fwd.inc`, and add benchmarks to `benchmark_fwd.sh` and
   `benchmark_fwd_v3.sh` as well. These test cases and benchmarks that
   specifically validate/benchmark the new padding and variable-length
   functionalities in both group and batch modes.

* [CK_TILE] Fix build error in fmha unit tests

* [CK_TILE] add mqa, gqa to sequence padding unit tests

* [CI_TILE] Reduce the number of padding seqlen unit tests in FMHA to avoid timeouts in CI

* [CK_TILE] remove unnecessary MageKArgs overload in FmhaFwdV3Kernel and FmhaFwdKernel
This commit is contained in:
Jeff Huang
2025-09-26 12:36:27 +08:00
committed by GitHub
parent b0a2d99d10
commit 518d24e662
14 changed files with 1155 additions and 72 deletions

View File

@@ -293,6 +293,11 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_o;
// Optional cumulative sequence length pointers for batch mode
// If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // cumulative, length without PAD
};
struct FmhaFwdGroupModeKargs
@@ -312,6 +317,11 @@ struct FmhaFwdKernel
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr;
// Optional cumulative padded sequence starts (including PAD tokens)
// Used solely to compute memory offsets when sequences are physically padded.
const int32_t* seqstart_padded_q_ptr = nullptr;
const int32_t* seqstart_padded_k_ptr = nullptr;
};
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
@@ -368,7 +378,9 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
drop_seed_offset,
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -459,6 +471,8 @@ struct FmhaFwdKernel
kargs.init_logits_soft_cap(logits_soft_cap);
}
kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
return kargs;
}
@@ -507,7 +521,9 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
{
return MakeKargsImpl(
q_ptr,
@@ -552,7 +568,9 @@ struct FmhaFwdKernel
mask_type,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
cu_seqlen_q_ptr,
cu_seqlen_kv_ptr);
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
@@ -600,7 +618,9 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<const void*, const void*>& drop_seed_offset)
const std::tuple<const void*, const void*>& drop_seed_offset,
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
{
return MakeKargsImpl(
q_ptr,
@@ -645,7 +665,9 @@ struct FmhaFwdKernel
mask_type,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
cu_seqlen_q_ptr,
cu_seqlen_kv_ptr);
}
template <bool Cond = kIsGroupMode>
@@ -688,7 +710,9 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
drop_seed_offset,
const void* seqstart_padded_q_ptr = nullptr,
const void* seqstart_padded_k_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -780,6 +804,8 @@ struct FmhaFwdKernel
kargs.min_seqlen_q = min_seqlen_q;
}
kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
return kargs;
}
@@ -823,7 +849,9 @@ struct FmhaFwdKernel
ck_tile::index_t min_seqlen_q,
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
const void* seqstart_padded_q_ptr = nullptr,
const void* seqstart_padded_k_ptr = nullptr)
{
return MakeKargsImpl(
q_ptr,
@@ -863,7 +891,9 @@ struct FmhaFwdKernel
min_seqlen_q,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
seqstart_padded_q_ptr,
seqstart_padded_k_ptr);
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
@@ -906,7 +936,9 @@ struct FmhaFwdKernel
ck_tile::index_t min_seqlen_q,
float p_drop,
bool s_randval,
const std::tuple<const void*, const void*>& drop_seed_offset)
const std::tuple<const void*, const void*>& drop_seed_offset,
const void* seqstart_padded_q_ptr = nullptr,
const void* seqstart_padded_k_ptr = nullptr)
{
return MakeKargsImpl(
q_ptr,
@@ -946,7 +978,9 @@ struct FmhaFwdKernel
min_seqlen_q,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
seqstart_padded_q_ptr,
seqstart_padded_k_ptr);
}
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
@@ -1075,35 +1109,44 @@ struct FmhaFwdKernel
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
// logical and physical (padded) starts
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
? kargs.seqstart_padded_q_ptr[i_batch]
: query_start_unpadded;
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
? kargs.seqstart_padded_k_ptr[i_batch]
: key_start_unpadded;
// DRAM base offsets use physical padded starts
batch_offset_q = query_start_padded * kargs.stride_q;
batch_offset_k = key_start_padded * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
batch_offset_v = key_start_padded * kargs.stride_v;
}
else
{
batch_offset_v = key_start;
batch_offset_v = key_start_padded;
}
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start * kargs.stride_bias;
batch_offset_bias = query_start_padded * kargs.stride_bias;
}
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start;
// LSE stays indexed by unpadded starts
batch_offset_lse = query_start_unpadded;
}
if constexpr(kHasDropout)
{
batch_offset_randval = query_start * kargs.stride_randval;
batch_offset_randval = query_start_padded * kargs.stride_randval;
}
batch_offset_o = query_start * kargs.stride_o;
batch_offset_o = query_start_padded * kargs.stride_o;
// get real # queries & # keys under group mode
// real logical lengths (exclude PAD)
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
@@ -1115,8 +1158,7 @@ struct FmhaFwdKernel
}
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
// terminate unnecessary blocks earlier
if(kargs.seqlen_q <= i_m0)
{
return;
@@ -1152,6 +1194,18 @@ struct FmhaFwdKernel
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
}
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
// If cumulative seqlen pointers are provided, override per-batch effective lengths
if(kargs.cu_seqlen_q_ptr != nullptr)
{
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
if(kargs.cu_seqlen_kv_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
}
}
// for simplicity, batch stride we just modify the pointer
@@ -1550,26 +1604,35 @@ struct FmhaFwdKernel
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
? kargs.seqstart_padded_q_ptr[i_batch]
: query_start_unpadded;
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
? kargs.seqstart_padded_k_ptr[i_batch]
: key_start_unpadded;
batch_offset_q = query_start_padded * kargs.stride_q;
batch_offset_k = key_start_padded * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
batch_offset_v = key_start_padded * kargs.stride_v;
}
else
{
batch_offset_v = key_start;
// col-major V: offset along seqlen dimension is scalar index
batch_offset_v = key_start_padded;
}
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start * kargs.stride_bias;
batch_offset_bias = query_start_padded * kargs.stride_bias;
}
batch_offset_lse = query_start;
batch_offset_o = query_start * kargs.stride_o;
// LSE layout is [nhead, total_seqlen], index by unpadded start
batch_offset_lse = query_start_unpadded;
batch_offset_o = query_start_padded * kargs.stride_o;
// get real # queries & # keys under group mode
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
@@ -1607,6 +1670,18 @@ struct FmhaFwdKernel
batch_offset_bias =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
}
// If cumulative seqlen pointers are provided, override per-batch effective lengths
if(kargs.cu_seqlen_q_ptr != nullptr)
{
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
if(kargs.cu_seqlen_kv_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
}
}
// for simplicity, batch stride we just modify the pointer

View File

@@ -100,6 +100,11 @@ struct FmhaFwdV3Kernel
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_o;
// Optional cumulative sequence length pointers for batch mode
// If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
};
struct FmhaFwdGroupModeKargs
@@ -110,6 +115,11 @@ struct FmhaFwdV3Kernel
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr;
// Optional cumulative padded sequence starts (including PAD tokens)
// Used solely to compute memory offsets when sequences are physically padded.
const int32_t* seqstart_padded_q_ptr = nullptr; // [batch+1]
const int32_t* seqstart_padded_k_ptr = nullptr; // [batch+1]
};
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
@@ -145,7 +155,9 @@ struct FmhaFwdV3Kernel
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
ck_tile::index_t remap_opt)
ck_tile::index_t remap_opt,
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -187,6 +199,8 @@ struct FmhaFwdV3Kernel
kargs.batch_stride_lse = batch_stride_lse;
}
kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
return kargs;
}
@@ -217,7 +231,9 @@ struct FmhaFwdV3Kernel
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
ck_tile::index_t remap_opt)
ck_tile::index_t remap_opt,
const void* seqstart_padded_q_ptr = nullptr,
const void* seqstart_padded_k_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -257,6 +273,8 @@ struct FmhaFwdV3Kernel
kargs.nhead_stride_lse = nhead_stride_lse;
}
kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
return kargs;
}
@@ -373,18 +391,26 @@ struct FmhaFwdV3Kernel
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_v = key_start * kargs.stride_v;
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
? kargs.seqstart_padded_q_ptr[i_batch]
: query_start_unpadded;
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
? kargs.seqstart_padded_k_ptr[i_batch]
: key_start_unpadded;
batch_offset_q = query_start_padded * kargs.stride_q;
batch_offset_k = key_start_padded * kargs.stride_k;
batch_offset_v = key_start_padded * kargs.stride_v;
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start;
// LSE layout is [nhead, total_seqlen], index by unpadded start
batch_offset_lse = query_start_unpadded;
}
batch_offset_o = query_start * kargs.stride_o;
batch_offset_o = query_start_padded * kargs.stride_o;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
@@ -417,6 +443,18 @@ struct FmhaFwdV3Kernel
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
// If cumulative seqlen pointers are provided, override per-batch effective lengths
if(kargs.cu_seqlen_q_ptr != nullptr)
{
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
if(kargs.cu_seqlen_kv_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
}
}
// for simplicity, batch stride we just modify the pointer