mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Add sequence padding and variable length support in fmha (#2932)
* * [CK_TILE] Add sequence padding and variable length support in fmha (and v3) - Group Mode Padding: Introduces the `-s_qpad` argument to support physically padded layouts. Kernels now use padded start pointers (`seqstart_padded_*_ptr`) for memory addressing. - Batch Mode Variable Length: Adds `-q_eff_lens` and `-kv_eff_lens` arguments for efficient processing of variable-length sequences by passing cumulative effective lengths (`cu_seqlen_*_ptr`) to the kernel. - FMHA examples: Support padding and variable length both in group and batch mode. Dispatcher is updated as well (dispatch to kPadSeqLenK enabled pipeline). - New padding test cases: Add padding test cases to `smoke_test_fwd.sh` and `test_fmha_fwd.inc`, and add benchmarks to `benchmark_fwd.sh` and `benchmark_fwd_v3.sh` as well. These test cases and benchmarks that specifically validate/benchmark the new padding and variable-length functionalities in both group and batch modes. * [CK_TILE] Fix build error in fmha unit tests * [CK_TILE] add mqa, gqa to sequence padding unit tests * [CI_TILE] Reduce the number of padding seqlen unit tests in FMHA to avoid timeouts in CI * [CK_TILE] remove unnecessary MageKArgs overload in FmhaFwdV3Kernel and FmhaFwdKernel
This commit is contained in:
@@ -293,6 +293,11 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
|
||||
// Optional cumulative sequence length pointers for batch mode
|
||||
// If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // cumulative, length without PAD
|
||||
};
|
||||
|
||||
struct FmhaFwdGroupModeKargs
|
||||
@@ -312,6 +317,11 @@ struct FmhaFwdKernel
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
|
||||
// Optional cumulative padded sequence starts (including PAD tokens)
|
||||
// Used solely to compute memory offsets when sequences are physically padded.
|
||||
const int32_t* seqstart_padded_q_ptr = nullptr;
|
||||
const int32_t* seqstart_padded_k_ptr = nullptr;
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
|
||||
@@ -368,7 +378,9 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
drop_seed_offset,
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -459,6 +471,8 @@ struct FmhaFwdKernel
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
|
||||
kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
|
||||
kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
|
||||
return kargs;
|
||||
}
|
||||
|
||||
@@ -507,7 +521,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -552,7 +568,9 @@ struct FmhaFwdKernel
|
||||
mask_type,
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_kv_ptr);
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
@@ -600,7 +618,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset)
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -645,7 +665,9 @@ struct FmhaFwdKernel
|
||||
mask_type,
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_kv_ptr);
|
||||
}
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
@@ -688,7 +710,9 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
drop_seed_offset,
|
||||
const void* seqstart_padded_q_ptr = nullptr,
|
||||
const void* seqstart_padded_k_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -780,6 +804,8 @@ struct FmhaFwdKernel
|
||||
kargs.min_seqlen_q = min_seqlen_q;
|
||||
}
|
||||
|
||||
kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
|
||||
kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
|
||||
return kargs;
|
||||
}
|
||||
|
||||
@@ -823,7 +849,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
const void* seqstart_padded_q_ptr = nullptr,
|
||||
const void* seqstart_padded_k_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -863,7 +891,9 @@ struct FmhaFwdKernel
|
||||
min_seqlen_q,
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
seqstart_padded_q_ptr,
|
||||
seqstart_padded_k_ptr);
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
@@ -906,7 +936,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset)
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
const void* seqstart_padded_q_ptr = nullptr,
|
||||
const void* seqstart_padded_k_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -946,7 +978,9 @@ struct FmhaFwdKernel
|
||||
min_seqlen_q,
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
seqstart_padded_q_ptr,
|
||||
seqstart_padded_k_ptr);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
@@ -1075,35 +1109,44 @@ struct FmhaFwdKernel
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
// logical and physical (padded) starts
|
||||
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
|
||||
? kargs.seqstart_padded_q_ptr[i_batch]
|
||||
: query_start_unpadded;
|
||||
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
|
||||
? kargs.seqstart_padded_k_ptr[i_batch]
|
||||
: key_start_unpadded;
|
||||
|
||||
// DRAM base offsets use physical padded starts
|
||||
batch_offset_q = query_start_padded * kargs.stride_q;
|
||||
batch_offset_k = key_start_padded * kargs.stride_k;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
batch_offset_v = key_start_padded * kargs.stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_v = key_start;
|
||||
batch_offset_v = key_start_padded;
|
||||
}
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = query_start * kargs.stride_bias;
|
||||
batch_offset_bias = query_start_padded * kargs.stride_bias;
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = query_start;
|
||||
// LSE stays indexed by unpadded starts
|
||||
batch_offset_lse = query_start_unpadded;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
batch_offset_randval = query_start * kargs.stride_randval;
|
||||
batch_offset_randval = query_start_padded * kargs.stride_randval;
|
||||
}
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
batch_offset_o = query_start_padded * kargs.stride_o;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
// real logical lengths (exclude PAD)
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
|
||||
@@ -1115,8 +1158,7 @@ struct FmhaFwdKernel
|
||||
}
|
||||
}
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
// terminate unnecessary blocks earlier
|
||||
if(kargs.seqlen_q <= i_m0)
|
||||
{
|
||||
return;
|
||||
@@ -1152,6 +1194,18 @@ struct FmhaFwdKernel
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
|
||||
}
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
|
||||
// If cumulative seqlen pointers are provided, override per-batch effective lengths
|
||||
if(kargs.cu_seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
if(kargs.cu_seqlen_kv_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
|
||||
}
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
@@ -1550,26 +1604,35 @@ struct FmhaFwdKernel
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
|
||||
? kargs.seqstart_padded_q_ptr[i_batch]
|
||||
: query_start_unpadded;
|
||||
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
|
||||
? kargs.seqstart_padded_k_ptr[i_batch]
|
||||
: key_start_unpadded;
|
||||
|
||||
batch_offset_q = query_start_padded * kargs.stride_q;
|
||||
batch_offset_k = key_start_padded * kargs.stride_k;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
batch_offset_v = key_start_padded * kargs.stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_v = key_start;
|
||||
// col-major V: offset along seqlen dimension is scalar index
|
||||
batch_offset_v = key_start_padded;
|
||||
}
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = query_start * kargs.stride_bias;
|
||||
batch_offset_bias = query_start_padded * kargs.stride_bias;
|
||||
}
|
||||
|
||||
batch_offset_lse = query_start;
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
// LSE layout is [nhead, total_seqlen], index by unpadded start
|
||||
batch_offset_lse = query_start_unpadded;
|
||||
batch_offset_o = query_start_padded * kargs.stride_o;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
|
||||
@@ -1607,6 +1670,18 @@ struct FmhaFwdKernel
|
||||
batch_offset_bias =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
|
||||
}
|
||||
|
||||
// If cumulative seqlen pointers are provided, override per-batch effective lengths
|
||||
if(kargs.cu_seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
if(kargs.cu_seqlen_kv_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
|
||||
}
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
|
||||
@@ -100,6 +100,11 @@ struct FmhaFwdV3Kernel
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
|
||||
// Optional cumulative sequence length pointers for batch mode
|
||||
// If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
|
||||
};
|
||||
|
||||
struct FmhaFwdGroupModeKargs
|
||||
@@ -110,6 +115,11 @@ struct FmhaFwdV3Kernel
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
|
||||
// Optional cumulative padded sequence starts (including PAD tokens)
|
||||
// Used solely to compute memory offsets when sequences are physically padded.
|
||||
const int32_t* seqstart_padded_q_ptr = nullptr; // [batch+1]
|
||||
const int32_t* seqstart_padded_k_ptr = nullptr; // [batch+1]
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
|
||||
@@ -145,7 +155,9 @@ struct FmhaFwdV3Kernel
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t remap_opt)
|
||||
ck_tile::index_t remap_opt,
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -187,6 +199,8 @@ struct FmhaFwdV3Kernel
|
||||
kargs.batch_stride_lse = batch_stride_lse;
|
||||
}
|
||||
|
||||
kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
|
||||
kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
|
||||
return kargs;
|
||||
}
|
||||
|
||||
@@ -217,7 +231,9 @@ struct FmhaFwdV3Kernel
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t remap_opt)
|
||||
ck_tile::index_t remap_opt,
|
||||
const void* seqstart_padded_q_ptr = nullptr,
|
||||
const void* seqstart_padded_k_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -257,6 +273,8 @@ struct FmhaFwdV3Kernel
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
}
|
||||
|
||||
kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
|
||||
kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
|
||||
return kargs;
|
||||
}
|
||||
|
||||
@@ -373,18 +391,26 @@ struct FmhaFwdV3Kernel
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
|
||||
? kargs.seqstart_padded_q_ptr[i_batch]
|
||||
: query_start_unpadded;
|
||||
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
|
||||
? kargs.seqstart_padded_k_ptr[i_batch]
|
||||
: key_start_unpadded;
|
||||
|
||||
batch_offset_q = query_start_padded * kargs.stride_q;
|
||||
batch_offset_k = key_start_padded * kargs.stride_k;
|
||||
batch_offset_v = key_start_padded * kargs.stride_v;
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = query_start;
|
||||
// LSE layout is [nhead, total_seqlen], index by unpadded start
|
||||
batch_offset_lse = query_start_unpadded;
|
||||
}
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
batch_offset_o = query_start_padded * kargs.stride_o;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
@@ -417,6 +443,18 @@ struct FmhaFwdV3Kernel
|
||||
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
|
||||
}
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
|
||||
// If cumulative seqlen pointers are provided, override per-batch effective lengths
|
||||
if(kargs.cu_seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
if(kargs.cu_seqlen_kv_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
|
||||
}
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
|
||||
Reference in New Issue
Block a user