Add sequence padding and variable length support in fmha (#2932)

* * [CK_TILE] Add sequence padding and variable length support in fmha (and v3)

 - Group Mode Padding: Introduces the `-s_qpad` argument to support
   physically padded layouts. Kernels now use padded start pointers
   (`seqstart_padded_*_ptr`) for memory addressing.

 - Batch Mode Variable Length: Adds `-q_eff_lens` and `-kv_eff_lens`
   arguments for efficient processing of variable-length sequences by
   passing cumulative effective lengths (`cu_seqlen_*_ptr`) to the kernel.

 - FMHA examples: Support padding and variable length both in
   group and batch mode. Dispatcher is updated as well (dispatch to
   kPadSeqLenK enabled pipeline).

 - New padding test cases: Add padding test cases to `smoke_test_fwd.sh` and
   `test_fmha_fwd.inc`, and add benchmarks to `benchmark_fwd.sh` and
   `benchmark_fwd_v3.sh` as well. These test cases and benchmarks that
   specifically validate/benchmark the new padding and variable-length
   functionalities in both group and batch modes.

* [CK_TILE] Fix build error in fmha unit tests

* [CK_TILE] add mqa, gqa to sequence padding unit tests

* [CI_TILE] Reduce the number of padding seqlen unit tests in FMHA to avoid timeouts in CI

* [CK_TILE] remove unnecessary MageKArgs overload in FmhaFwdV3Kernel and FmhaFwdKernel
This commit is contained in:
Jeff Huang
2025-09-26 12:36:27 +08:00
committed by GitHub
parent b0a2d99d10
commit 518d24e662
14 changed files with 1155 additions and 72 deletions

View File

@@ -36,6 +36,13 @@ args:
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode)
-s_k seqlen_k (including new key/value), -1 means equal to s (default:-1)
also with "-s_k=s0,s1,s2..." comma-separated ints to set seqlen per batch (group mode)
-s_qpad seqlen_q stride between 2 batches (group-mode optional) (default:-1)
Provide positive strides per-batch to simulate physical padding on Q
-s_kpad seqlen_k stride between 2 batches, currently used in group-mode only (default:-1)
for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride
along seqlen, instead of packed, same as xformer kv_padding,
must be greater than or equal to s_k
-d head dim for q, k (default:128)
-d_v head dim for v, -1 means equal to d (default:-1)
-scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
@@ -76,11 +83,20 @@ args:
-repeat number of iterations to benchmark the kernel (default:20)
-json 0: No Json, 1: Dump Results in Json format (default:0)
-jsonfile json file name to dump results (default:fmha_fwd.json)
-q_eff_lens Batch-mode only: per-batch effective seqlen for Q (exclude PAD) (default:"")
Comma-separated list of length 'b'. If empty, no override
-kv_eff_lens Batch-mode only: per-batch effective seqlen for KV (exclude PAD) (default:"")
Comma-separated list of length 'b'. If empty, no override
```
Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with
batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case
## Padding Examples
Example 3 (Group mode with padding): `./bin/tile_example_fmha_fwd -mode=1 -b=2 -h=8 -s=1024,2048 -s_k=1024,2048 -s_qpad=1536,3072 -s_kpad=1536,3072 -d=128` will run group mode with 2 batches having different sequence lengths (1024, 2048) but physically padded to (1536, 3072) respectively.
Example 4 (Batch mode with effective lengths): `./bin/tile_example_fmha_fwd -mode=0 -b=2 -h=8 -s=2048 -s_k=2048 -d=128 -q_eff_lens=1024,1536 -kv_eff_lens=1024,1536` will run batch mode where all batches use 2048 as physical sequence length but have effective lengths of (1024, 1536) for Q and KV respectively.
## support features
Currently we are still in rapid development stage, so more features/optimizations will be coming soon.
@@ -128,6 +144,15 @@ Note FA use bottom-right by default to express swa case, here we require you exp
### dropout
TBD
### sequence padding and variable length support
We support sequence padding and variable-length processing in both batch and group modes fmha forward to handle real-world scenarios where sequences have different lengths.
**Group Mode Padding**: Use `-s_qpad` and `-s_kpad` to specify physical stride between batches, enabling padded layouts. Each batch can have different logical sequence lengths (`-s`, `-s_k`) but use larger physical strides for memory alignment.
**Batch Mode Variable Length**: Use `-q_eff_lens` and `-kv_eff_lens` to specify effective sequence lengths per batch. All batches share the same physical sequence length, but the kernel processes only the effective portions. This enables efficient variable-length attention without memory waste.
Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios.
## FP8 experimental support
As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+.

View File

@@ -259,11 +259,11 @@ class FmhaFwdApiTrait:
def skcheck(self) -> str:
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
if self.skpad == 't' : return f'(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)'
else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)'
elif self.pipeline_tag in ['qr', 'qs']:
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k % {self.bn0} == 0'
else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)'
elif self.pipeline_tag == 'qr_async_trload':
if self.skpad == 't' : return 'true'
else: return 'true'

View File

@@ -33,6 +33,10 @@ auto create_args(int argc, char* argv[])
"0",
"seqlen_k for new key/value, 0 means not to use this at all; "
"-1 to choose s_knew in [1, s] randomly.")
.insert("s_qpad",
"-1",
"seqlen_q stride between 2 batches (group-mode optional).\n"
"Provide positive strides per-batch to simulate physical padding on Q.")
.insert("s_kpad",
"-1",
"seqlen_k stride between 2 batches, currently used in group-mode only\n"
@@ -107,7 +111,15 @@ auto create_args(int argc, char* argv[])
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "fmha_fwd.json", "json file name to dump results");
.insert("jsonfile", "fmha_fwd.json", "json file name to dump results")
.insert("q_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.")
.insert("kv_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -127,6 +139,9 @@ auto run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew");
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
auto seqlen_qpads = arg_parser.get_int_vec("s_qpad");
auto q_eff_lens_per_batch = arg_parser.get_int_vec("q_eff_lens");
auto kv_eff_lens_per_batch = arg_parser.get_int_vec("kv_eff_lens");
ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
bool i_perm = arg_parser.get_bool("iperm");
bool o_perm = arg_parser.get_bool("operm");
@@ -174,7 +189,10 @@ auto run(const ck_tile::ArgParser& arg_parser)
hdim_q,
hdim_v,
seqlen_knew,
seqlen_qpads,
seqlen_kpads,
q_eff_lens_per_batch,
kv_eff_lens_per_batch,
rotary_dim,
i_perm,
o_perm,

View File

@@ -52,7 +52,16 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
"random seed used for initializing input tensors. 0 for "
"non-deterministic seed")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "30", "number of iterations to benchmark the kernel");
.insert("repeat", "30", "number of iterations to benchmark the kernel")
// Optional effective seqlen override (exclude PAD) for batch mode
.insert("q_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.")
.insert("kv_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.");
bool result = arg_parser.parse(argc, argv);
return std::make_pair(result, arg_parser);
@@ -111,6 +120,8 @@ struct Problem
input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
q_eff_lens = args.get_int_vec("q_eff_lens");
kv_eff_lens = args.get_int_vec("kv_eff_lens");
}
std::vector<ck_tile::index_t> get_query_shape() const
@@ -172,6 +183,8 @@ struct Problem
mask_info mask;
TensorLayout input_layout;
TensorLayout output_layout;
std::vector<int> q_eff_lens;
std::vector<int> kv_eff_lens;
};
struct RunConfig
@@ -326,8 +339,10 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
q_buf.ToDevice(q.data());
k_buf.ToDevice(k.data());
v_buf.ToDevice(v.data());
// Ensure output buffer is zero-initialized so padded regions compare cleanly
o_buf.SetZero();
ck_tile::fmha_fwd_v3_args args;
ck_tile::fmha_fwd_v3_args args{};
args.data_type = problem.data_type;
args.batch = problem.batch;
@@ -380,6 +395,60 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
: problem.seqlen_q * problem.hdim;
args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim;
// Optional cumulative seqlen overrides (exclude PAD)
const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1;
const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1;
auto make_effective_vec = [&](const std::vector<int>& opt_vec, ck_tile::index_t fallback) {
std::vector<ck_tile::index_t> eff;
if(!opt_vec.empty() && opt_vec[0] != -1)
{
eff.assign(opt_vec.begin(), opt_vec.end());
if(eff.size() < static_cast<size_t>(problem.batch))
{
eff.resize(problem.batch, eff.back());
}
}
else
{
eff.assign(problem.batch, fallback);
}
return eff;
};
const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q);
const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k);
// Calculate cumulative sums for kernel arguments if varlen is used
std::vector<ck_tile::index_t> cuq_cum, cukv_cum;
auto calculate_cumulative = [&](const std::vector<ck_tile::index_t>& per_batch_vec,
std::vector<ck_tile::index_t>& cum_vec) {
cum_vec.resize(per_batch_vec.size() + 1);
cum_vec[0] = 0;
for(std::size_t i = 0; i < per_batch_vec.size(); ++i)
cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i];
};
if(has_varlen_q)
{
calculate_cumulative(eff_q_vec, cuq_cum);
}
if(has_varlen_k)
{
calculate_cumulative(eff_kv_vec, cukv_cum);
}
ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0);
ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0);
cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr);
cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr);
args.cu_seqlen_q_ptr =
!cuq_cum.empty() ? reinterpret_cast<const ck_tile::index_t*>(cuq_buf.GetDeviceBuffer())
: nullptr;
args.cu_seqlen_kv_ptr =
!cukv_cum.empty() ? reinterpret_cast<const ck_tile::index_t*>(cukv_buf.GetDeviceBuffer())
: nullptr;
ck_tile::stream_config stream_config{nullptr,
true,
/*log_level=*/0,
@@ -442,15 +511,72 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
o_ref = o_ref.transpose({0, 2, 1, 3});
}
host::fmha_fwd<float, DataType>(q,
k,
v,
problem.mask,
o_ref,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales{problem.softmax_scale});
// If variable lengths are provided, compute per-batch references
// with the effective lengths; else compute a single full reference.
if(has_varlen_q || has_varlen_k)
{
// Variable-length aware verification: zero-fill padded region and only compute valid part.
o_ref.SetZero();
for(int b = 0; b < problem.batch; ++b)
{
const ck_tile::index_t seqlen_q_eff = eff_q_vec[b];
const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b];
if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0)
continue;
// Slice current batch from inputs (bshd) and build single-batch tensors
ck_tile::HostTensor<DataType> q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
ck_tile::HostTensor<DataType> k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
ck_tile::HostTensor<DataType> v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
ck_tile::HostTensor<DataType> o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
// Copy effective region
q_b.ForEach([&](auto& self, auto idx) {
// idx: [0, s, h, d]
self(idx) = q(b, idx[1], idx[2], idx[3]);
});
k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); });
v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); });
// Compute reference for this batch segment (host::fmha_fwd expects bshd tensors)
host::fmha_fwd<float, DataType>(q_b,
k_b,
v_b,
problem.mask,
o_b,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales{problem.softmax_scale});
// Scatter into o_ref's bshd descriptor memory
for(int s = 0; s < seqlen_q_eff; ++s)
{
for(int h = 0; h < problem.nhead_q; ++h)
{
for(int d = 0; d < problem.hdim; ++d)
{
o_ref(b, s, h, d) = o_b(0, s, h, d);
}
}
}
}
}
else
{
// No varlen override: compute the full reference once
host::fmha_fwd<float, DataType>(q,
k,
v,
problem.mask,
o_ref,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales{problem.softmax_scale});
}
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
o_buf.FromDevice(o.data());

View File

@@ -162,11 +162,20 @@ struct fmha_fwd_args
void* lse_ptr;
void* o_ptr;
// Optional cumulative sequence length arrays
// Batch mode: cu_seqlen_* override effective per-batch lengths (exclude PAD)
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void*
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
// Group mode: seqstart_padded_* provide physical starts including PAD (optional)
const void* seqstart_padded_q_ptr = nullptr; // [batch+1]
const void* seqstart_padded_k_ptr = nullptr; // [batch+1]
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
@@ -554,7 +563,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.min_seqlen_q,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
args.drop_seed_offset,
args.seqstart_padded_q_ptr,
args.seqstart_padded_k_ptr);
}
else
{ // create batch mode kernel arguments
@@ -600,7 +611,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
args.drop_seed_offset,
args.cu_seqlen_q_ptr,
args.cu_seqlen_kv_ptr);
}
}();

View File

@@ -151,7 +151,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t seqlen_knew,
std::vector<ck_tile::index_t> seqlen_qpads,
std::vector<ck_tile::index_t> seqlen_kpads,
std::vector<ck_tile::index_t> q_eff_lens_per_batch,
std::vector<ck_tile::index_t> kv_eff_lens_per_batch,
ck_tile::index_t rotary_dim,
bool i_perm,
bool o_perm,
@@ -299,6 +302,24 @@ fwd_result fmha_fwd_run(mode_enum mode,
#endif
const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size);
// Reject unsupported padding usage in special pipelines (appendkv / splitkv / pagedkv)
const bool has_group_padding =
(mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1)) ||
(mode == mode_enum::group && (seqlen_kpads[0] >= 0));
const bool has_batch_efflens = (mode == mode_enum::batch && (!q_eff_lens_per_batch.empty() ||
!kv_eff_lens_per_batch.empty()));
const bool using_appendkv = (0 < seqlen_knew || 0 < rotary_dim);
const bool using_pagedkv = (0 < page_block_size);
const bool using_splitkv = (num_splits > 1) || use_cache_batch_idx;
if((using_appendkv || using_pagedkv || using_splitkv) &&
(has_group_padding || has_batch_efflens))
{
std::cerr << "Padding (physical or effective lengths) is not supported with "
"appendkv/splitkv/pagedkv pipelines"
<< std::endl;
return fwd_result::invalid_args;
}
std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) =
generate_missing_seqlens(mode,
batch,
@@ -362,6 +383,44 @@ fwd_result fmha_fwd_run(mode_enum mode,
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
// Optional padded Q seqstarts (group-mode only)
std::vector<int32_t> seqstart_q_with_padding_host;
if(mode == mode_enum::group && !seqlen_qpads.empty() && seqlen_qpads[0] != -1)
{
if(seqlen_qpads.size() < static_cast<size_t>(batch))
{
seqlen_qpads.resize(batch, seqlen_qpads.back());
}
if(seqlen_qpads.size() == static_cast<size_t>(batch))
{
seqstart_q_with_padding_host = to_seqstarts(
ck_tile::span<const int32_t>(seqlen_qpads.data(), seqlen_qpads.size()));
}
}
// Optional batch-mode cumulative seqlen overrides
std::vector<ck_tile::index_t> cuq_cum, cukv_cum;
if(mode == mode_enum::batch)
{
auto calculate_cumulative = [&](std::vector<ck_tile::index_t>& per_batch_vec,
std::vector<ck_tile::index_t>& cum_vec) {
if(!per_batch_vec.empty() && per_batch_vec[0] != -1)
{
if(per_batch_vec.size() < static_cast<size_t>(batch))
{
per_batch_vec.resize(batch, per_batch_vec.back());
}
cum_vec.resize(batch + 1);
cum_vec[0] = 0;
for(int i = 0; i < batch; ++i)
cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i];
}
};
calculate_cumulative(q_eff_lens_per_batch, cuq_cum);
calculate_cumulative(kv_eff_lens_per_batch, cukv_cum);
}
using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
using QDataType = typename TypeConfig::QDataType;
@@ -445,8 +504,15 @@ fwd_result fmha_fwd_run(mode_enum mode,
// host memory for storing all the tensor elements
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
const ck_tile::index_t shape_seqlen_q =
// logical(unpadded) total seqlen_q for group; batch uses fixed seqlen
const ck_tile::index_t shape_seqlen_q_lse =
(mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back());
// physical(padded) total seqlen_q for group when s_qpad is provided; else use logical
const ck_tile::index_t shape_seqlen_q =
(mode == mode_enum::batch
? seqlen_qs[0]
: (seqstart_q_with_padding_host.empty() ? seqstart_q_host.back()
: seqstart_q_with_padding_host.back()));
const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_ks[0]
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
@@ -504,7 +570,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
// batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q_lse}
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<ODataType> o_host(
@@ -602,6 +668,16 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_q_padded_buf(seqstart_q_with_padding_host.empty()
? 0
: seqstart_q_with_padding_host.size() *
sizeof(int32_t));
ck_tile::DeviceMem seqstart_k_padded_buf(
seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t));
ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0
: cuq_cum.size() * sizeof(ck_tile::index_t));
ck_tile::DeviceMem cu_seqlen_kv_buf(
cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t));
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) ||
0 <= seqlen_kpads[0]
? seqlen_ks.size() * sizeof(int32_t)
@@ -693,8 +769,14 @@ fwd_result fmha_fwd_run(mode_enum mode,
vnew_buf.ToDevice(vnew_host.data());
bias_buf.ToDevice(bias_host.data());
seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data()
: seqstart_k_with_padding_host.data());
// Keep logical starts in seqstart_k; pass padded K via separate pointer
seqstart_k.ToDevice(seqstart_k_host.data());
seqstart_q_padded_buf.ToDevice(
seqstart_q_with_padding_host.empty() ? nullptr : seqstart_q_with_padding_host.data());
seqstart_k_padded_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr
: seqstart_k_with_padding_host.data());
cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data());
cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data());
seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0]
? seqlen_ks.data()
: nullptr);
@@ -747,6 +829,54 @@ fwd_result fmha_fwd_run(mode_enum mode,
std::cout << ", cache_batch_idx:" << use_cache_batch_idx;
}
#endif
// Padding / effective length diagnostic logging
auto print_vec = [&](const char* label, const std::vector<int>& v) {
if(v.empty())
return;
std::cout << ", " << label << ":[";
for(std::size_t i = 0; i < v.size(); ++i)
{
if(i)
std::cout << ",";
std::cout << v[i];
}
std::cout << "]";
};
if(has_group_padding)
{
bool has_qpad = !seqstart_q_with_padding_host.empty();
bool has_kpad = (seqlen_kpads[0] >= 0);
if(has_qpad)
{
print_vec("q_logical", seqlen_qs);
print_vec("q_padded", seqlen_qpads);
}
if(has_kpad)
{
print_vec("k_logical", seqlen_ks);
print_vec("k_padded", seqlen_kpads);
}
}
else if(has_batch_efflens)
{
// derive effective lengths from cumulative arrays if present
if(!cuq_cum.empty())
{
std::vector<int> eff_q(batch);
for(int b_i = 0; b_i < batch; ++b_i)
eff_q[b_i] = static_cast<int>(cuq_cum[b_i + 1] - cuq_cum[b_i]);
print_vec("q_eff", eff_q);
}
if(!cukv_cum.empty())
{
std::vector<int> eff_kv(batch);
for(int b_i = 0; b_i < batch; ++b_i)
eff_kv[b_i] = static_cast<int>(cukv_cum[b_i + 1] - cukv_cum[b_i]);
print_vec("kv_eff", eff_kv);
}
}
std::cout << std::flush;
const auto init_traits = [&](auto& traits) {
@@ -830,8 +960,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
const ck_tile::index_t nhead_stride_bias =
(i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q_lse;
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse);
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments
@@ -846,8 +976,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q_lse);
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q_lse);
const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
@@ -961,6 +1091,29 @@ fwd_result fmha_fwd_run(mode_enum mode,
{
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
}
// Group-mode: optional physical padded starts for Q/K
if(mode == mode_enum::group)
{
args.seqstart_padded_q_ptr = (seqstart_q_with_padding_host.empty()
? nullptr
: seqstart_q_padded_buf.GetDeviceBuffer());
args.seqstart_padded_k_ptr =
(seqlen_kpads[0] < 0 ? nullptr : seqstart_k_padded_buf.GetDeviceBuffer());
}
// Batch-mode: optional cumulative effective seqlen overrides
if(mode == mode_enum::batch)
{
args.cu_seqlen_q_ptr = cuq_cum.empty()
? nullptr
: reinterpret_cast<const ck_tile::index_t*>(
cu_seqlen_q_buf.GetDeviceBuffer());
args.cu_seqlen_kv_ptr = cukv_cum.empty()
? nullptr
: reinterpret_cast<const ck_tile::index_t*>(
cu_seqlen_kv_buf.GetDeviceBuffer());
}
}
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
{
@@ -1167,15 +1320,29 @@ fwd_result fmha_fwd_run(mode_enum mode,
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
if(mode == mode_enum::batch)
{
if(!cuq_cum.empty())
{
real_seqlen_q = cuq_cum[wb + 1] - cuq_cum[wb];
}
if(!cukv_cum.empty())
{
real_seqlen_k = cukv_cum[wb + 1] - cukv_cum[wb];
}
}
// adjust matrix index according to the mode
const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0);
const ck_tile::index_t cache_b_idx =
(use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx);
const ck_tile::index_t query_offset =
(mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
(mode == mode_enum::batch
? 0
: (seqstart_q_with_padding_host.empty() ? seqstart_q_host[wb]
: seqstart_q_with_padding_host[wb]));
const ck_tile::index_t key_offset =
(mode == mode_enum::batch
? 0
@@ -1538,8 +1705,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
if(lse)
{
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
const ck_tile::index_t query_offset_lse =
(mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
lse_host_result.ForEach([&](auto& self, auto idx) {
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset);
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset_lse);
});
cur_pass = ck_tile::check_err(lse_host_result,

View File

@@ -56,6 +56,11 @@ struct fmha_fwd_v3_args
index_t stride_o;
index_t nhead_stride_o;
index_t batch_stride_o;
// Optional batch-mode cumulative seqlen overrides (exclude PAD)
// If provided, they override per-batch effective lengths to skip tail padding.
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
};
std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type);

View File

@@ -158,7 +158,9 @@ float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_confi
args.window_size_left,
args.window_size_right,
args.mask_type,
remap_opt);
remap_opt,
args.cu_seqlen_q_ptr,
args.cu_seqlen_kv_ptr);
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v);
constexpr dim3 blocks = Kernel::BlockSize();

View File

@@ -18,3 +18,36 @@ $EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kn
done
done
done
#Padding Benchmarks: batch mode (baseline vs low/med/high pad)
prec="fp16"
base_batch_args="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID"
# baseline (no pad)
$EXE $base_batch_args
# low pad (≈9095% effective)
$EXE $base_batch_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
# medium pad (≈6075% effective)
$EXE $base_batch_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
# high pad (≈3040% effective)
$EXE $base_batch_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320
# Padding Benchmarks: group mode (baseline vs low/med/high physical pad)
seqlens_q="1024,768,512,256"
seqlens_k="1024,768,512,256"
base_group_args="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID"
# baseline (no physical pad)
$EXE $base_group_args
# low physical pad
$EXE $base_group_args -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320
# medium physical pad
$EXE $base_group_args -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384
# high physical pad
$EXE $base_group_args -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512

View File

@@ -23,3 +23,20 @@ done
done
done
done
# Padding benchmark comparisons for v3 (batch mode only)
# ==== V3 Padding Benchmarks: batch mode (baseline vs low/med/high pad) ====
prec="fp16"
base_v3_args="-prec=$prec -b=4 -h=16 -d=128 -s=1024 -mask=0 -iperm=0 -operm=0 -v=$VALID"
# baseline (no pad)
$EXE $base_v3_args
# low pad (≈9095% effective)
$EXE $base_v3_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
# medium pad (≈6075% effective)
$EXE $base_v3_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
# high pad (≈3040% effective)
$EXE $base_v3_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320

View File

@@ -137,9 +137,118 @@ run_fp16_appendkv_tests() {
done ; done ; done
}
run_padding_smoke_tests() {
# Padding-only smoke tests for batch/group mode using COMMON_ARGS
local prec="fp16"
# Batch mode: padding via effective lengths (exclude PAD)
# Use lse=1 to select a non-trload kernel and avoid overly strict tolerance mismatches
local base_batch="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=1 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS"
# low pad (≈9095% effective)
$EXE $base_batch -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
# medium pad (≈6075% effective)
$EXE $base_batch -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
# high pad (≈3040% effective)
$EXE $base_batch -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320
# Group mode: padding via physical stride along seqlen
local seqlens_q="1024,768,512,256"
local seqlens_k="1024,768,512,256"
local base_group="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS"
# low physical pad
$EXE $base_group -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320
# medium physical pad
$EXE $base_group -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384
# high physical pad
$EXE $base_group -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512
}
run_padding_basic_boundary_tests() {
# Basic padding and boundary tests (reference: smoke_test_fwd_pad.sh)
local prec
local perm
# Group mode: Q&K padded with per-batch different strides
for prec in fp16 bf16 ; do
for perm in 0 1 ; do
$EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=16 -d_v=32 \
-s=55 -s_k=256 -s_qpad=64,60 -s_kpad=272,260 \
-bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
done
done
# slightly larger, uneven padding strides
for prec in fp16 bf16 ; do
for perm in 0 1 ; do
$EXE -prec=$prec -mode=1 -b=3 -h=2 -h_k=1 -d=64 -d_v=64 \
-s=50,60,40 -s_k=128,256,192 -s_qpad=64,64,64 -s_kpad=160,288,224 \
-bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
done
done
# only K padded; Q unpadded
for prec in fp16 bf16 ; do
for perm in 0 1 ; do
$EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 \
-s=55 -s_k=256 -s_kpad=272,260 \
-bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
done
done
# use cu_seqlen overrides to skip tail PAD
for prec in fp16 bf16 ; do
for perm in 0 1 ; do
$EXE -prec=$prec -mode=0 -b=4 -h=8 -h_k=8 -d=128 -s=3 -s_k=3 \
-q_eff_lens=1,2,1,2 -kv_eff_lens=1,2,1,2 \
-bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=0 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 -s=64 -s_k=256 \
-q_eff_lens=55,60 -kv_eff_lens=200,256 \
-bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
done
done
# no padding (equal), mixed Q/KV, all len=1
for prec in fp16 bf16 ; do
$EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \
-q_eff_lens=128,128,128,128 -kv_eff_lens=128,128,128,128 \
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \
-q_eff_lens=10,20,30,40 -kv_eff_lens=40,30,20,10 \
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \
-q_eff_lens=1,1,1,1 -kv_eff_lens=1,1,1,1 \
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
done
# highly variable logical lengths
for prec in fp16 bf16 ; do
$EXE -prec=$prec -mode=1 -b=4 -h=4 -d=32 \
-s=1,127,3,65 -s_k=1,127,3,65 -s_kpad=128 \
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
done
# GQA + Alibi + Causal mask (keep vlayout row-major for fp16/bf16
for prec in fp16 bf16 ; do
$EXE -prec=$prec -mode=1 -b=2 -h=16 -h_k=4 -d=128 \
-s=256,129 -s_k=256,129 -s_kpad=256 \
-bias=a -mask=t -lse=1 -iperm=0 -operm=0 -vlayout=r \
-kname=$KNAME $COMMON_ARGS
done
}
set -x
run_fp16_bf16_tests
run_padding_smoke_tests
run_padding_basic_boundary_tests
run_fp8_tests
run_fp8bf16_tests
run_fp8fp32_tests

View File

@@ -293,6 +293,11 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_o;
// Optional cumulative sequence length pointers for batch mode
// If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // cumulative, length without PAD
};
struct FmhaFwdGroupModeKargs
@@ -312,6 +317,11 @@ struct FmhaFwdKernel
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr;
// Optional cumulative padded sequence starts (including PAD tokens)
// Used solely to compute memory offsets when sequences are physically padded.
const int32_t* seqstart_padded_q_ptr = nullptr;
const int32_t* seqstart_padded_k_ptr = nullptr;
};
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
@@ -368,7 +378,9 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
drop_seed_offset,
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -459,6 +471,8 @@ struct FmhaFwdKernel
kargs.init_logits_soft_cap(logits_soft_cap);
}
kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
return kargs;
}
@@ -507,7 +521,9 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
{
return MakeKargsImpl(
q_ptr,
@@ -552,7 +568,9 @@ struct FmhaFwdKernel
mask_type,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
cu_seqlen_q_ptr,
cu_seqlen_kv_ptr);
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
@@ -600,7 +618,9 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<const void*, const void*>& drop_seed_offset)
const std::tuple<const void*, const void*>& drop_seed_offset,
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
{
return MakeKargsImpl(
q_ptr,
@@ -645,7 +665,9 @@ struct FmhaFwdKernel
mask_type,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
cu_seqlen_q_ptr,
cu_seqlen_kv_ptr);
}
template <bool Cond = kIsGroupMode>
@@ -688,7 +710,9 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
drop_seed_offset,
const void* seqstart_padded_q_ptr = nullptr,
const void* seqstart_padded_k_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -780,6 +804,8 @@ struct FmhaFwdKernel
kargs.min_seqlen_q = min_seqlen_q;
}
kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
return kargs;
}
@@ -823,7 +849,9 @@ struct FmhaFwdKernel
ck_tile::index_t min_seqlen_q,
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
const void* seqstart_padded_q_ptr = nullptr,
const void* seqstart_padded_k_ptr = nullptr)
{
return MakeKargsImpl(
q_ptr,
@@ -863,7 +891,9 @@ struct FmhaFwdKernel
min_seqlen_q,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
seqstart_padded_q_ptr,
seqstart_padded_k_ptr);
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
@@ -906,7 +936,9 @@ struct FmhaFwdKernel
ck_tile::index_t min_seqlen_q,
float p_drop,
bool s_randval,
const std::tuple<const void*, const void*>& drop_seed_offset)
const std::tuple<const void*, const void*>& drop_seed_offset,
const void* seqstart_padded_q_ptr = nullptr,
const void* seqstart_padded_k_ptr = nullptr)
{
return MakeKargsImpl(
q_ptr,
@@ -946,7 +978,9 @@ struct FmhaFwdKernel
min_seqlen_q,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
seqstart_padded_q_ptr,
seqstart_padded_k_ptr);
}
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
@@ -1075,35 +1109,44 @@ struct FmhaFwdKernel
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
// logical and physical (padded) starts
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
? kargs.seqstart_padded_q_ptr[i_batch]
: query_start_unpadded;
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
? kargs.seqstart_padded_k_ptr[i_batch]
: key_start_unpadded;
// DRAM base offsets use physical padded starts
batch_offset_q = query_start_padded * kargs.stride_q;
batch_offset_k = key_start_padded * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
batch_offset_v = key_start_padded * kargs.stride_v;
}
else
{
batch_offset_v = key_start;
batch_offset_v = key_start_padded;
}
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start * kargs.stride_bias;
batch_offset_bias = query_start_padded * kargs.stride_bias;
}
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start;
// LSE stays indexed by unpadded starts
batch_offset_lse = query_start_unpadded;
}
if constexpr(kHasDropout)
{
batch_offset_randval = query_start * kargs.stride_randval;
batch_offset_randval = query_start_padded * kargs.stride_randval;
}
batch_offset_o = query_start * kargs.stride_o;
batch_offset_o = query_start_padded * kargs.stride_o;
// get real # queries & # keys under group mode
// real logical lengths (exclude PAD)
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
@@ -1115,8 +1158,7 @@ struct FmhaFwdKernel
}
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
// terminate unnecessary blocks earlier
if(kargs.seqlen_q <= i_m0)
{
return;
@@ -1152,6 +1194,18 @@ struct FmhaFwdKernel
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
}
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
// If cumulative seqlen pointers are provided, override per-batch effective lengths
if(kargs.cu_seqlen_q_ptr != nullptr)
{
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
if(kargs.cu_seqlen_kv_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
}
}
// for simplicity, batch stride we just modify the pointer
@@ -1550,26 +1604,35 @@ struct FmhaFwdKernel
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
? kargs.seqstart_padded_q_ptr[i_batch]
: query_start_unpadded;
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
? kargs.seqstart_padded_k_ptr[i_batch]
: key_start_unpadded;
batch_offset_q = query_start_padded * kargs.stride_q;
batch_offset_k = key_start_padded * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
batch_offset_v = key_start_padded * kargs.stride_v;
}
else
{
batch_offset_v = key_start;
// col-major V: offset along seqlen dimension is scalar index
batch_offset_v = key_start_padded;
}
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start * kargs.stride_bias;
batch_offset_bias = query_start_padded * kargs.stride_bias;
}
batch_offset_lse = query_start;
batch_offset_o = query_start * kargs.stride_o;
// LSE layout is [nhead, total_seqlen], index by unpadded start
batch_offset_lse = query_start_unpadded;
batch_offset_o = query_start_padded * kargs.stride_o;
// get real # queries & # keys under group mode
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
@@ -1607,6 +1670,18 @@ struct FmhaFwdKernel
batch_offset_bias =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
}
// If cumulative seqlen pointers are provided, override per-batch effective lengths
if(kargs.cu_seqlen_q_ptr != nullptr)
{
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
if(kargs.cu_seqlen_kv_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
}
}
// for simplicity, batch stride we just modify the pointer

View File

@@ -100,6 +100,11 @@ struct FmhaFwdV3Kernel
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_o;
// Optional cumulative sequence length pointers for batch mode
// If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
};
struct FmhaFwdGroupModeKargs
@@ -110,6 +115,11 @@ struct FmhaFwdV3Kernel
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr;
// Optional cumulative padded sequence starts (including PAD tokens)
// Used solely to compute memory offsets when sequences are physically padded.
const int32_t* seqstart_padded_q_ptr = nullptr; // [batch+1]
const int32_t* seqstart_padded_k_ptr = nullptr; // [batch+1]
};
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
@@ -145,7 +155,9 @@ struct FmhaFwdV3Kernel
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
ck_tile::index_t remap_opt)
ck_tile::index_t remap_opt,
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -187,6 +199,8 @@ struct FmhaFwdV3Kernel
kargs.batch_stride_lse = batch_stride_lse;
}
kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
return kargs;
}
@@ -217,7 +231,9 @@ struct FmhaFwdV3Kernel
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
ck_tile::index_t remap_opt)
ck_tile::index_t remap_opt,
const void* seqstart_padded_q_ptr = nullptr,
const void* seqstart_padded_k_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -257,6 +273,8 @@ struct FmhaFwdV3Kernel
kargs.nhead_stride_lse = nhead_stride_lse;
}
kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
return kargs;
}
@@ -373,18 +391,26 @@ struct FmhaFwdV3Kernel
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_v = key_start * kargs.stride_v;
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
? kargs.seqstart_padded_q_ptr[i_batch]
: query_start_unpadded;
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
? kargs.seqstart_padded_k_ptr[i_batch]
: key_start_unpadded;
batch_offset_q = query_start_padded * kargs.stride_q;
batch_offset_k = key_start_padded * kargs.stride_k;
batch_offset_v = key_start_padded * kargs.stride_v;
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start;
// LSE layout is [nhead, total_seqlen], index by unpadded start
batch_offset_lse = query_start_unpadded;
}
batch_offset_o = query_start * kargs.stride_o;
batch_offset_o = query_start_padded * kargs.stride_o;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
@@ -417,6 +443,18 @@ struct FmhaFwdV3Kernel
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
// If cumulative seqlen pointers are provided, override per-batch effective lengths
if(kargs.cu_seqlen_q_ptr != nullptr)
{
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
if(kargs.cu_seqlen_kv_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
}
}
// for simplicity, batch stride we just modify the pointer

View File

@@ -98,7 +98,10 @@ TEST_P(AllLong, Test)
hdim_q,
hdim_v,
0, // seqlen_knew
{-1}, // seqlen_qpads
{seqlen_kpad}, // seqlen_kpads
{}, // q_eff_lens_per_batch
{}, // kv_eff_lens_per_batch
0, // rotary_dim
perm, // i_perm
perm, // o_perm
@@ -121,6 +124,141 @@ TEST_P(AllLong, Test)
CHECK_RESULT(result);
}
// ---------------------------------------------------------------
// Negative tests: padding not supported with appendkv/splitkv/pagedkv
// ---------------------------------------------------------------
#if CK_TILE_FMHA_FWD_APPENDKV_API
TEST(TestCkTileFmhaFwd, AppendKvWithBatchEffLensShouldFail)
{
// batch mode effective lengths simulate padding
auto result = fmha_fwd_run<DataTypeConfig>(
mode_enum::batch,
2, // batch
4, // nhead
-1, // nhead_k
{128}, // seqlen_qs
{128}, // seqlen_ks
64, // hdim_q
64, // hdim_v
32, // seqlen_knew -> triggers appendkv
{}, // seqlen_qpads
{}, // seqlen_kpads
{100, 120}, // q_eff_lens_per_batch
{90, 110}, // kv_eff_lens_per_batch
0, // rotary_dim
true, // i_perm
true, // o_perm
0, // scale_s
0, // logits_soft_cap
def_is_v_rowmajor,
def_lse,
0, // page_block_size
false, // use_cache_batch_idx
"n", // bias
0.0f, // p_drop
0, // drop_seed
0, // drop_offset
false, // drop_prefs
"0", // mask
squant,
true, // is_rotary_interleaved
1, // num_splits
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
0,
stream_config);
ASSERT_EQ(result, fwd_result::invalid_args);
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API
TEST(TestCkTileFmhaFwd, SplitKvWithGroupPaddingShouldFail)
{
// group mode physical padding
auto result = fmha_fwd_run<DataTypeConfig>(
mode_enum::group,
2, // batch
4, // nhead
-1, // nhead_k
{96, 120}, // seqlen_qs logical
{96, 120}, // seqlen_ks logical
64, // hdim_q
64, // hdim_v
0, // seqlen_knew
{128, 128}, // seqlen_qpads
{128, 128}, // seqlen_kpads
{}, // q_eff
{}, // kv_eff
0, // rotary_dim
true, // i_perm
true, // o_perm
0, // scale_s
0, // logits_soft_cap
def_is_v_rowmajor,
def_lse,
0, // page_block_size
false, // use_cache_batch_idx
"n", // bias
0.0f,
0,
0,
false,
"0",
squant,
true,
2, // num_splits (>1 triggers splitkv)
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
0,
stream_config);
ASSERT_EQ(result, fwd_result::invalid_args);
}
#endif
#if CK_TILE_FMHA_FWD_PAGEDKV_API
TEST(TestCkTileFmhaFwd, PagedKvWithGroupPaddingShouldFail)
{
auto result = fmha_fwd_run<DataTypeConfig>(
mode_enum::group,
2,
4,
-1,
{80, 100},
{80, 100},
64,
64,
0, // seqlen_knew
{96, 128}, // seqlen_qpads
{96, 128}, // seqlen_kpads
{},
{},
0,
true,
true,
0,
0,
def_is_v_rowmajor,
def_lse,
128, // page_block_size triggers pagedkv
false,
"n",
0.0f,
0,
0,
false,
"0",
squant,
true,
1,
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
0,
stream_config);
ASSERT_EQ(result, fwd_result::invalid_args);
}
#endif
class HDimPadding
: public TestWithParam<std::tuple<std::tuple<int, int>,
bool,
@@ -160,7 +298,10 @@ TEST_P(HDimPadding, Test)
hdim_q,
hdim_v,
0, // seqlen_knew
{-1}, // seqlen_qpads
{seqlen_kpad}, // seqlen_kpads
{}, // q_eff_lens_per_batch
{}, // kv_eff_lens_per_batch
0, // rotary_dim
perm, // i_perm
perm, // o_perm
@@ -217,7 +358,10 @@ TEST_P(ElementwiseBias, Test)
hdim_q,
hdim_v,
0, // seqlen_knew
{-1}, // seqlen_qpads
{-1}, // seqlen_kpads
{}, // q_eff_lens_per_batch
{}, // kv_eff_lens_per_batch
0, // rotary_dim
i_perm, // i_perm
false, // o_perm
@@ -273,7 +417,10 @@ TEST_P(Alibi, Test)
hdim_q,
hdim_v,
0, // seqlen_knew
{-1}, // seqlen_qpads
{-1}, // seqlen_kpads
{}, // q_eff_lens_per_batch
{}, // kv_eff_lens_per_batch
0, // rotary_dim
true, // i_perm
true, // o_perm
@@ -331,7 +478,10 @@ TEST_P(Dropout, Test)
hdim_q,
hdim_v,
0, // seqlen_knew
{-1}, // seqlen_qpads
{-1}, // seqlen_kpads
{}, // q_eff_lens_per_batch
{}, // kv_eff_lens_per_batch
0, // rotary_dim
false, // i_perm
false, // o_perm
@@ -391,7 +541,10 @@ TEST_P(PagedKV, Test)
hdim_q,
hdim_v,
0, // seqlen_knew
{-1}, // seqlen_qpads
{-1}, // seqlen_kpads
{}, // q_eff_lens_per_batch
{}, // kv_eff_lens_per_batch
0, // rotary_dim
i_perm, // i_perm
false, // o_perm
@@ -457,7 +610,10 @@ TEST_P(SplitKV, Test)
hdim_q,
hdim_v,
0, // seqlen_knew
{-1}, // seqlen_qpads
{-1}, // seqlen_kpads
{}, // q_eff_lens_per_batch
{}, // kv_eff_lens_per_batch
0, // rotary_dim
i_perm, // i_perm
false, // o_perm
@@ -529,7 +685,10 @@ TEST_P(AppendKV, Test)
hdim_q,
hdim_v,
seqlen_knew, // seqlen_knew
{-1}, // seqlen_qpads
{-1}, // seqlen_kpads
{}, // q_eff_lens_per_batch
{}, // kv_eff_lens_per_batch
0, // rotary_dim
i_perm, // i_perm
true, // o_perm
@@ -599,7 +758,10 @@ TEST_P(AppendKVRoPE, Test)
hdim_q,
hdim_v,
seqlen_knew, // seqlen_knew
{-1}, // seqlen_qpads
{-1}, // seqlen_kpads
{}, // q_eff_lens_per_batch
{}, // kv_eff_lens_per_batch
rotary_dim, // rotary_dim
i_perm, // i_perm
true, // o_perm
@@ -623,3 +785,294 @@ TEST_P(AppendKVRoPE, Test)
}
#endif // CK_TILE_FMHA_FWD_APPENDKV_API
// ---------------------------------------------------------------
// Parameterized padding tests (batch & group) using Combine+Values
// ---------------------------------------------------------------
using PaddingParam = std::tuple<mode_enum, // mode
int, // batch
int, // nhead
int, // nhead_k
std::vector<int>, // seqlen_qs (logical)
std::vector<int>, // seqlen_ks (logical)
std::vector<int>, // seqlen_qpads (physical padded lengths)
std::vector<int>, // seqlen_kpads (physical padded lengths)
std::vector<int>, // q_eff_lens
std::vector<int>, // kv_eff_lens
bool, // i_perm
bool, // o_perm
std::string>; // mask_str
// Ensure headers for containers / algorithms used in padding param builder.
#include <vector>
#include <array>
#include <cmath>
#include <algorithm>
class PaddingCases : public TestWithParam<PaddingParam>
{
};
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(PaddingCases);
// Build padding test params programmatically to enforce constraints
static std::vector<PaddingParam> BuildPaddingParams()
{
std::vector<PaddingParam> params;
// mask variants to cover
const std::vector<std::string> mask_variants{"0", "t:50,64", "b:32,40"};
const std::vector<std::string> mask_variants_reduced{"0", "t:50,64"}; // used for trimmed sets
// Representative ratio pairs (q_ratio, k_ratio) to avoid explosion
const std::vector<std::pair<double, double>> ratio_pairs_full{
{1.0, 1.0}, // both full
{1.0, 0.5}, // q full, k half
{0.5, 1.0}, // q half, k full
};
const std::vector<std::pair<double, double>> ratio_pairs_reduced{{1.0, 1.0}, {0.5, 1.0}};
// candidate physical seqlens for batch mode (single value) & for group mode (per batch)
const std::vector<int> physical_lengths_full{64, 128, 256};
const std::vector<int> physical_lengths_reduced{64};
// batch sizes to sample
const std::vector<int> batch_sizes{1, 4};
// --------------------------------------------------------------------
// Head configuration space (cover MHA, GQA, MQA)
// - Standard MHA: nhead_k == -1 (treated internally as nhead)
// - GQA: nhead_k > 0 and nhead % nhead_k == 0, nhead_k < nhead
// - MQA: nhead_k == 1
// We choose (9, -1), (9, 3), (9, 1) so that divisibility holds. Full
// combinatorics only applied to the first (standard) configuration to
// avoid test explosion.
// --------------------------------------------------------------------
struct HeadCfg
{
int nhead;
int nhead_k; // -1 for standard; else must divide nhead
bool full; // whether to use full coverage sets
};
const std::vector<HeadCfg> head_cfgs = {
{9, -1, true}, // MHA full
{9, 3, false}, // GQA reduced (nhead/nhead_k=3)
{9, 1, false} // MQA reduced
};
// Helper to clamp and ensure >=1
auto logical_len = [](int physical, double ratio) {
int v = static_cast<int>(std::round(physical * ratio));
v = std::max(1, std::min(v, physical));
return v;
};
// Iterate over head configurations
for(const auto& hc : head_cfgs)
{
const auto& ratio_pairs = hc.full ? ratio_pairs_full : ratio_pairs_reduced;
const auto& phys_lengths_batch = hc.full ? physical_lengths_full : physical_lengths_reduced;
const auto& phys_lengths_group_q = phys_lengths_batch; // reuse
const auto& phys_lengths_group_k = phys_lengths_batch; // reuse
const auto& masks = hc.full ? mask_variants : mask_variants_reduced;
// -----------------
// Batch mode params (effective lengths only)
// -----------------
for(int b : batch_sizes)
{
for(int phys_qkv : phys_lengths_batch)
{
for(const auto& rkpair : ratio_pairs)
{
double rq = rkpair.first;
double rk = rkpair.second;
std::vector<int> q_eff(b), kv_eff(b);
int log_q = logical_len(phys_qkv, rq);
int log_k = logical_len(phys_qkv, rk);
for(int i = 0; i < b; ++i)
{
q_eff[i] = log_q;
kv_eff[i] = log_k;
}
for(const auto& mask : masks)
{
params.emplace_back(PaddingParam{mode_enum::batch,
b,
hc.nhead,
hc.nhead_k,
{phys_qkv}, // seqlen_qs
{phys_qkv}, // seqlen_ks
{}, // seqlen_qpads
{}, // seqlen_kpads
q_eff,
kv_eff,
true,
true,
mask});
}
}
// Single-token logical length case (both q & k = 1)
for(const auto& mask : masks)
{
std::vector<int> q_eff(b, 1), kv_eff(b, 1);
params.emplace_back(PaddingParam{mode_enum::batch,
b,
hc.nhead,
hc.nhead_k,
{phys_qkv},
{phys_qkv},
{},
{},
q_eff,
kv_eff,
true,
true,
mask});
}
}
}
// -----------------
// Group mode params (physical padding + logical variants)
// -----------------
for(int b : batch_sizes)
{
for(int phys_q : phys_lengths_group_q)
{
for(int phys_k : phys_lengths_group_k)
{
for(const auto& rkpair : ratio_pairs)
{
double rq = rkpair.first;
double rk = rkpair.second;
std::vector<int> seqlen_qs(b), seqlen_ks(b), seqlen_qpads(b),
seqlen_kpads(b);
for(int i = 0; i < b; ++i)
{
seqlen_qpads[i] = phys_q;
seqlen_kpads[i] = phys_k;
seqlen_qs[i] = logical_len(phys_q, rq);
seqlen_ks[i] = logical_len(phys_k, rk);
}
std::array<std::pair<std::vector<int>, std::vector<int>>, 3> pad_variants{
std::pair{seqlen_qpads, seqlen_kpads}, // both
std::pair{seqlen_qpads, seqlen_ks}, // only q padding
std::pair{seqlen_qs, seqlen_kpads} // only kv padding
};
for(const auto& mask : masks)
{
for(const auto& pv : pad_variants)
{
params.emplace_back(PaddingParam{mode_enum::group,
b,
hc.nhead,
hc.nhead_k,
seqlen_qs,
seqlen_ks,
pv.first,
pv.second,
{},
{},
true,
true,
mask});
}
}
}
// Single-token logical length case
for(const auto& mask : masks)
{
std::vector<int> seqlen_qs(b, 1), seqlen_ks(b, 1);
std::vector<int> seqlen_qpads(b, phys_q), seqlen_kpads(b, phys_k);
// both padding variant only (others degenerate)
params.emplace_back(PaddingParam{mode_enum::group,
b,
hc.nhead,
hc.nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
{},
{},
true,
true,
mask});
}
}
}
}
}
return params;
}
static const std::vector<PaddingParam> kPaddingParams = BuildPaddingParams();
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd_Padding, PaddingCases, ValuesIn(kPaddingParams));
TEST_P(PaddingCases, Test)
{
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8>)
{
GTEST_SKIP() << "Skip for fp8";
}
auto [mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
q_eff_lens,
kv_eff_lens,
i_perm,
o_perm,
mask_str] = GetParam();
// For batch mode we wrap single logical lengths with adjust_seqlen.
std::vector<int> adj_qs =
(mode == mode_enum::batch) ? std::vector<int>{adjust_seqlen(seqlen_qs.at(0))} : seqlen_qs;
std::vector<int> adj_ks =
(mode == mode_enum::batch) ? std::vector<int>{adjust_seqlen(seqlen_ks.at(0))} : seqlen_ks;
const int hdim_q = 64;
const int hdim_v = 64;
const int seqlen_knew = 0;
auto result = fmha_fwd_run<DataTypeConfig>(mode,
batch,
nhead,
nhead_k,
adj_qs,
adj_ks,
hdim_q,
hdim_v,
seqlen_knew, // seqlen_knew
seqlen_qpads, // seqlen_qpads
seqlen_kpads, // seqlen_kpads
q_eff_lens, // q_eff_lens_per_batch
kv_eff_lens, // kv_eff_lens_per_batch
0, // rotary_dim
i_perm, // i_perm
o_perm, // o_perm
0, // scale_s
0, // logits_soft_cap
def_is_v_rowmajor,
def_lse, // lse
0, // page_block_size
false, // use_cache_batch_idx
"n", // bias_str
0.0f, // p_drop
0, // drop_seed
0, // drop_offset
false, // drop_prefs
mask_str, // mask_str
squant,
true, // is_rotary_interleaved
1, // num_splits
COMMON_ARGS);
CHECK_RESULT(result);
}