mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
[CK_TILE] fmha: Add query padding support to backward pass (#3097)
* [CK_TILE] fmha: Add query padding support to backward pass Introduces support for query sequence padding (q_padding) in the FMHA backward pass kernels. - Passing `seqlen_q_ptr` to the backward kernels to distinguish logical from physical sequence lengths. - Updating `OGradDotO`, `ConvertQGrad`, and `DQDKDV` kernels to respect logical lengths and handle zero-length sequences. - Aligning LSE indexing in the forward kernel with the padded layout for consistency. - Adding a new GTest suite (`test_fmha_bwd_kernel_padding.cpp`) with comprehensive tests for various padding scenarios, including zero-length sequences and deterministic mode. * fix clang format * Adapt fmha_bwd_runner.cpp to new q, kv sequence padding Add backward q/kv sequence padding unit tests. * [CK_TILE] fmha: Unify sequence length and padding handling Refactor the handling of sequence lengths and padding in the FMHA forward and backward kernels to provide a more unified and flexible interface. - Replaced `seqstart_padded_*_ptr` with a more robust system that uses `seqstart_*_ptr` for physical sequence lengths and introduces `seqlen_*_ptr` and `cu_seqlen_*_ptr` for logical (unpadded) lengths. - Established a clear order of precedence for determining sequence length: cumulative lengths (`cu_seqlen_*_ptr`) take priority, followed by per-sequence lengths (`seqlen_*_ptr`), and finally physical lengths derived from `seqstart_*_ptr`. - Clarified the distinction between "group mode" and "batch mode" and how sequence lengths are handled in each case. - Renamed `cu_seqlen_kv_ptr` to `cu_seqlen_k_ptr` for consistency. - Updated comments and documentation to reflect the new argument structure and usage. --------- Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
This commit is contained in:
@@ -142,12 +142,14 @@ auto randints(ForwardIterator first,
|
||||
*/
|
||||
template <typename RandomEngine>
|
||||
std::tuple<std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>>
|
||||
generate_missing_seqlens(mode_enum mode,
|
||||
ck_tile::index_t batch,
|
||||
const std::vector<ck_tile::index_t>& q_val,
|
||||
const std::vector<ck_tile::index_t>& k_val,
|
||||
const std::vector<ck_tile::index_t>& q_pad_val,
|
||||
const std::vector<ck_tile::index_t>& k_pad_val,
|
||||
ck_tile::index_t seqlen_k_min,
|
||||
bool need_append_kvcache,
|
||||
@@ -177,7 +179,7 @@ generate_missing_seqlens(mode_enum mode,
|
||||
return seqlen_ks;
|
||||
}();
|
||||
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
|
||||
|
||||
auto s_qpad = std::vector<ck_tile::index_t>(batch, -1);
|
||||
// s_k should be greater than or equal to seqlen_k_min if provided
|
||||
if(s_k.back() < seqlen_k_min)
|
||||
{
|
||||
@@ -187,13 +189,14 @@ generate_missing_seqlens(mode_enum mode,
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return std::make_tuple(s_q, s_k, s_kpad);
|
||||
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::vector<ck_tile::index_t> s_q;
|
||||
std::vector<ck_tile::index_t> s_k;
|
||||
std::vector<ck_tile::index_t> s_kpad;
|
||||
std::vector<ck_tile::index_t> s_qpad;
|
||||
ck_tile::index_t idx = 0;
|
||||
for(; idx < std::min(static_cast<ck_tile::index_t>(q_val.size()), batch); ++idx)
|
||||
{
|
||||
@@ -205,9 +208,15 @@ generate_missing_seqlens(mode_enum mode,
|
||||
? -1
|
||||
: k_pad_val[std::min(idx, static_cast<ck_tile::index_t>(k_pad_val.size()) - 1)];
|
||||
|
||||
ck_tile::index_t qp =
|
||||
q_pad_val.empty()
|
||||
? -1
|
||||
: q_pad_val[std::min(idx, static_cast<ck_tile::index_t>(q_pad_val.size()) - 1)];
|
||||
|
||||
s_q.push_back(q);
|
||||
s_k.push_back(k < 0 ? q : k);
|
||||
s_kpad.push_back(kp);
|
||||
s_qpad.push_back(qp);
|
||||
|
||||
// s_k should be greater than or equal to seqlen_k_min
|
||||
if(s_k.back() < seqlen_k_min)
|
||||
@@ -228,8 +237,9 @@ generate_missing_seqlens(mode_enum mode,
|
||||
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
|
||||
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
|
||||
s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back());
|
||||
s_qpad.insert(s_qpad.end(), batch - idx, s_qpad.back());
|
||||
}
|
||||
return std::make_tuple(s_q, s_k, s_kpad);
|
||||
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user