[CK_TILE] Add sequence padding and variable length support in fmha (a… (#2851)

* [CK_TILE] Add sequence padding and variable length support in fmha (and v3)

 - Group Mode Padding: Introduces the `-s_qpad` argument to support
   physically padded layouts. Kernels now use padded start pointers
   (`seqstart_padded_*_ptr`) for memory addressing.

 - Batch Mode Variable Length: Adds `-q_eff_lens` and `-kv_eff_lens`
   arguments for efficient processing of variable-length sequences by
   passing cumulative effective lengths (`cu_seqlen_*_ptr`) to the kernel.

 - FMHA examples: Support padding and variable length both in
   group and batch mode. Dispatcher is updated as well (dispatch to
   kPadSeqLenK enabled pipeline).

 - New padding test cases: Add padding test cases to `smoke_test_fwd.sh`,
   and add benchmarks to `benchmark_fwd.sh` and `benchmark_fwd_v3.sh` as well.
   These test cases and benchmarks that specifically validate/benchmark the
   new padding and variable-length functionalities in both group and batch modes.

* [CK_TILE] Fix build error in fmha unit tests

---------

Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: Yi DING <yi.ding@amd.com>
This commit is contained in:
Jeff Huang
2025-09-19 17:36:49 +08:00
committed by GitHub
parent 2aec38f9ec
commit 86dd59cd01
13 changed files with 1032 additions and 60 deletions

View File

@@ -291,6 +291,11 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_o;
// Optional cumulative sequence length pointers for batch mode
// If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // cumulative, length without PAD
};
struct FmhaFwdGroupModeKargs
@@ -310,6 +315,11 @@ struct FmhaFwdKernel
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr;
// Optional cumulative padded sequence starts (including PAD tokens)
// Used solely to compute memory offsets when sequences are physically padded.
const int32_t* seqstart_padded_q_ptr = nullptr;
const int32_t* seqstart_padded_k_ptr = nullptr;
};
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
@@ -460,6 +470,105 @@ struct FmhaFwdKernel
return kargs;
}
// Overload: Batch mode with optional cu_seqlen pointers (unpadded cumulative lengths)
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
const ck_tile::index_t* cu_seqlen_q_ptr,
const ck_tile::index_t* cu_seqlen_kv_ptr)
{
auto kargs = MakeKargsImpl(q_ptr,
k_ptr,
v_ptr,
bias_ptr,
rand_val_ptr,
lse_ptr,
o_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale_s,
scale_p,
scale_o,
logits_soft_cap,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_randval,
batch_stride_lse,
batch_stride_o,
window_size_left,
window_size_right,
mask_type,
p_drop,
s_randval,
drop_seed_offset);
kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
return kargs;
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
@@ -781,6 +890,95 @@ struct FmhaFwdKernel
return kargs;
}
// Overload: Group mode with optional padded seqstarts for memory offsets
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q,
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
const void* seqstart_padded_q_ptr,
const void* seqstart_padded_k_ptr)
{
auto kargs = MakeKargsImpl(q_ptr,
k_ptr,
v_ptr,
bias_ptr,
rand_val_ptr,
lse_ptr,
o_ptr,
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_k_ptr,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale_s,
scale_p,
scale_o,
logits_soft_cap,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
window_size_left,
window_size_right,
mask_type,
min_seqlen_q,
p_drop,
s_randval,
drop_seed_offset);
kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
return kargs;
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
@@ -1073,35 +1271,44 @@ struct FmhaFwdKernel
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
// logical and physical (padded) starts
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
? kargs.seqstart_padded_q_ptr[i_batch]
: query_start_unpadded;
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
? kargs.seqstart_padded_k_ptr[i_batch]
: key_start_unpadded;
// DRAM base offsets use physical padded starts
batch_offset_q = query_start_padded * kargs.stride_q;
batch_offset_k = key_start_padded * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
batch_offset_v = key_start_padded * kargs.stride_v;
}
else
{
batch_offset_v = key_start;
batch_offset_v = key_start_padded;
}
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start * kargs.stride_bias;
batch_offset_bias = query_start_padded * kargs.stride_bias;
}
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start;
// LSE stays indexed by unpadded starts
batch_offset_lse = query_start_unpadded;
}
if constexpr(kHasDropout)
{
batch_offset_randval = query_start * kargs.stride_randval;
batch_offset_randval = query_start_padded * kargs.stride_randval;
}
batch_offset_o = query_start * kargs.stride_o;
batch_offset_o = query_start_padded * kargs.stride_o;
// get real # queries & # keys under group mode
// real logical lengths (exclude PAD)
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
@@ -1113,8 +1320,7 @@ struct FmhaFwdKernel
}
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
// terminate unnecessary blocks earlier
if(kargs.seqlen_q <= i_m0)
{
return;
@@ -1150,6 +1356,18 @@ struct FmhaFwdKernel
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
}
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
// If cumulative seqlen pointers are provided, override per-batch effective lengths
if(kargs.cu_seqlen_q_ptr != nullptr)
{
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
if(kargs.cu_seqlen_kv_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
}
}
// for simplicity, batch stride we just modify the pointer
@@ -1548,26 +1766,35 @@ struct FmhaFwdKernel
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
? kargs.seqstart_padded_q_ptr[i_batch]
: query_start_unpadded;
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
? kargs.seqstart_padded_k_ptr[i_batch]
: key_start_unpadded;
batch_offset_q = query_start_padded * kargs.stride_q;
batch_offset_k = key_start_padded * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
batch_offset_v = key_start_padded * kargs.stride_v;
}
else
{
batch_offset_v = key_start;
// col-major V: offset along seqlen dimension is scalar index
batch_offset_v = key_start_padded;
}
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start * kargs.stride_bias;
batch_offset_bias = query_start_padded * kargs.stride_bias;
}
batch_offset_lse = query_start;
batch_offset_o = query_start * kargs.stride_o;
// LSE layout is [nhead, total_seqlen], index by unpadded start
batch_offset_lse = query_start_unpadded;
batch_offset_o = query_start_padded * kargs.stride_o;
// get real # queries & # keys under group mode
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
@@ -1605,6 +1832,18 @@ struct FmhaFwdKernel
batch_offset_bias =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
}
// If cumulative seqlen pointers are provided, override per-batch effective lengths
if(kargs.cu_seqlen_q_ptr != nullptr)
{
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
if(kargs.cu_seqlen_kv_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
}
}
// for simplicity, batch stride we just modify the pointer

View File

@@ -100,6 +100,11 @@ struct FmhaFwdV3Kernel
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_o;
// Optional cumulative sequence length pointers for batch mode
// If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
};
struct FmhaFwdGroupModeKargs
@@ -110,6 +115,11 @@ struct FmhaFwdV3Kernel
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr;
// Optional cumulative padded sequence starts (including PAD tokens)
// Used solely to compute memory offsets when sequences are physically padded.
const int32_t* seqstart_padded_q_ptr = nullptr; // [batch+1]
const int32_t* seqstart_padded_k_ptr = nullptr; // [batch+1]
};
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
@@ -190,6 +200,78 @@ struct FmhaFwdV3Kernel
return kargs;
}
// Overload: Batch mode with optional cu_seqlen pointers
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
ck_tile::index_t remap_opt,
const ck_tile::index_t* cu_seqlen_q_ptr,
const ck_tile::index_t* cu_seqlen_kv_ptr)
{
auto kargs = MakeKargs(q_ptr,
k_ptr,
v_ptr,
lse_ptr,
o_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale_s,
stride_q,
stride_k,
stride_v,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_lse,
batch_stride_o,
window_size_left,
window_size_right,
mask_type,
remap_opt);
kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
return kargs;
}
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
@@ -260,6 +342,70 @@ struct FmhaFwdV3Kernel
return kargs;
}
// Overload: Group mode with optional padded seqstarts for memory offsets
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
ck_tile::index_t remap_opt,
const void* seqstart_padded_q_ptr,
const void* seqstart_padded_k_ptr)
{
auto kargs = MakeKargs(q_ptr,
k_ptr,
v_ptr,
lse_ptr,
o_ptr,
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_k_ptr,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale_s,
stride_q,
stride_k,
stride_v,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_lse,
nhead_stride_o,
window_size_left,
window_size_right,
mask_type,
remap_opt);
kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
return kargs;
}
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
@@ -373,18 +519,26 @@ struct FmhaFwdV3Kernel
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_v = key_start * kargs.stride_v;
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
? kargs.seqstart_padded_q_ptr[i_batch]
: query_start_unpadded;
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
? kargs.seqstart_padded_k_ptr[i_batch]
: key_start_unpadded;
batch_offset_q = query_start_padded * kargs.stride_q;
batch_offset_k = key_start_padded * kargs.stride_k;
batch_offset_v = key_start_padded * kargs.stride_v;
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start;
// LSE layout is [nhead, total_seqlen], index by unpadded start
batch_offset_lse = query_start_unpadded;
}
batch_offset_o = query_start * kargs.stride_o;
batch_offset_o = query_start_padded * kargs.stride_o;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
@@ -417,6 +571,18 @@ struct FmhaFwdV3Kernel
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
// If cumulative seqlen pointers are provided, override per-batch effective lengths
if(kargs.cu_seqlen_q_ptr != nullptr)
{
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
if(kargs.cu_seqlen_kv_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
}
}
// for simplicity, batch stride we just modify the pointer