mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
[CK_TILE] fmha: Add query padding support to backward pass
Introduces support for query sequence padding (q_padding) in the FMHA backward pass kernels. - Passing `seqlen_q_ptr` to the backward kernels to distinguish logical from physical sequence lengths. - Updating `OGradDotO`, `ConvertQGrad`, and `DQDKDV` kernels to respect logical lengths and handle zero-length sequences. - Aligning LSE indexing in the forward kernel with the padded layout for consistency. - Adding a new GTest suite (`test_fmha_bwd_kernel_padding.cpp`) with comprehensive tests for various padding scenarios, including zero-length sequences and deterministic mode.
This commit is contained in:
@@ -24,11 +24,19 @@ auto create_args(int argc, char* argv[])
|
||||
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
|
||||
"also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
|
||||
"(group mode)")
|
||||
.insert("s_qpad",
|
||||
"-1",
|
||||
"padded seqlen_q per batch (group mode only). "
|
||||
"Use \"-s_qpad=p0,p1,...\"; -1 disables explicit padding")
|
||||
.insert("s_k",
|
||||
"-1",
|
||||
"seqlen_k, -1 means equal to s\n"
|
||||
"also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
|
||||
"(group mode)")
|
||||
.insert("s_kpad",
|
||||
"-1",
|
||||
"padded seqlen_k per batch (group mode only). "
|
||||
"Use \"-s_kpad=k0,k1,...\"; -1 disables explicit padding")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)")
|
||||
@@ -96,7 +104,9 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
auto seqlen_qs = arg_parser.get_int_vec("s");
|
||||
auto seqlen_qpads = arg_parser.get_int_vec("s_qpad");
|
||||
auto seqlen_ks = arg_parser.get_int_vec("s_k");
|
||||
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
@@ -130,6 +140,8 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
nhead_k,
|
||||
seqlen_qs,
|
||||
seqlen_ks,
|
||||
seqlen_qpads,
|
||||
seqlen_kpads,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
|
||||
@@ -116,6 +116,7 @@ struct fmha_bwd_args
|
||||
void* dq_acc_ptr;
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_q_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -203,6 +204,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
dq_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
@@ -315,6 +317,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.d_ptr,
|
||||
args.p_undrop,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.hdim_v,
|
||||
args.stride_do,
|
||||
args.stride_o,
|
||||
@@ -356,6 +359,8 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.dq_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
|
||||
@@ -65,6 +65,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::index_t nhead_k,
|
||||
std::vector<ck_tile::index_t> seqlen_qs,
|
||||
std::vector<ck_tile::index_t> seqlen_ks,
|
||||
std::vector<ck_tile::index_t> seqlen_qpads,
|
||||
std::vector<ck_tile::index_t> seqlen_kpads,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
bool i_perm,
|
||||
@@ -119,10 +121,15 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
std::cerr << "dbias only exists when bias type is elementwise" << std::endl;
|
||||
return bwd_result::invalid_args;
|
||||
}
|
||||
std::vector<ck_tile::index_t> seqlen_kpads;
|
||||
std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) =
|
||||
generate_missing_seqlens(mode, batch, seqlen_qs, seqlen_ks, {}, 0, false, random_engine);
|
||||
ck_tile::ignore = seqlen_kpads;
|
||||
|
||||
std::tie(seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads) = generate_missing_seqlens(
|
||||
mode, batch, seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads, 0, false, random_engine);
|
||||
|
||||
bool use_qpadding =
|
||||
mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1);
|
||||
bool use_kpadding =
|
||||
mode == mode_enum::group && (!seqlen_kpads.empty() && seqlen_kpads[0] != -1);
|
||||
|
||||
#if 0
|
||||
std::cout << "seqlen_qs: " << seqlen_qs << std::endl;
|
||||
std::cout << "seqlen_ks: " << seqlen_ks << std::endl;
|
||||
@@ -146,8 +153,10 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
s_randval = true;
|
||||
}
|
||||
|
||||
const auto seqstart_q_host = to_seqstarts(seqlen_qs);
|
||||
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
|
||||
const auto seqstart_q_host =
|
||||
(use_qpadding ? to_seqstarts(seqlen_qpads) : to_seqstarts(seqlen_qs));
|
||||
const auto seqstart_k_host =
|
||||
(use_kpadding ? to_seqstarts(seqlen_kpads) : to_seqstarts(seqlen_ks));
|
||||
|
||||
using TypeConfig = FmhaBwdTypeConfig<DataTypeConfig>;
|
||||
|
||||
@@ -336,6 +345,10 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqlen_q_dev(mode == mode_enum::batch ? 0
|
||||
: seqlen_qs.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqlen_k_dev(mode == mode_enum::batch ? 0
|
||||
: seqlen_ks.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
@@ -349,6 +362,13 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
do_buf.ToDevice(do_host.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
seqstart_k.ToDevice(seqstart_k_host.data());
|
||||
if(mode == mode_enum::group)
|
||||
{
|
||||
std::vector<int32_t> seqlen_q_host(seqlen_qs.begin(), seqlen_qs.end());
|
||||
seqlen_q_dev.ToDevice(seqlen_q_host.data());
|
||||
std::vector<int32_t> seqlen_k_host(seqlen_ks.begin(), seqlen_ks.end());
|
||||
seqlen_k_dev.ToDevice(seqlen_k_host.data());
|
||||
}
|
||||
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
|
||||
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
|
||||
alibi_slope_buf.ToDevice(alibi_slope_host.data());
|
||||
@@ -440,11 +460,14 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
}
|
||||
}();
|
||||
|
||||
const void* seqlen_q_ptr_dev = use_qpadding ? seqlen_q_dev.GetDeviceBuffer() : nullptr;
|
||||
const void* seqlen_k_ptr_dev = use_kpadding ? seqlen_k_dev.GetDeviceBuffer() : nullptr;
|
||||
|
||||
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
v_buf.GetDeviceBuffer(),
|
||||
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
|
||||
: bias_buf.GetDeviceBuffer(),
|
||||
: bias_buf.GetDeviceBuffer(),
|
||||
o_buf.GetDeviceBuffer(),
|
||||
lse_buf.GetDeviceBuffer(),
|
||||
do_buf.GetDeviceBuffer(),
|
||||
@@ -457,7 +480,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
dq_acc_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
seqstart_k.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
seqlen_q_ptr_dev,
|
||||
seqlen_k_ptr_dev,
|
||||
shape_seqlen_q,
|
||||
shape_seqlen_k,
|
||||
batch,
|
||||
@@ -472,7 +496,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
stride_k,
|
||||
stride_v,
|
||||
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
|
||||
: stride_bias,
|
||||
: stride_bias,
|
||||
stride_o,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
|
||||
@@ -330,11 +330,12 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
return fwd_result::invalid_args;
|
||||
}
|
||||
|
||||
std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) =
|
||||
std::tie(seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads) =
|
||||
generate_missing_seqlens(mode,
|
||||
batch,
|
||||
seqlen_qs,
|
||||
seqlen_ks,
|
||||
seqlen_qpads,
|
||||
seqlen_kpads,
|
||||
/*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0,
|
||||
need_append_kvcache,
|
||||
@@ -346,7 +347,13 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
std::cerr << "kpad must be greater than or equal to seqlen for k" << std::endl;
|
||||
return fwd_result::invalid_args;
|
||||
}
|
||||
if(seqlen_qpads[wb] > 0 && seqlen_qpads[wb] < seqlen_qs[wb])
|
||||
{
|
||||
std::cerr << "qpad must be greater than or equal to seqlen for q" << std::endl;
|
||||
return fwd_result::invalid_args;
|
||||
}
|
||||
}
|
||||
|
||||
// compute kvcache seqlen_k (before appending knew/vnew)
|
||||
auto cache_seqlen_ks = seqlen_ks;
|
||||
std::transform(cache_seqlen_ks.begin(),
|
||||
@@ -357,6 +364,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
#if 0
|
||||
std::cout << "seqlen_qs: " << seqlen_qs << std::endl;
|
||||
std::cout << "seqlen_ks: " << seqlen_ks << std::endl;
|
||||
std::cout << "seqlen_qpads: " << seqlen_qpads << std::endl;
|
||||
std::cout << "seqlen_kpads: " << seqlen_kpads << std::endl;
|
||||
std::cout << "cache_seqlen_ks: " << cache_seqlen_ks << std::endl;
|
||||
#endif
|
||||
@@ -514,9 +522,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
// host memory for storing all the tensor elements
|
||||
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
|
||||
// logical(unpadded) total seqlen_q for group; batch uses fixed seqlen
|
||||
const ck_tile::index_t shape_seqlen_q_lse =
|
||||
(mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back());
|
||||
// physical(padded) total seqlen_q for group when s_qpad is provided; else use logical
|
||||
const ck_tile::index_t shape_seqlen_q =
|
||||
(mode == mode_enum::batch
|
||||
@@ -580,7 +585,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
// batch mode of lse data layout is [batch, nhead, seqlen_q]
|
||||
// group mode of lse data layout is [nhead, total_seqlen_q]
|
||||
ck_tile::HostTensor<LSEDataType> lse_host(
|
||||
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q_lse}
|
||||
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
|
||||
|
||||
ck_tile::HostTensor<ODataType> o_host(
|
||||
@@ -970,8 +975,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
const ck_tile::index_t nhead_stride_bias =
|
||||
(i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q_lse;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse);
|
||||
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
|
||||
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
// setup batch_stride_* arguments
|
||||
@@ -986,8 +991,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
|
||||
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q_lse);
|
||||
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q_lse);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
|
||||
@@ -1727,12 +1732,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
|
||||
const ck_tile::index_t query_offset_lse =
|
||||
(mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
lse_host_result.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset_lse);
|
||||
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset);
|
||||
});
|
||||
|
||||
std::cout << "lse_host_result shape: " << shape_batch << ", " << nhead << ", " << shape_seqlen_q << std::endl;
|
||||
|
||||
cur_pass = ck_tile::check_err(lse_host_result,
|
||||
lse_host_ref,
|
||||
"LSE Error: Incorrect results!",
|
||||
|
||||
@@ -142,12 +142,14 @@ auto randints(ForwardIterator first,
|
||||
*/
|
||||
template <typename RandomEngine>
|
||||
std::tuple<std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>>
|
||||
generate_missing_seqlens(mode_enum mode,
|
||||
ck_tile::index_t batch,
|
||||
const std::vector<ck_tile::index_t>& q_val,
|
||||
const std::vector<ck_tile::index_t>& k_val,
|
||||
const std::vector<ck_tile::index_t>& q_pad_val,
|
||||
const std::vector<ck_tile::index_t>& k_pad_val,
|
||||
ck_tile::index_t seqlen_k_min,
|
||||
bool need_append_kvcache,
|
||||
@@ -177,7 +179,7 @@ generate_missing_seqlens(mode_enum mode,
|
||||
return seqlen_ks;
|
||||
}();
|
||||
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
|
||||
|
||||
auto s_qpad = std::vector<ck_tile::index_t>(batch, -1);
|
||||
// s_k should be greater than or equal to seqlen_k_min if provided
|
||||
if(s_k.back() < seqlen_k_min)
|
||||
{
|
||||
@@ -187,13 +189,14 @@ generate_missing_seqlens(mode_enum mode,
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return std::make_tuple(s_q, s_k, s_kpad);
|
||||
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::vector<ck_tile::index_t> s_q;
|
||||
std::vector<ck_tile::index_t> s_k;
|
||||
std::vector<ck_tile::index_t> s_kpad;
|
||||
std::vector<ck_tile::index_t> s_qpad;
|
||||
ck_tile::index_t idx = 0;
|
||||
for(; idx < std::min(static_cast<ck_tile::index_t>(q_val.size()), batch); ++idx)
|
||||
{
|
||||
@@ -205,9 +208,15 @@ generate_missing_seqlens(mode_enum mode,
|
||||
? -1
|
||||
: k_pad_val[std::min(idx, static_cast<ck_tile::index_t>(k_pad_val.size()) - 1)];
|
||||
|
||||
ck_tile::index_t qp =
|
||||
q_pad_val.empty()
|
||||
? -1
|
||||
: q_pad_val[std::min(idx, static_cast<ck_tile::index_t>(q_pad_val.size()) - 1)];
|
||||
|
||||
s_q.push_back(q);
|
||||
s_k.push_back(k < 0 ? q : k);
|
||||
s_kpad.push_back(kp);
|
||||
s_qpad.push_back(qp);
|
||||
|
||||
// s_k should be greater than or equal to seqlen_k_min
|
||||
if(s_k.back() < seqlen_k_min)
|
||||
@@ -228,8 +237,9 @@ generate_missing_seqlens(mode_enum mode,
|
||||
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
|
||||
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
|
||||
s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back());
|
||||
s_qpad.insert(s_qpad.end(), batch - idx, s_qpad.back());
|
||||
}
|
||||
return std::make_tuple(s_q, s_k, s_kpad);
|
||||
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user