[CK_TILE] fmha: Add query padding support to backward pass

Introduces support for query sequence padding (q_padding) in the FMHA backward pass kernels.
- Passing `seqlen_q_ptr` to the backward kernels to distinguish logical from physical sequence lengths.
- Updating `OGradDotO`, `ConvertQGrad`, and `DQDKDV` kernels to respect logical lengths and handle zero-length sequences.
- Aligning LSE indexing in the forward kernel with the padded layout for consistency.
- Adding a new GTest suite (`test_fmha_bwd_kernel_padding.cpp`) with comprehensive tests for various padding scenarios, including zero-length
  sequences and deterministic mode.
This commit is contained in:
Jeff Huang
2025-10-08 15:35:02 +08:00
parent 86d542f663
commit 4e06eaa417
10 changed files with 839 additions and 35 deletions

View File

@@ -24,11 +24,19 @@ auto create_args(int argc, char* argv[])
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
"also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_qpad",
"-1",
"padded seqlen_q per batch (group mode only). "
"Use \"-s_qpad=p0,p1,...\"; -1 disables explicit padding")
.insert("s_k",
"-1",
"seqlen_k, -1 means equal to s\n"
"also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_kpad",
"-1",
"padded seqlen_k per batch (group mode only). "
"Use \"-s_kpad=k0,k1,...\"; -1 disables explicit padding")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)")
@@ -96,7 +104,9 @@ auto run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
auto seqlen_qs = arg_parser.get_int_vec("s");
auto seqlen_qpads = arg_parser.get_int_vec("s_qpad");
auto seqlen_ks = arg_parser.get_int_vec("s_k");
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
bool i_perm = arg_parser.get_bool("iperm");
@@ -130,6 +140,8 @@ auto run(const ck_tile::ArgParser& arg_parser)
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,

View File

@@ -116,6 +116,7 @@ struct fmha_bwd_args
void* dq_acc_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_q_ptr;
const void* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
@@ -203,6 +204,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
@@ -315,6 +317,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
args.d_ptr,
args.p_undrop,
args.seqstart_q_ptr,
args.seqlen_q_ptr,
args.hdim_v,
args.stride_do,
args.stride_o,
@@ -356,6 +359,8 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
args.dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,

View File

@@ -65,6 +65,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::index_t nhead_k,
std::vector<ck_tile::index_t> seqlen_qs,
std::vector<ck_tile::index_t> seqlen_ks,
std::vector<ck_tile::index_t> seqlen_qpads,
std::vector<ck_tile::index_t> seqlen_kpads,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
bool i_perm,
@@ -119,10 +121,15 @@ bwd_result fmha_bwd_run(mode_enum mode,
std::cerr << "dbias only exists when bias type is elementwise" << std::endl;
return bwd_result::invalid_args;
}
std::vector<ck_tile::index_t> seqlen_kpads;
std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) =
generate_missing_seqlens(mode, batch, seqlen_qs, seqlen_ks, {}, 0, false, random_engine);
ck_tile::ignore = seqlen_kpads;
std::tie(seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads) = generate_missing_seqlens(
mode, batch, seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads, 0, false, random_engine);
bool use_qpadding =
mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1);
bool use_kpadding =
mode == mode_enum::group && (!seqlen_kpads.empty() && seqlen_kpads[0] != -1);
#if 0
std::cout << "seqlen_qs: " << seqlen_qs << std::endl;
std::cout << "seqlen_ks: " << seqlen_ks << std::endl;
@@ -146,8 +153,10 @@ bwd_result fmha_bwd_run(mode_enum mode,
s_randval = true;
}
const auto seqstart_q_host = to_seqstarts(seqlen_qs);
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
const auto seqstart_q_host =
(use_qpadding ? to_seqstarts(seqlen_qpads) : to_seqstarts(seqlen_qs));
const auto seqstart_k_host =
(use_kpadding ? to_seqstarts(seqlen_kpads) : to_seqstarts(seqlen_ks));
using TypeConfig = FmhaBwdTypeConfig<DataTypeConfig>;
@@ -336,6 +345,10 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_q_dev(mode == mode_enum::batch ? 0
: seqlen_qs.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_k_dev(mode == mode_enum::batch ? 0
: seqlen_ks.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
@@ -349,6 +362,13 @@ bwd_result fmha_bwd_run(mode_enum mode,
do_buf.ToDevice(do_host.data());
seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqstart_k_host.data());
if(mode == mode_enum::group)
{
std::vector<int32_t> seqlen_q_host(seqlen_qs.begin(), seqlen_qs.end());
seqlen_q_dev.ToDevice(seqlen_q_host.data());
std::vector<int32_t> seqlen_k_host(seqlen_ks.begin(), seqlen_ks.end());
seqlen_k_dev.ToDevice(seqlen_k_host.data());
}
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
alibi_slope_buf.ToDevice(alibi_slope_host.data());
@@ -440,11 +460,14 @@ bwd_result fmha_bwd_run(mode_enum mode,
}
}();
const void* seqlen_q_ptr_dev = use_qpadding ? seqlen_q_dev.GetDeviceBuffer() : nullptr;
const void* seqlen_k_ptr_dev = use_kpadding ? seqlen_k_dev.GetDeviceBuffer() : nullptr;
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
: bias_buf.GetDeviceBuffer(),
: bias_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(),
do_buf.GetDeviceBuffer(),
@@ -457,7 +480,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
dq_acc_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
nullptr,
seqlen_q_ptr_dev,
seqlen_k_ptr_dev,
shape_seqlen_q,
shape_seqlen_k,
batch,
@@ -472,7 +496,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
stride_k,
stride_v,
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
: stride_bias,
: stride_bias,
stride_o,
stride_randval,
stride_do,

View File

@@ -330,11 +330,12 @@ fwd_result fmha_fwd_run(mode_enum mode,
return fwd_result::invalid_args;
}
std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) =
std::tie(seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads) =
generate_missing_seqlens(mode,
batch,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
/*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0,
need_append_kvcache,
@@ -346,7 +347,13 @@ fwd_result fmha_fwd_run(mode_enum mode,
std::cerr << "kpad must be greater than or equal to seqlen for k" << std::endl;
return fwd_result::invalid_args;
}
if(seqlen_qpads[wb] > 0 && seqlen_qpads[wb] < seqlen_qs[wb])
{
std::cerr << "qpad must be greater than or equal to seqlen for q" << std::endl;
return fwd_result::invalid_args;
}
}
// compute kvcache seqlen_k (before appending knew/vnew)
auto cache_seqlen_ks = seqlen_ks;
std::transform(cache_seqlen_ks.begin(),
@@ -357,6 +364,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
#if 0
std::cout << "seqlen_qs: " << seqlen_qs << std::endl;
std::cout << "seqlen_ks: " << seqlen_ks << std::endl;
std::cout << "seqlen_qpads: " << seqlen_qpads << std::endl;
std::cout << "seqlen_kpads: " << seqlen_kpads << std::endl;
std::cout << "cache_seqlen_ks: " << cache_seqlen_ks << std::endl;
#endif
@@ -514,9 +522,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
// host memory for storing all the tensor elements
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
// logical(unpadded) total seqlen_q for group; batch uses fixed seqlen
const ck_tile::index_t shape_seqlen_q_lse =
(mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back());
// physical(padded) total seqlen_q for group when s_qpad is provided; else use logical
const ck_tile::index_t shape_seqlen_q =
(mode == mode_enum::batch
@@ -580,7 +585,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
// batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q_lse}
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<ODataType> o_host(
@@ -970,8 +975,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
const ck_tile::index_t nhead_stride_bias =
(i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q_lse;
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments
@@ -986,8 +991,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q_lse);
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q_lse);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
@@ -1727,12 +1732,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
if(lse)
{
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
const ck_tile::index_t query_offset_lse =
(mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
lse_host_result.ForEach([&](auto& self, auto idx) {
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset_lse);
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset);
});
std::cout << "lse_host_result shape: " << shape_batch << ", " << nhead << ", " << shape_seqlen_q << std::endl;
cur_pass = ck_tile::check_err(lse_host_result,
lse_host_ref,
"LSE Error: Incorrect results!",

View File

@@ -142,12 +142,14 @@ auto randints(ForwardIterator first,
*/
template <typename RandomEngine>
std::tuple<std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>>
generate_missing_seqlens(mode_enum mode,
ck_tile::index_t batch,
const std::vector<ck_tile::index_t>& q_val,
const std::vector<ck_tile::index_t>& k_val,
const std::vector<ck_tile::index_t>& q_pad_val,
const std::vector<ck_tile::index_t>& k_pad_val,
ck_tile::index_t seqlen_k_min,
bool need_append_kvcache,
@@ -177,7 +179,7 @@ generate_missing_seqlens(mode_enum mode,
return seqlen_ks;
}();
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
auto s_qpad = std::vector<ck_tile::index_t>(batch, -1);
// s_k should be greater than or equal to seqlen_k_min if provided
if(s_k.back() < seqlen_k_min)
{
@@ -187,13 +189,14 @@ generate_missing_seqlens(mode_enum mode,
throw std::runtime_error(msg.str());
}
return std::make_tuple(s_q, s_k, s_kpad);
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
}
else
{
std::vector<ck_tile::index_t> s_q;
std::vector<ck_tile::index_t> s_k;
std::vector<ck_tile::index_t> s_kpad;
std::vector<ck_tile::index_t> s_qpad;
ck_tile::index_t idx = 0;
for(; idx < std::min(static_cast<ck_tile::index_t>(q_val.size()), batch); ++idx)
{
@@ -205,9 +208,15 @@ generate_missing_seqlens(mode_enum mode,
? -1
: k_pad_val[std::min(idx, static_cast<ck_tile::index_t>(k_pad_val.size()) - 1)];
ck_tile::index_t qp =
q_pad_val.empty()
? -1
: q_pad_val[std::min(idx, static_cast<ck_tile::index_t>(q_pad_val.size()) - 1)];
s_q.push_back(q);
s_k.push_back(k < 0 ? q : k);
s_kpad.push_back(kp);
s_qpad.push_back(qp);
// s_k should be greater than or equal to seqlen_k_min
if(s_k.back() < seqlen_k_min)
@@ -228,8 +237,9 @@ generate_missing_seqlens(mode_enum mode,
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back());
s_qpad.insert(s_qpad.end(), batch - idx, s_qpad.back());
}
return std::make_tuple(s_q, s_k, s_kpad);
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
}
}