[CK_TILE] fmha: Add query padding support to backward pass (#3097)

* [CK_TILE] fmha: Add query padding support to backward pass

Introduces support for query sequence padding (q_padding) in the FMHA backward pass kernels.
- Passing `seqlen_q_ptr` to the backward kernels to distinguish logical from physical sequence lengths.
- Updating `OGradDotO`, `ConvertQGrad`, and `DQDKDV` kernels to respect logical lengths and handle zero-length sequences.
- Aligning LSE indexing in the forward kernel with the padded layout for consistency.
- Adding a new GTest suite (`test_fmha_bwd_kernel_padding.cpp`) with comprehensive tests for various padding scenarios, including zero-length
  sequences and deterministic mode.

* fix clang format

* Adapt fmha_bwd_runner.cpp to new q, kv sequence padding
Add backward q/kv sequence padding unit tests.

* [CK_TILE] fmha: Unify sequence length and padding handling

Refactor the handling of sequence lengths and padding in the
FMHA forward and backward kernels to provide a more unified and flexible
interface.

- Replaced `seqstart_padded_*_ptr` with a more robust system that uses
  `seqstart_*_ptr` for physical sequence lengths and introduces
  `seqlen_*_ptr` and `cu_seqlen_*_ptr` for logical (unpadded) lengths.
- Established a clear order of precedence for determining sequence
  length: cumulative lengths (`cu_seqlen_*_ptr`) take priority,
  followed by per-sequence lengths (`seqlen_*_ptr`), and finally
  physical lengths derived from `seqstart_*_ptr`.
- Clarified the distinction between "group mode" and "batch mode" and
  how sequence lengths are handled in each case.
- Renamed `cu_seqlen_kv_ptr` to `cu_seqlen_k_ptr` for consistency.
- Updated comments and documentation to reflect the new argument
  structure and usage.

---------

Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>

[ROCm/composable_kernel commit: 7c6430eca0]
This commit is contained in:
Jeff Huang
2025-10-29 13:56:11 +08:00
committed by GitHub
parent dc1cd3df0c
commit 9ad15a658c
11 changed files with 1292 additions and 214 deletions

View File

@@ -269,14 +269,14 @@ class FmhaFwdApiTrait:
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
if self.pipeline_tag == "qr_async":
if self.skpad == "t":
return f"(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
else:
return f"(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
elif self.pipeline_tag in ["qr", "qs"]:
if self.skpad == "t":
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
elif self.pipeline_tag == "qr_async_trload":
if self.skpad == "t":
return "true"

View File

@@ -24,11 +24,19 @@ auto create_args(int argc, char* argv[])
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
"also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_qpad",
"-1",
"padded seqlen_q per batch (group mode only). "
"Use \"-s_qpad=p0,p1,...\"; -1 disables explicit padding")
.insert("s_k",
"-1",
"seqlen_k, -1 means equal to s\n"
"also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_kpad",
"-1",
"padded seqlen_k per batch (group mode only). "
"Use \"-s_kpad=k0,k1,...\"; -1 disables explicit padding")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)")
@@ -96,7 +104,9 @@ auto run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
auto seqlen_qs = arg_parser.get_int_vec("s");
auto seqlen_qpads = arg_parser.get_int_vec("s_qpad");
auto seqlen_ks = arg_parser.get_int_vec("s_k");
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
bool i_perm = arg_parser.get_bool("iperm");
@@ -130,6 +140,8 @@ auto run(const ck_tile::ArgParser& arg_parser)
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,

View File

@@ -114,9 +114,51 @@ struct fmha_bwd_args
void* dv_ptr;
void* dbias_ptr;
void* dq_acc_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
// Usage notes for sequence length pointer parameters:
//
// [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles
// MQA/GQA..."]
//
// With padding:
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence
// lengths. [array size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each
// sequence. [array size: batch]
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually
// exclusive. Use one set, not both.
//
// Batch mode:
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqstart_* and seqlen_* pointers must be nullptr.
//
// Without padding:
// (Note: Physical length equals logical length)
//
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array
// size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr.
//
// Batch mode:
// - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr.
//
const void* seqstart_q_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqstart_k_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
@@ -203,7 +245,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
@@ -315,6 +360,8 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
args.d_ptr,
args.p_undrop,
args.seqstart_q_ptr,
args.seqlen_q_ptr,
args.cu_seqlen_q_ptr,
args.hdim_v,
args.stride_do,
args.stride_o,
@@ -356,6 +403,10 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
args.dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,

View File

@@ -65,6 +65,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::index_t nhead_k,
std::vector<ck_tile::index_t> seqlen_qs,
std::vector<ck_tile::index_t> seqlen_ks,
std::vector<ck_tile::index_t> seqlen_qpads,
std::vector<ck_tile::index_t> seqlen_kpads,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
bool i_perm,
@@ -119,13 +121,26 @@ bwd_result fmha_bwd_run(mode_enum mode,
std::cerr << "dbias only exists when bias type is elementwise" << std::endl;
return bwd_result::invalid_args;
}
std::vector<ck_tile::index_t> seqlen_kpads;
std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) =
generate_missing_seqlens(mode, batch, seqlen_qs, seqlen_ks, {}, 0, false, random_engine);
ck_tile::ignore = seqlen_kpads;
std::tie(seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads) = generate_missing_seqlens(
mode, batch, seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads, 0, false, random_engine);
bool use_qpadding =
mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1);
bool use_kpadding =
mode == mode_enum::group && (!seqlen_kpads.empty() && seqlen_kpads[0] != -1);
#if 0
std::cout << "use_qpadding: " << use_qpadding << std::endl;
std::cout << "use_kpadding: " << use_kpadding << std::endl;
std::cout << "seqlen_qs: " << seqlen_qs << std::endl;
std::cout << "seqlen_ks: " << seqlen_ks << std::endl;
if (use_qpadding) {
std::cout << "seqlen_qpads: " << seqlen_qpads << std::endl;
}
if (use_kpadding) {
std::cout << "seqlen_kpads: " << seqlen_kpads << std::endl;
}
#endif
mask_info mask = mask_info::decode(mask_str, seqlen_qs[0], seqlen_ks[0]);
@@ -146,8 +161,10 @@ bwd_result fmha_bwd_run(mode_enum mode,
s_randval = true;
}
const auto seqstart_q_host = to_seqstarts(seqlen_qs);
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
const auto seqstart_q_host =
(use_qpadding ? to_seqstarts(seqlen_qpads) : to_seqstarts(seqlen_qs));
const auto seqstart_k_host =
(use_kpadding ? to_seqstarts(seqlen_kpads) : to_seqstarts(seqlen_ks));
using TypeConfig = FmhaBwdTypeConfig<DataTypeConfig>;
@@ -176,8 +193,11 @@ bwd_result fmha_bwd_run(mode_enum mode,
{
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
// When padding is enabled, use logical lengths for flop/bandwidth calculation
const int32_t real_seqlen_q =
use_qpadding ? seqlen_qs[wb] : (seqstart_q_host[wb + 1] - seqstart_q_host[wb]);
const int32_t real_seqlen_k =
use_kpadding ? seqlen_ks[wb] : (seqstart_k_host[wb + 1] - seqstart_k_host[wb]);
if(max_seqlen_q < real_seqlen_q)
{
@@ -336,6 +356,10 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_q_dev(mode == mode_enum::batch ? 0
: seqlen_qs.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_k_dev(mode == mode_enum::batch ? 0
: seqlen_ks.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
@@ -349,6 +373,13 @@ bwd_result fmha_bwd_run(mode_enum mode,
do_buf.ToDevice(do_host.data());
seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqstart_k_host.data());
if(mode == mode_enum::group)
{
std::vector<int32_t> seqlen_q_host(seqlen_qs.begin(), seqlen_qs.end());
seqlen_q_dev.ToDevice(seqlen_q_host.data());
std::vector<int32_t> seqlen_k_host(seqlen_ks.begin(), seqlen_ks.end());
seqlen_k_dev.ToDevice(seqlen_k_host.data());
}
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
alibi_slope_buf.ToDevice(alibi_slope_host.data());
@@ -440,6 +471,9 @@ bwd_result fmha_bwd_run(mode_enum mode,
}
}();
const void* seqlen_q_ptr_dev = use_qpadding ? seqlen_q_dev.GetDeviceBuffer() : nullptr;
const void* seqlen_k_ptr_dev = use_kpadding ? seqlen_k_dev.GetDeviceBuffer() : nullptr;
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
@@ -457,6 +491,9 @@ bwd_result fmha_bwd_run(mode_enum mode,
dq_acc_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
seqlen_q_ptr_dev,
seqlen_k_ptr_dev,
nullptr,
nullptr,
shape_seqlen_q,
shape_seqlen_k,
@@ -551,8 +588,18 @@ bwd_result fmha_bwd_run(mode_enum mode,
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
// When padding is enabled, use logical lengths instead of computing from padded
// prefix-sum
const ck_tile::index_t real_seqlen_q =
use_qpadding ? seqlen_qs[wb] : (seqstart_q_host[wb + 1] - seqstart_q_host[wb]);
const ck_tile::index_t real_seqlen_k =
use_kpadding ? seqlen_ks[wb] : (seqstart_k_host[wb + 1] - seqstart_k_host[wb]);
// Skip forward reference computation for batches with zero length sequences
if(real_seqlen_q == 0 || real_seqlen_k == 0)
{
continue;
}
// adjust matrix index according to the mode
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
@@ -797,10 +844,23 @@ bwd_result fmha_bwd_run(mode_enum mode,
dv_buf.FromDevice(dv_host.data());
dbias_buf.FromDevice(dbias_host.data());
// Track the index into reference vectors (may differ from wb if batches were skipped)
ck_tile::index_t ref_idx = 0;
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
// When padding is enabled, use logical lengths instead of computing from padded
// prefix-sum
const ck_tile::index_t real_seqlen_q =
use_qpadding ? seqlen_qs[wb] : (seqstart_q_host[wb + 1] - seqstart_q_host[wb]);
const ck_tile::index_t real_seqlen_k =
use_kpadding ? seqlen_ks[wb] : (seqstart_k_host[wb + 1] - seqstart_k_host[wb]);
// Skip validation for batches with zero length sequences
if(real_seqlen_q == 0 || real_seqlen_k == 0)
{
continue;
}
// adjust matrix index according to the mode
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
@@ -833,14 +893,14 @@ bwd_result fmha_bwd_run(mode_enum mode,
// dP = dO@V x Z w/ dropout
// dP = dO@V w/o dropout
auto v_t_host_ref = v_host_refs[wb].transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o
auto v_t_host_ref = v_host_refs[ref_idx].transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o
ck_tile::reference_batched_gemm<OGradDataType, VDataType, AccDataType, AccDataType>(
do_host_ref, v_t_host_ref, dp_hp_host_ref); // dp_g_m_n = do_g_m_o@v_g_n_o
if(p_drop > 0)
{
ck_tile::reference_batched_dropout(
dp_hp_host_ref, randval_host_refs[wb], p_undrop_in_uint8_t, rp_undrop);
dp_hp_host_ref, randval_host_refs[ref_idx], p_undrop_in_uint8_t, rp_undrop);
}
// dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i)
@@ -849,11 +909,13 @@ bwd_result fmha_bwd_run(mode_enum mode,
AccDataType do_dot_o = 0;
for(int o = 0; o < hdim_v; o++)
{
do_dot_o += ck_tile::type_convert<AccDataType>(do_host_ref(i0, i1, o)) *
ck_tile::type_convert<AccDataType>(o_host_refs[wb](i0, i1, o));
do_dot_o +=
ck_tile::type_convert<AccDataType>(do_host_ref(i0, i1, o)) *
ck_tile::type_convert<AccDataType>(o_host_refs[ref_idx](i0, i1, o));
}
ds_hp_host_ref(i0, i1, i2) = ck_tile::type_convert<AccDataType>(
p_hp_host_refs[wb](i0, i1, i2) * (dp_hp_host_ref(i0, i1, i2) - do_dot_o));
ds_hp_host_ref(i0, i1, i2) =
ck_tile::type_convert<AccDataType>(p_hp_host_refs[ref_idx](i0, i1, i2) *
(dp_hp_host_ref(i0, i1, i2) - do_dot_o));
},
ds_hp_host_ref.mDesc.get_lengths()[0],
ds_hp_host_ref.mDesc.get_lengths()[1],
@@ -869,14 +931,14 @@ bwd_result fmha_bwd_run(mode_enum mode,
// dV = P_drop^T@dO^T
// dV = P^T@dO^T w/o dropout
auto p_t_lp_host_ref =
p_lp_host_refs[wb].transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m
p_lp_host_refs[ref_idx].transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m
auto do_t_host_ref = do_host_ref.transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m
ck_tile::
reference_batched_gemm<GemmDataType, OGradDataType, AccDataType, VGradDataType>(
p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m
// dQ = scale * dS@K^T
auto k_t_host_ref = k_host_refs[wb].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n
auto k_t_host_ref = k_host_refs[ref_idx].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n
ck_tile::reference_batched_gemm<GemmDataType, KDataType, AccDataType, QGradDataType>(
ds_lp_host_ref,
k_t_host_ref,
@@ -886,8 +948,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n
// dK = scale * dS^T@Q^T
auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m
auto q_t_host_ref = q_host_refs[wb].transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m
auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m
auto q_t_host_ref = q_host_refs[ref_idx].transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m
ck_tile::reference_batched_gemm<GemmDataType, QDataType, AccDataType, KGradDataType>(
ds_t_lp_host_ref,
q_t_host_ref,
@@ -961,6 +1023,9 @@ bwd_result fmha_bwd_run(mode_enum mode,
break;
}
// Increment reference vector index for successfully validated batches
ref_idx++;
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;

View File

@@ -182,19 +182,50 @@ struct fmha_fwd_args
void* lse_ptr;
void* o_ptr;
// Optional cumulative sequence length arrays
// Batch mode: cu_seqlen_* override effective per-batch lengths (exclude PAD)
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void*
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
// Group mode: seqstart_padded_* provide physical starts including PAD (optional)
const void* seqstart_padded_q_ptr = nullptr; // [batch+1]
const void* seqstart_padded_k_ptr = nullptr; // [batch+1]
// Usage notes for sequence length pointer parameters:
//
// [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles
// MQA/GQA..."]
//
// With padding:
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence
// lengths. [array size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each
// sequence. [array size: batch]
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually
// exclusive. Use one set, not both.
//
// Batch mode:
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqstart_* and seqlen_* pointers must be nullptr.
//
// Without padding:
// (Note: Physical length equals logical length)
//
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array
// size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr.
//
// Batch mode:
// - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr.
//
const void* seqstart_q_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqstart_k_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
@@ -555,6 +586,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.o_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
@@ -584,8 +616,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.seqstart_padded_q_ptr,
args.seqstart_padded_k_ptr);
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr);
}
else
{ // create batch mode kernel arguments
@@ -633,7 +665,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.s_randval,
args.drop_seed_offset,
args.cu_seqlen_q_ptr,
args.cu_seqlen_kv_ptr);
args.cu_seqlen_k_ptr);
}
}();

View File

@@ -313,16 +313,19 @@ fwd_result fmha_fwd_run(mode_enum mode,
const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size);
// Reject unsupported padding usage in special pipelines (appendkv / splitkv / pagedkv)
const bool has_group_padding =
(mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1)) ||
(mode == mode_enum::group && (seqlen_kpads[0] >= 0));
const bool has_batch_efflens = (mode == mode_enum::batch && (!q_eff_lens_per_batch.empty() ||
!kv_eff_lens_per_batch.empty()));
const bool using_appendkv = (0 < seqlen_knew || 0 < rotary_dim);
const bool using_pagedkv = (0 < page_block_size);
const bool using_splitkv = (num_splits > 1) || use_cache_batch_idx;
const bool has_group_q_padding =
mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] > 0);
const bool has_group_k_padding =
mode == mode_enum::group && (!seqlen_kpads.empty() && seqlen_kpads[0] > 0);
const bool has_group_padding = has_group_q_padding || has_group_k_padding;
const bool has_batch_q_padding = mode == mode_enum::batch && !q_eff_lens_per_batch.empty();
const bool has_batch_k_padding = mode == mode_enum::batch && !kv_eff_lens_per_batch.empty();
const bool has_batch_padding = has_batch_q_padding || has_batch_k_padding;
const bool using_appendkv = (0 < seqlen_knew || 0 < rotary_dim);
const bool using_pagedkv = (0 < page_block_size);
const bool using_splitkv = (num_splits > 1) || use_cache_batch_idx;
if((using_appendkv || using_pagedkv || using_splitkv) &&
(has_group_padding || has_batch_efflens))
(has_group_padding || has_batch_padding))
{
std::cerr << "Padding (physical or effective lengths) is not supported with "
"appendkv/splitkv/pagedkv pipelines"
@@ -330,11 +333,12 @@ fwd_result fmha_fwd_run(mode_enum mode,
return fwd_result::invalid_args;
}
std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) =
std::tie(seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads) =
generate_missing_seqlens(mode,
batch,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
/*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0,
need_append_kvcache,
@@ -346,7 +350,13 @@ fwd_result fmha_fwd_run(mode_enum mode,
std::cerr << "kpad must be greater than or equal to seqlen for k" << std::endl;
return fwd_result::invalid_args;
}
if(seqlen_qpads[wb] > 0 && seqlen_qpads[wb] < seqlen_qs[wb])
{
std::cerr << "qpad must be greater than or equal to seqlen for q" << std::endl;
return fwd_result::invalid_args;
}
}
// compute kvcache seqlen_k (before appending knew/vnew)
auto cache_seqlen_ks = seqlen_ks;
std::transform(cache_seqlen_ks.begin(),
@@ -357,6 +367,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
#if 0
std::cout << "seqlen_qs: " << seqlen_qs << std::endl;
std::cout << "seqlen_ks: " << seqlen_ks << std::endl;
std::cout << "seqlen_qpads: " << seqlen_qpads << std::endl;
std::cout << "seqlen_kpads: " << seqlen_kpads << std::endl;
std::cout << "cache_seqlen_ks: " << cache_seqlen_ks << std::endl;
#endif
@@ -391,23 +402,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
const auto seqstart_q_host = to_seqstarts(seqlen_qs);
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
const auto seqstart_q_with_padding_host = to_seqstarts(seqlen_qpads);
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
// Optional padded Q seqstarts (group-mode only)
std::vector<int32_t> seqstart_q_with_padding_host;
if(mode == mode_enum::group && !seqlen_qpads.empty() && seqlen_qpads[0] != -1)
{
if(seqlen_qpads.size() < static_cast<size_t>(batch))
{
seqlen_qpads.resize(batch, seqlen_qpads.back());
}
if(seqlen_qpads.size() == static_cast<size_t>(batch))
{
seqstart_q_with_padding_host = to_seqstarts(
ck_tile::span<const int32_t>(seqlen_qpads.data(), seqlen_qpads.size()));
}
}
// Optional batch-mode cumulative seqlen overrides
std::vector<ck_tile::index_t> cuq_cum, cukv_cum;
if(mode == mode_enum::batch)
@@ -514,19 +511,17 @@ fwd_result fmha_fwd_run(mode_enum mode,
// host memory for storing all the tensor elements
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
// logical(unpadded) total seqlen_q for group; batch uses fixed seqlen
const ck_tile::index_t shape_seqlen_q_lse =
(mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back());
// physical(padded) total seqlen_q for group when s_qpad is provided; else use logical
const ck_tile::index_t shape_seqlen_q =
(mode == mode_enum::batch
? seqlen_qs[0]
: (seqstart_q_with_padding_host.empty() ? seqstart_q_host.back()
: seqstart_q_with_padding_host.back()));
(mode == mode_enum::batch ? seqlen_qs[0]
: (has_group_q_padding && !seqstart_q_with_padding_host.empty()
? seqstart_q_with_padding_host.back()
: seqstart_q_host.back()));
const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_ks[0]
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
: seqstart_k_with_padding_host.back()));
: (has_group_k_padding && !seqstart_k_with_padding_host.empty()
? seqstart_k_with_padding_host.back()
: seqstart_k_host.back()));
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
@@ -580,7 +575,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
// batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q_lse}
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<ODataType> o_host(
@@ -684,14 +679,18 @@ fwd_result fmha_fwd_run(mode_enum mode,
sizeof(int32_t));
ck_tile::DeviceMem seqstart_k_padded_buf(
seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t));
// Buffers for query per-sequence logical (unpadded) lengths (used in group mode with padding
// enabled)
ck_tile::DeviceMem seqlen_q_buf(has_group_q_padding ? seqlen_qs.size() * sizeof(int32_t) : 0);
// Buffers for key/value per-sequence logical (unpadded) lengths (used in batch mode with
// kvcache or group mode with padding enabled)
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || has_group_k_padding
? seqlen_ks.size() * sizeof(int32_t)
: 0);
ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0
: cuq_cum.size() * sizeof(ck_tile::index_t));
ck_tile::DeviceMem cu_seqlen_kv_buf(
cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t));
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) ||
0 <= seqlen_kpads[0]
? seqlen_ks.size() * sizeof(int32_t)
: 0);
ck_tile::DeviceMem cache_seqlen_k_buf(
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
@@ -787,7 +786,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
: seqstart_k_with_padding_host.data());
cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data());
cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data());
seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0]
seqlen_q_buf.ToDevice(has_group_q_padding ? seqlen_qs.data() : nullptr);
seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || has_group_k_padding
? seqlen_ks.data()
: nullptr);
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
@@ -868,7 +868,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
print_vec("k_padded", seqlen_kpads);
}
}
else if(has_batch_efflens)
else if(has_batch_padding)
{
// derive effective lengths from cumulative arrays if present
if(!cuq_cum.empty())
@@ -970,8 +970,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
const ck_tile::index_t nhead_stride_bias =
(i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q_lse;
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments
@@ -986,8 +986,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q_lse);
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q_lse);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
@@ -1051,14 +1051,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.lse_ptr = lse_buf.GetDeviceBuffer();
args.o_ptr = o_buf.GetDeviceBuffer();
args.seqstart_q_ptr =
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
args.seqstart_k_ptr =
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
? seqlen_k_buf.GetDeviceBuffer()
: nullptr);
args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled)
args.max_seqlen_q = max_seqlen_q;
@@ -1102,27 +1094,54 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
}
// Group-mode: optional physical padded starts for Q/K
// Sequence length and padding parameters (mode-specific)
if(mode == mode_enum::group)
{
args.seqstart_padded_q_ptr = (seqstart_q_with_padding_host.empty()
? nullptr
: seqstart_q_padded_buf.GetDeviceBuffer());
args.seqstart_padded_k_ptr =
(seqlen_kpads[0] < 0 ? nullptr : seqstart_k_padded_buf.GetDeviceBuffer());
}
// Group mode: use physical (padded) cumulative starts + logical per-sequence
// lengths
// Batch-mode: optional cumulative effective seqlen overrides
if(mode == mode_enum::batch)
// Physical cumulative starts (including padding)
args.seqstart_q_ptr =
has_group_q_padding && !seqstart_q_with_padding_host.empty()
? seqstart_q_padded_buf.GetDeviceBuffer()
: seqstart_q.GetDeviceBuffer();
args.seqstart_k_ptr =
has_group_k_padding && !seqstart_k_with_padding_host.empty()
? seqstart_k_padded_buf.GetDeviceBuffer()
: seqstart_k.GetDeviceBuffer();
// Logical (unpadded) per-sequence lengths, used when padding is enabled
args.seqlen_q_ptr =
(has_group_q_padding && !seqstart_q_with_padding_host.empty())
? seqlen_q_buf.GetDeviceBuffer()
: nullptr;
args.seqlen_k_ptr =
(has_group_k_padding && !seqstart_k_with_padding_host.empty())
? seqlen_k_buf.GetDeviceBuffer()
: nullptr;
// Cumulative lengths not used in group mode
args.cu_seqlen_q_ptr = nullptr;
args.cu_seqlen_k_ptr = nullptr;
}
else // mode == mode_enum::batch
{
args.cu_seqlen_q_ptr = cuq_cum.empty()
? nullptr
: reinterpret_cast<const ck_tile::index_t*>(
cu_seqlen_q_buf.GetDeviceBuffer());
args.cu_seqlen_kv_ptr = cukv_cum.empty()
? nullptr
: reinterpret_cast<const ck_tile::index_t*>(
cu_seqlen_kv_buf.GetDeviceBuffer());
// Batch mode: use cumulative logical lengths for tail padding
// seqstart pointers not used in batch mode
args.seqstart_q_ptr = nullptr;
args.seqstart_k_ptr = nullptr;
// seqlen_q_ptr/seqlen_k_ptr not used in batch mode
args.seqlen_q_ptr = nullptr;
args.seqlen_k_ptr = nullptr;
// Cumulative logical lengths for effective length handling
args.cu_seqlen_q_ptr = has_batch_q_padding && !cuq_cum.empty()
? cu_seqlen_q_buf.GetDeviceBuffer()
: nullptr;
args.cu_seqlen_k_ptr = has_batch_k_padding && !cukv_cum.empty()
? cu_seqlen_kv_buf.GetDeviceBuffer()
: nullptr;
}
}
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
@@ -1148,6 +1167,15 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.batch_stride_o_acc = batch_stride_o_acc;
args.split_stride_lse_acc = split_stride_lse_acc;
args.split_stride_o_acc = split_stride_o_acc;
args.seqstart_q_ptr =
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
args.seqstart_k_ptr =
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
args.seqlen_k_ptr =
((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
? seqlen_k_buf.GetDeviceBuffer()
: nullptr);
}
else if constexpr(std::is_same_v<fmha_fwd_pagedkv_args, std::decay_t<decltype(args)>>)
{
@@ -1159,6 +1187,15 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.cache_batch_idx =
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
args.seqstart_q_ptr =
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
args.seqstart_k_ptr =
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
args.seqlen_k_ptr =
((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
? seqlen_k_buf.GetDeviceBuffer()
: nullptr);
}
}
};
@@ -1360,16 +1397,19 @@ fwd_result fmha_fwd_run(mode_enum mode,
const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0);
const ck_tile::index_t cache_b_idx =
(use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx);
// Use physical offset if padding info is valid (not -1) and buffers are available
const ck_tile::index_t query_offset =
(mode == mode_enum::batch
? 0
: (seqstart_q_with_padding_host.empty() ? seqstart_q_host[wb]
: seqstart_q_with_padding_host[wb]));
: ((seqstart_q_with_padding_host.empty() || seqlen_qpads[0] < 0)
? seqstart_q_host[wb]
: seqstart_q_with_padding_host[wb]));
const ck_tile::index_t key_offset =
(mode == mode_enum::batch
? 0
: (seqlen_kpads[0] < 0 ? seqstart_k_host[wb]
: seqstart_k_with_padding_host[wb]));
: ((seqstart_k_with_padding_host.empty() || seqlen_kpads[0] < 0)
? seqstart_k_host[wb]
: seqstart_k_with_padding_host[wb]));
ck_tile::HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q});
ck_tile::HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q});
@@ -1718,8 +1758,14 @@ fwd_result fmha_fwd_run(mode_enum mode,
std::cerr << "OUT mismatch found at batch: " << wb << std::endl
<< "\tseqlen_q: " << real_seqlen_q << std::endl
<< "\tseqlen_k: " << real_seqlen_k << std::endl
<< "\tseqstart_q: " << seqstart_q_host << std::endl
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
<< "\tseqstart_q (logical): " << seqstart_q_host << std::endl
<< "\tseqstart_q (physical): " << seqstart_q_with_padding_host
<< std::endl
<< "\tseqstart_k (logical): " << seqstart_k_host << std::endl
<< "\tseqstart_k (physical): " << seqstart_k_with_padding_host
<< std::endl
<< "\tquery_offset used: " << query_offset << std::endl
<< "\tkey_offset used: " << key_offset << std::endl;
break;
}
@@ -1727,10 +1773,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
if(lse)
{
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
const ck_tile::index_t query_offset_lse =
(mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
lse_host_result.ForEach([&](auto& self, auto idx) {
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset_lse);
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset);
});
cur_pass = ck_tile::check_err(lse_host_result,

View File

@@ -142,12 +142,14 @@ auto randints(ForwardIterator first,
*/
template <typename RandomEngine>
std::tuple<std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>>
generate_missing_seqlens(mode_enum mode,
ck_tile::index_t batch,
const std::vector<ck_tile::index_t>& q_val,
const std::vector<ck_tile::index_t>& k_val,
const std::vector<ck_tile::index_t>& q_pad_val,
const std::vector<ck_tile::index_t>& k_pad_val,
ck_tile::index_t seqlen_k_min,
bool need_append_kvcache,
@@ -177,7 +179,7 @@ generate_missing_seqlens(mode_enum mode,
return seqlen_ks;
}();
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
auto s_qpad = std::vector<ck_tile::index_t>(batch, -1);
// s_k should be greater than or equal to seqlen_k_min if provided
if(s_k.back() < seqlen_k_min)
{
@@ -187,13 +189,14 @@ generate_missing_seqlens(mode_enum mode,
throw std::runtime_error(msg.str());
}
return std::make_tuple(s_q, s_k, s_kpad);
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
}
else
{
std::vector<ck_tile::index_t> s_q;
std::vector<ck_tile::index_t> s_k;
std::vector<ck_tile::index_t> s_kpad;
std::vector<ck_tile::index_t> s_qpad;
ck_tile::index_t idx = 0;
for(; idx < std::min(static_cast<ck_tile::index_t>(q_val.size()), batch); ++idx)
{
@@ -205,9 +208,15 @@ generate_missing_seqlens(mode_enum mode,
? -1
: k_pad_val[std::min(idx, static_cast<ck_tile::index_t>(k_pad_val.size()) - 1)];
ck_tile::index_t qp =
q_pad_val.empty()
? -1
: q_pad_val[std::min(idx, static_cast<ck_tile::index_t>(q_pad_val.size()) - 1)];
s_q.push_back(q);
s_k.push_back(k < 0 ? q : k);
s_kpad.push_back(kp);
s_qpad.push_back(qp);
// s_k should be greater than or equal to seqlen_k_min
if(s_k.back() < seqlen_k_min)
@@ -228,8 +237,9 @@ generate_missing_seqlens(mode_enum mode,
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back());
s_qpad.insert(s_qpad.end(), batch - idx, s_qpad.back());
}
return std::make_tuple(s_q, s_k, s_kpad);
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
}
}

View File

@@ -313,7 +313,10 @@ struct FmhaBwdDQDKDVKernel
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr;
const int32_t* seqlen_q_ptr; // per-batch actual length [batch]
const int32_t* seqlen_k_ptr; // per-batch actual length [batch]
const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional
const int32_t* cu_seqlen_k_ptr; // cumulative seqlen [batch+1], optional
};
using Kargs = std::conditional_t<kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs>;
@@ -520,7 +523,10 @@ struct FmhaBwdDQDKDVKernel
void* dq_acc_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_q_ptr,
const void* seqlen_k_ptr,
const void* cu_seqlen_q_ptr,
const void* cu_seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -594,7 +600,10 @@ struct FmhaBwdDQDKDVKernel
{}, // placeholder for deterministic
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr),
reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr)};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
@@ -736,10 +745,29 @@ struct FmhaBwdDQDKDVKernel
batch_offset_randval = query_start * kargs.stride_randval;
}
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
if(kargs.seqlen_k_ptr != nullptr)
// Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q
if(kargs.cu_seqlen_q_ptr != nullptr)
{
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
else
{
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
const ck_tile::index_t physical_seqlen_q =
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
kargs.seqlen_q =
kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[i_batch] : physical_seqlen_q;
}
// Priority: cu_seqlen_k_ptr > seqlen_k_ptr > seqstart_k
if(kargs.cu_seqlen_k_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
}
else if(kargs.seqlen_k_ptr != nullptr)
{
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
}
@@ -749,6 +777,12 @@ struct FmhaBwdDQDKDVKernel
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
}
// skip if logical lengths are zero
if(kargs.seqlen_q == 0 || kargs.seqlen_k == 0)
{
return;
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if constexpr(!kUseQrQtrDorPipeline)
@@ -1246,6 +1280,8 @@ struct FmhaBwdOGradDotOKernel
struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs
{
const int32_t* seqstart_q_ptr;
const int32_t* seqlen_q_ptr; // per-batch actual length [batch]
const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional
};
using Kargs = std::
@@ -1293,6 +1329,8 @@ struct FmhaBwdOGradDotOKernel
void* d_ptr,
float p_undrop,
const void* seqstart_q_ptr,
const void* seqlen_q_ptr,
const void* cu_seqlen_q_ptr,
ck_tile::index_t hdim_v,
ck_tile::index_t stride_do,
ck_tile::index_t stride_o,
@@ -1311,7 +1349,9 @@ struct FmhaBwdOGradDotOKernel
nhead_stride_do,
nhead_stride_o,
nhead_stride_d},
reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr)};
return kargs;
}
@@ -1355,9 +1395,23 @@ struct FmhaBwdOGradDotOKernel
batch_offset_do = query_start * kargs.stride_do;
batch_offset_d = query_start;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
// Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q
if(kargs.cu_seqlen_q_ptr != nullptr)
{
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
else
{
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
const ck_tile::index_t physical_seqlen_q =
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
kargs.seqlen_q = kargs.seqlen_q_ptr
? static_cast<ck_tile::index_t>(kargs.seqlen_q_ptr[i_batch])
: physical_seqlen_q;
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if(kargs.seqlen_q <= i_m0)
@@ -1521,6 +1575,10 @@ struct FmhaBwdConvertQGradKernel
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_q_ptr; // per-batch actual length [batch]
const int32_t* seqlen_k_ptr; // per-batch actual length [batch]
const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional
const int32_t* cu_seqlen_k_ptr; // cumulative seqlen [batch+1], optional
};
using Kargs = std::conditional_t<kIsGroupMode,
@@ -1569,6 +1627,10 @@ struct FmhaBwdConvertQGradKernel
void* dq_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_q_ptr,
const void* seqlen_k_ptr,
const void* cu_seqlen_q_ptr,
const void* cu_seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t stride_dq,
ck_tile::index_t stride_dq_acc,
@@ -1587,7 +1649,11 @@ struct FmhaBwdConvertQGradKernel
nhead_stride_dq_acc},
{},
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr)};
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr),
reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr)};
if constexpr(kIsDeterministic)
{
@@ -1632,13 +1698,41 @@ struct FmhaBwdConvertQGradKernel
batch_offset_dq = query_start * kargs.stride_dq;
batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
if(kargs.cu_seqlen_q_ptr != nullptr)
{
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
else
{
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
const ck_tile::index_t physical_seqlen_q =
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
kargs.seqlen_q = kargs.seqlen_q_ptr
? static_cast<ck_tile::index_t>(kargs.seqlen_q_ptr[i_batch])
: physical_seqlen_q;
}
if constexpr(kIsDeterministic)
{
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
const ck_tile::index_t physical_seqlen_k =
adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
// Priority: cu_seqlen_k_ptr > seqlen_k_ptr > physical_seqlen_k
if(kargs.cu_seqlen_k_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
}
else
{
kargs.seqlen_k =
kargs.seqlen_k_ptr
? static_cast<ck_tile::index_t>(kargs.seqlen_k_ptr[i_batch])
: physical_seqlen_k;
}
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier

View File

@@ -296,8 +296,8 @@ struct FmhaFwdKernel
// Optional cumulative sequence length pointers for batch mode
// If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // cumulative, length without PAD
const int32_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD
const int32_t* cu_seqlen_k_ptr = nullptr; // cumulative, length without PAD
};
struct FmhaFwdGroupModeKargs
@@ -316,12 +316,12 @@ struct FmhaFwdKernel
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_q_ptr;
const int32_t* seqlen_k_ptr;
// Optional cumulative padded sequence starts (including PAD tokens)
// Used solely to compute memory offsets when sequences are physically padded.
const int32_t* seqstart_padded_q_ptr = nullptr;
const int32_t* seqstart_padded_k_ptr = nullptr;
// Optional per-sequence and cumulative logical (excluding padding) sequence length arrays
const int32_t* cu_seqlen_q_ptr = nullptr;
const int32_t* cu_seqlen_k_ptr = nullptr;
};
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
@@ -379,8 +379,8 @@ struct FmhaFwdKernel
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -471,8 +471,8 @@ struct FmhaFwdKernel
kargs.init_logits_soft_cap(logits_soft_cap);
}
kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
return kargs;
}
@@ -522,8 +522,8 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
return MakeKargsImpl(
q_ptr,
@@ -570,7 +570,7 @@ struct FmhaFwdKernel
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
cu_seqlen_q_ptr,
cu_seqlen_kv_ptr);
cu_seqlen_k_ptr);
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
@@ -619,8 +619,8 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
const std::tuple<const void*, const void*>& drop_seed_offset,
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
return MakeKargsImpl(
q_ptr,
@@ -667,7 +667,7 @@ struct FmhaFwdKernel
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
cu_seqlen_q_ptr,
cu_seqlen_kv_ptr);
cu_seqlen_k_ptr);
}
template <bool Cond = kIsGroupMode>
@@ -681,6 +681,7 @@ struct FmhaFwdKernel
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_q_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
@@ -711,8 +712,8 @@ struct FmhaFwdKernel
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
const void* seqstart_padded_q_ptr = nullptr,
const void* seqstart_padded_k_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -746,6 +747,7 @@ struct FmhaFwdKernel
{}, // placeholder for min_seqlen_q
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
@@ -804,8 +806,8 @@ struct FmhaFwdKernel
kargs.min_seqlen_q = min_seqlen_q;
}
kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
return kargs;
}
@@ -821,6 +823,7 @@ struct FmhaFwdKernel
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_q_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
@@ -850,8 +853,8 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
const void* seqstart_padded_q_ptr = nullptr,
const void* seqstart_padded_k_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
return MakeKargsImpl(
q_ptr,
@@ -863,6 +866,7 @@ struct FmhaFwdKernel
o_ptr,
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_q_ptr,
seqlen_k_ptr,
hdim_q,
hdim_v,
@@ -892,8 +896,8 @@ struct FmhaFwdKernel
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
seqstart_padded_q_ptr,
seqstart_padded_k_ptr);
cu_seqlen_q_ptr,
cu_seqlen_k_ptr);
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
@@ -908,6 +912,7 @@ struct FmhaFwdKernel
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_q_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
@@ -937,8 +942,8 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
const std::tuple<const void*, const void*>& drop_seed_offset,
const void* seqstart_padded_q_ptr = nullptr,
const void* seqstart_padded_k_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
return MakeKargsImpl(
q_ptr,
@@ -950,6 +955,7 @@ struct FmhaFwdKernel
o_ptr,
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_q_ptr,
seqlen_k_ptr,
hdim_q,
hdim_v,
@@ -979,8 +985,8 @@ struct FmhaFwdKernel
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
seqstart_padded_q_ptr,
seqstart_padded_k_ptr);
cu_seqlen_q_ptr,
cu_seqlen_k_ptr);
}
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
@@ -1109,46 +1115,52 @@ struct FmhaFwdKernel
if constexpr(kIsGroupMode)
{
// logical and physical (padded) starts
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
// Use seqstart_q_ptr and seqstart_k_ptr for physical starts
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
? kargs.seqstart_padded_q_ptr[i_batch]
: query_start_unpadded;
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
? kargs.seqstart_padded_k_ptr[i_batch]
: key_start_unpadded;
// DRAM base offsets use physical padded starts
batch_offset_q = query_start_padded * kargs.stride_q;
batch_offset_k = key_start_padded * kargs.stride_k;
// DRAM base offsets use physical starts
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start_padded * kargs.stride_v;
batch_offset_v = key_start * kargs.stride_v;
}
else
{
batch_offset_v = key_start_padded;
batch_offset_v = key_start;
}
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start_padded * kargs.stride_bias;
batch_offset_bias = query_start * kargs.stride_bias;
}
if constexpr(kStoreLSE)
{
// LSE stays indexed by unpadded starts
batch_offset_lse = query_start_unpadded;
// LSE follows the physical layout to stay consistent with other tensors
batch_offset_lse = query_start;
}
if constexpr(kHasDropout)
{
batch_offset_randval = query_start_padded * kargs.stride_randval;
batch_offset_randval = query_start * kargs.stride_randval;
}
batch_offset_o = query_start_padded * kargs.stride_o;
batch_offset_o = query_start * kargs.stride_o;
// real logical lengths (exclude PAD)
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
// Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr
if(kargs.seqlen_q_ptr != nullptr)
{
kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
}
else if(kargs.cu_seqlen_q_ptr != nullptr)
{
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
else
{
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
}
if constexpr(kSkipMinSeqlenQ)
{
@@ -1168,6 +1180,11 @@ struct FmhaFwdKernel
{
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
}
else if(kargs.cu_seqlen_k_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
}
else
{
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
@@ -1201,10 +1218,10 @@ struct FmhaFwdKernel
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
if(kargs.cu_seqlen_kv_ptr != nullptr)
if(kargs.cu_seqlen_k_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
}
}
@@ -1603,39 +1620,46 @@ struct FmhaFwdKernel
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
// get starting offset for each batch - use seqstart_q_ptr/seqstart_k_ptr for
// physical starts
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
? kargs.seqstart_padded_q_ptr[i_batch]
: query_start_unpadded;
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
? kargs.seqstart_padded_k_ptr[i_batch]
: key_start_unpadded;
batch_offset_q = query_start_padded * kargs.stride_q;
batch_offset_k = key_start_padded * kargs.stride_k;
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start_padded * kargs.stride_v;
batch_offset_v = key_start * kargs.stride_v;
}
else
{
// col-major V: offset along seqlen dimension is scalar index
batch_offset_v = key_start_padded;
batch_offset_v = key_start;
}
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start_padded * kargs.stride_bias;
batch_offset_bias = query_start * kargs.stride_bias;
}
// LSE layout is [nhead, total_seqlen], index by unpadded start
batch_offset_lse = query_start_unpadded;
batch_offset_o = query_start_padded * kargs.stride_o;
// LSE layout is [nhead, total_seqlen] following the physical layout for Q/O
batch_offset_lse = query_start;
batch_offset_o = query_start * kargs.stride_o;
// get real # queries & # keys under group mode
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
if(kargs.seqlen_q_ptr != nullptr)
{
kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
}
else if(kargs.cu_seqlen_q_ptr != nullptr)
{
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
else
{
kargs.seqlen_q =
kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
@@ -1648,6 +1672,11 @@ struct FmhaFwdKernel
{
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
}
else if(kargs.cu_seqlen_k_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
}
else
{
kargs.seqlen_k =
@@ -1677,10 +1706,10 @@ struct FmhaFwdKernel
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
if(kargs.cu_seqlen_kv_ptr != nullptr)
if(kargs.cu_seqlen_k_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
}
}

View File

@@ -5,6 +5,7 @@ endif()
set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances")
set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")
set(TEST_NAME "test_ck_tile_fmha")
function(add_gtest_fwd test_group)

View File

@@ -77,6 +77,8 @@ void fmha_bwd_test(const FmhaBwdTestParam& param)
nhead_k,
{seqlen_q},
{seqlen_k},
{-1},
{-1},
hdim_q,
hdim_v,
i_perm,
@@ -246,3 +248,741 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
Values(true) // deterministic
));
TEST_P(Deterministic, DataTypeConfig) { fmha_bwd_test(GetParam()); }
// ============================================================================
// Q/KV Padding Tests - High Priority
// ============================================================================
// 1. BasicQPadding: Test Q padding only (K/V have no padding)
class BasicQPadding : public TestWithParam<FmhaBwdTestParam>
{
};
INSTANTIATE_TEST_SUITE_P(
TestCkTileFmhaBwd,
BasicQPadding,
Combine(Values(mode_enum::group), // Only group mode supports padding
HDimValues,
Values(std::tuple{true, true}), // perm
Values("n"), // no bias for basic test
Values(false), // use_dbias
Values(0.0f), // no dropout
Values(std::tuple{0, 0, false}), // seed/offset/prefs
ValuesIn([]() {
// Define test cases with Q padding: seqlen_q < seqlen_qpad
// Format: {batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str}
// Note: Will set seqlen_qpad separately in the test
std::vector<FmhaBwdDimsMaskParam> test_cases;
// Small padding: logical length close to physical
test_cases.push_back(std::tuple{2, 2, 2, 127, 128, "0"}); // Q: 127->128
test_cases.push_back(std::tuple{3, 4, 2, 250, 256, "0"}); // Q: 250->256
// Medium padding: ~20-30% padding
test_cases.push_back(std::tuple{2, 2, 1, 180, 256, "0"}); // Q: 180->256
test_cases.push_back(std::tuple{3, 3, 3, 350, 512, "1"}); // Q: 350->512, causal
// Large padding: ~50% padding
test_cases.push_back(std::tuple{2, 4, 2, 128, 256, "0"}); // Q: 128->256
test_cases.push_back(std::tuple{2, 2, 2, 200, 512, "2"}); // Q: 200->512, causal
return test_cases;
}()),
Values(false) // deterministic
));
TEST_P(BasicQPadding, DataTypeConfig)
{
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [i_perm, o_perm] = perm;
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
// Set up Q padding: physical length larger than logical
std::vector<ck_tile::index_t> seqlen_qs(batch, seqlen_q);
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k);
// Calculate physical Q length (padded)
ck_tile::index_t seqlen_qpad = ((seqlen_q + 63) / 64) * 64; // Round up to multiple of 64
if(seqlen_q > 256)
seqlen_qpad = ((seqlen_q + 127) / 128) * 128; // Larger alignment for longer sequences
std::vector<ck_tile::index_t> seqlen_qpads(batch, seqlen_qpad);
std::vector<ck_tile::index_t> seqlen_kpads(batch, seqlen_k); // No K padding
auto result = fmha_bwd_run<DataTypeConfig>(
mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,
o_perm,
0, // scale
bias_str,
use_dbias,
p_drop,
drop_seed,
drop_offset,
drop_prefs,
mask_str,
det,
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
stream_config);
if(result == bwd_result::no_instance)
GTEST_SKIP() << "No instance for Q padding with hdim_q=" << hdim_q;
ASSERT_EQ(result, bwd_result::success);
}
// 2. BasicKVPadding: Test K/V padding only (Q has no padding)
class BasicKVPadding : public TestWithParam<FmhaBwdTestParam>
{
};
INSTANTIATE_TEST_SUITE_P(
TestCkTileFmhaBwd,
BasicKVPadding,
Combine(Values(mode_enum::group),
HDimValues,
Values(std::tuple{true, true}),
Values("n"),
Values(false),
Values(0.0f),
Values(std::tuple{0, 0, false}),
ValuesIn([]() {
std::vector<FmhaBwdDimsMaskParam> test_cases;
// Small K/V padding
test_cases.push_back(std::tuple{2, 2, 2, 128, 127, "0"}); // K: 127->128
test_cases.push_back(std::tuple{3, 4, 2, 256, 250, "0"}); // K: 250->256
// Medium K/V padding
test_cases.push_back(std::tuple{2, 2, 1, 256, 180, "0"}); // K: 180->256
test_cases.push_back(std::tuple{3, 3, 3, 512, 350, "1"}); // K: 350->512
// Large K/V padding
test_cases.push_back(std::tuple{2, 4, 2, 256, 128, "0"}); // K: 128->256
test_cases.push_back(std::tuple{2, 2, 2, 512, 200, "2"}); // K: 200->512
return test_cases;
}()),
Values(false)));
TEST_P(BasicKVPadding, DataTypeConfig)
{
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [i_perm, o_perm] = perm;
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
std::vector<ck_tile::index_t> seqlen_qs(batch, seqlen_q);
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k);
// No Q padding
std::vector<ck_tile::index_t> seqlen_qpads(batch, seqlen_q);
// Set up K/V padding
ck_tile::index_t seqlen_kpad = ((seqlen_k + 63) / 64) * 64;
if(seqlen_k > 256)
seqlen_kpad = ((seqlen_k + 127) / 128) * 128;
std::vector<ck_tile::index_t> seqlen_kpads(batch, seqlen_kpad);
auto result = fmha_bwd_run<DataTypeConfig>(
mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,
o_perm,
0,
bias_str,
use_dbias,
p_drop,
drop_seed,
drop_offset,
drop_prefs,
mask_str,
det,
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
stream_config);
if(result == bwd_result::no_instance)
GTEST_SKIP() << "No instance for K/V padding with hdim_q=" << hdim_q;
ASSERT_EQ(result, bwd_result::success);
}
// 3. QKVPadding: Test both Q and K/V padding simultaneously
class QKVPadding : public TestWithParam<FmhaBwdTestParam>
{
};
INSTANTIATE_TEST_SUITE_P(
TestCkTileFmhaBwd,
QKVPadding,
Combine(Values(mode_enum::group),
HDimValues,
Values(std::tuple{true, true}),
Values("n"),
Values(false),
Values(0.0f),
Values(std::tuple{0, 0, false}),
ValuesIn([]() {
std::vector<FmhaBwdDimsMaskParam> test_cases;
// Both Q and K have small padding
test_cases.push_back(std::tuple{2, 2, 2, 120, 125, "0"}); // Q:120->128, K:125->128
// Both Q and K have medium padding
test_cases.push_back(std::tuple{2, 4, 2, 180, 200, "0"}); // Q:180->256, K:200->256
test_cases.push_back(std::tuple{3, 3, 3, 300, 350, "1"}); // Q:300->320, K:350->384
// Both Q and K have large padding
test_cases.push_back(std::tuple{2, 2, 1, 150, 180, "0"}); // Q:150->256, K:180->256
test_cases.push_back(std::tuple{2, 4, 2, 256, 300, "2"}); // Q:256->384, K:300->384
// Asymmetric padding (Q more padded than K)
test_cases.push_back(std::tuple{2, 2, 2, 100, 200, "0"}); // Q:100->128, K:200->256
// Asymmetric padding (K more padded than Q)
test_cases.push_back(std::tuple{2, 3, 1, 200, 100, "0"}); // Q:200->256, K:100->128
return test_cases;
}()),
Values(false)));
TEST_P(QKVPadding, DataTypeConfig)
{
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [i_perm, o_perm] = perm;
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
std::vector<ck_tile::index_t> seqlen_qs(batch, seqlen_q);
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k);
// Set up both Q and K/V padding
ck_tile::index_t seqlen_qpad = ((seqlen_q + 63) / 64) * 64;
if(seqlen_q > 256)
seqlen_qpad = ((seqlen_q + 127) / 128) * 128;
ck_tile::index_t seqlen_kpad = ((seqlen_k + 63) / 64) * 64;
if(seqlen_k > 256)
seqlen_kpad = ((seqlen_k + 127) / 128) * 128;
std::vector<ck_tile::index_t> seqlen_qpads(batch, seqlen_qpad);
std::vector<ck_tile::index_t> seqlen_kpads(batch, seqlen_kpad);
auto result = fmha_bwd_run<DataTypeConfig>(
mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,
o_perm,
0,
bias_str,
use_dbias,
p_drop,
drop_seed,
drop_offset,
drop_prefs,
mask_str,
det,
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
stream_config);
if(result == bwd_result::no_instance)
GTEST_SKIP() << "No instance for Q+K/V padding with hdim_q=" << hdim_q;
ASSERT_EQ(result, bwd_result::success);
}
// 4. ZeroLengthPadding: Test zero-length sequences with padding
class ZeroLengthPadding : public TestWithParam<FmhaBwdTestParam>
{
};
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
ZeroLengthPadding,
Combine(Values(mode_enum::group),
Values(std::tuple{64, -1},
std::tuple{128, -1}), // Limited hdim for edge cases
Values(std::tuple{true, true}),
Values("n"),
Values(false),
Values(0.0f),
Values(std::tuple{0, 0, false}),
Values(
// Test case 1: First batch has zero Q length
std::tuple{3, 2, 2, 0, 128, "0"},
// Test case 2: Middle batch has zero Q length (multi-batch)
std::tuple{3, 2, 1, 100, 128, "0"},
// Test case 3: Last batch has zero Q length
std::tuple{3, 3, 3, 150, 200, "0"},
// Test case 4: Zero K length (first batch)
std::tuple{3, 2, 2, 128, 0, "0"},
// Test case 5: Mixed zero lengths with padding
std::tuple{4, 2, 2, 80, 100, "0"}),
Values(false)));
TEST_P(ZeroLengthPadding, DataTypeConfig)
{
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [i_perm, o_perm] = perm;
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
// Create varied sequence lengths with some zero-length sequences
std::vector<ck_tile::index_t> seqlen_qs;
std::vector<ck_tile::index_t> seqlen_ks;
std::vector<ck_tile::index_t> seqlen_qpads;
std::vector<ck_tile::index_t> seqlen_kpads;
for(int b = 0; b < batch; ++b)
{
// Create pattern with zero-length sequences
ck_tile::index_t q_len, k_len;
if(seqlen_q == 0 && b == 1) // Middle batch zero Q
{
q_len = (b == 1) ? 0 : ((b == 0) ? 100 : 80);
k_len = seqlen_k;
}
else if(seqlen_k == 0 && b == 0) // First batch zero K
{
q_len = seqlen_q;
k_len = (b == 0) ? 0 : 100;
}
else
{
// Varied lengths
q_len = (b == 0 && seqlen_q == 0) ? 0 : (seqlen_q + b * 10);
k_len = seqlen_k + b * 15;
}
seqlen_qs.push_back(q_len);
seqlen_ks.push_back(k_len);
// Add padding for non-zero lengths
ck_tile::index_t qpad = (q_len == 0) ? 0 : ((q_len + 63) / 64) * 64;
ck_tile::index_t kpad = (k_len == 0) ? 0 : ((k_len + 63) / 64) * 64;
seqlen_qpads.push_back(qpad);
seqlen_kpads.push_back(kpad);
}
auto result = fmha_bwd_run<DataTypeConfig>(
mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,
o_perm,
0,
bias_str,
use_dbias,
p_drop,
drop_seed,
drop_offset,
drop_prefs,
mask_str,
det,
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
stream_config);
if(result == bwd_result::no_instance)
GTEST_SKIP() << "No instance for zero-length padding";
ASSERT_EQ(result, bwd_result::success);
}
// ============================================================================
// Q/KV Padding Tests - Medium Priority
// ============================================================================
// 5. VariedPaddingRatios: Test different padding ratios (waste ratios)
class VariedPaddingRatios : public TestWithParam<FmhaBwdTestParam>
{
};
INSTANTIATE_TEST_SUITE_P(
TestCkTileFmhaBwd,
VariedPaddingRatios,
Combine(Values(mode_enum::group),
HDimValues,
Values(std::tuple{true, true}),
Values("n"),
Values(false),
Values(0.0f),
Values(std::tuple{0, 0, false}),
ValuesIn([]() {
std::vector<FmhaBwdDimsMaskParam> test_cases;
// Minimal waste: ~1-5% padding (logical ≈ physical - small delta)
test_cases.push_back(
std::tuple{2, 2, 2, 127, 127, "0"}); // Q:127->128 (~0.8%), K:127->128
test_cases.push_back(
std::tuple{2, 4, 2, 252, 250, "0"}); // Q:252->256 (~1.6%), K:250->256
test_cases.push_back(std::tuple{2, 2, 1, 509, 505, "1"}); // Q:509->512, K:505->512
// Low waste: ~10-20% padding
test_cases.push_back(
std::tuple{2, 3, 3, 220, 210, "0"}); // Q:220->256 (~16%), K:210->256
test_cases.push_back(
std::tuple{3, 2, 2, 440, 420, "0"}); // Q:440->512 (~16%), K:420->512
test_cases.push_back(std::tuple{2, 4, 2, 350, 340, "1"}); // Q:350->384, K:340->384
// Medium waste: ~30-40% padding
test_cases.push_back(
std::tuple{2, 2, 2, 180, 170, "0"}); // Q:180->256 (~42%), K:170->256
test_cases.push_back(
std::tuple{2, 3, 1, 320, 310, "0"}); // Q:320->384 (~20%), K:310->384
test_cases.push_back(std::tuple{3, 2, 2, 350, 340, "2"}); // Q:350->512, K:340->512
// High waste: ~50%+ padding
test_cases.push_back(
std::tuple{2, 2, 2, 130, 130, "0"}); // Q:130->256 (~97%), K:130->256
test_cases.push_back(
std::tuple{2, 4, 2, 260, 260, "0"}); // Q:260->512 (~97%), K:260->512
test_cases.push_back(
std::tuple{2, 2, 1, 200, 200, "1"}); // Q:200->256 (~28%), K:200->256
// Extreme waste: very small logical vs large physical
test_cases.push_back(std::tuple{2, 2, 2, 65, 70, "0"}); // Q:65->128, K:70->128
test_cases.push_back(std::tuple{2, 3, 3, 100, 90, "0"}); // Q:100->128, K:90->128
return test_cases;
}()),
Values(false)));
TEST_P(VariedPaddingRatios, DataTypeConfig)
{
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [i_perm, o_perm] = perm;
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
std::vector<ck_tile::index_t> seqlen_qs(batch, seqlen_q);
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k);
// Calculate padding based on common alignment strategies
auto calc_pad = [](ck_tile::index_t len) -> ck_tile::index_t {
if(len <= 64)
return 64;
else if(len <= 128)
return 128;
else if(len <= 256)
return 256;
else if(len <= 384)
return 384;
else if(len <= 512)
return 512;
else
return ((len + 127) / 128) * 128;
};
std::vector<ck_tile::index_t> seqlen_qpads(batch, calc_pad(seqlen_q));
std::vector<ck_tile::index_t> seqlen_kpads(batch, calc_pad(seqlen_k));
auto result = fmha_bwd_run<DataTypeConfig>(
mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,
o_perm,
0,
bias_str,
use_dbias,
p_drop,
drop_seed,
drop_offset,
drop_prefs,
mask_str,
det,
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
stream_config);
if(result == bwd_result::no_instance)
GTEST_SKIP() << "No instance for varied padding ratios";
ASSERT_EQ(result, bwd_result::success);
}
// 6. PaddingWithMask: Test padding combined with various mask types
class PaddingWithMask : public TestWithParam<FmhaBwdTestParam>
{
};
INSTANTIATE_TEST_SUITE_P(
TestCkTileFmhaBwd,
PaddingWithMask,
Combine(Values(mode_enum::group),
Values(std::tuple{64, -1}, std::tuple{128, -1}), // Focus on common sizes
Values(std::tuple{true, true}),
Values("n"),
Values(false),
Values(0.0f),
Values(std::tuple{0, 0, false}),
ValuesIn([]() {
std::vector<FmhaBwdDimsMaskParam> test_cases;
// No mask with padding (baseline)
test_cases.push_back(std::tuple{2, 2, 2, 200, 180, "0"});
// Causal mask (top-left) with Q padding
test_cases.push_back(std::tuple{2, 2, 2, 200, 256, "1"}); // Q padded, K exact
test_cases.push_back(std::tuple{2, 4, 2, 180, 200, "t"}); // Both padded, causal
// Causal mask (bottom-right) with K/V padding
test_cases.push_back(std::tuple{2, 2, 1, 256, 180, "2"}); // K padded, Q exact
test_cases.push_back(
std::tuple{2, 3, 3, 200, 180, "b"}); // Both padded, bottom-right
// Sliding window attention with padding
test_cases.push_back(std::tuple{2, 2, 2, 200, 190, "t:64,32"}); // SWA + padding
test_cases.push_back(std::tuple{2, 4, 2, 180, 170, "b:32,64"}); // SWA + padding
test_cases.push_back(std::tuple{3, 2, 1, 220, 210, "t:100,50"}); // Larger window
// Sliding window with asymmetric padding
test_cases.push_back(std::tuple{2, 2, 2, 150, 250, "t:80,40"}); // Q more padded
test_cases.push_back(std::tuple{2, 3, 3, 250, 150, "b:50,70"}); // K more padded
// Mixed scenarios
test_cases.push_back(std::tuple{2, 4, 2, 190, 185, "t:50,50"}); // Symmetric window
test_cases.push_back(std::tuple{3, 2, 2, 300, 280, "1"}); // Multi-batch causal
return test_cases;
}()),
Values(false)));
TEST_P(PaddingWithMask, DataTypeConfig)
{
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [i_perm, o_perm] = perm;
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
std::vector<ck_tile::index_t> seqlen_qs(batch, seqlen_q);
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k);
// Apply padding
ck_tile::index_t seqlen_qpad = ((seqlen_q + 63) / 64) * 64;
ck_tile::index_t seqlen_kpad = ((seqlen_k + 63) / 64) * 64;
if(seqlen_q > 256)
seqlen_qpad = ((seqlen_q + 127) / 128) * 128;
if(seqlen_k > 256)
seqlen_kpad = ((seqlen_k + 127) / 128) * 128;
std::vector<ck_tile::index_t> seqlen_qpads(batch, seqlen_qpad);
std::vector<ck_tile::index_t> seqlen_kpads(batch, seqlen_kpad);
auto result = fmha_bwd_run<DataTypeConfig>(
mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,
o_perm,
0,
bias_str,
use_dbias,
p_drop,
drop_seed,
drop_offset,
drop_prefs,
mask_str,
det,
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
stream_config);
if(result == bwd_result::no_instance)
GTEST_SKIP() << "No instance for padding with mask";
ASSERT_EQ(result, bwd_result::success);
}
// 7. MultiBatchPadding: Test multiple batches with different padding configurations
class MultiBatchPadding : public TestWithParam<FmhaBwdTestParam>
{
};
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
MultiBatchPadding,
Combine(Values(mode_enum::group),
Values(std::tuple{64, -1}, std::tuple{128, -1}),
Values(std::tuple{true, true}),
Values("n"),
Values(false),
Values(0.0f),
Values(std::tuple{0, 0, false}),
Values(
// 3 batches with varied Q/K lengths and padding
std::tuple{3, 2, 2, 150, 200, "0"},
// 4 batches with different patterns
std::tuple{4, 3, 3, 180, 220, "0"},
// 5 batches with mixed scenarios
std::tuple{5, 2, 1, 120, 160, "1"},
// 3 batches with causal mask
std::tuple{3, 4, 2, 200, 180, "t"},
// 4 batches with sliding window
std::tuple{4, 2, 2, 160, 140, "t:50,30"}),
Values(false)));
TEST_P(MultiBatchPadding, DataTypeConfig)
{
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
auto [hdim_q, hdim_v] = hdims;
auto [i_perm, o_perm] = perm;
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
auto [batch, nhead, nhead_k, base_seqlen_q, base_seqlen_k, mask_str] = dims_mask;
// Create varied sequence lengths for each batch
std::vector<ck_tile::index_t> seqlen_qs;
std::vector<ck_tile::index_t> seqlen_ks;
std::vector<ck_tile::index_t> seqlen_qpads;
std::vector<ck_tile::index_t> seqlen_kpads;
for(int b = 0; b < batch; ++b)
{
// Generate varied lengths across batches
// Pattern: decreasing, increasing, or random variation
ck_tile::index_t q_len, k_len;
switch(b % 3)
{
case 0: // Decreasing
q_len = base_seqlen_q - b * 20;
k_len = base_seqlen_k - b * 25;
break;
case 1: // Increasing
q_len = base_seqlen_q + b * 15;
k_len = base_seqlen_k + b * 20;
break;
case 2: // Mixed
q_len = base_seqlen_q + (b % 2 == 0 ? 10 : -10) * b;
k_len = base_seqlen_k + (b % 2 == 0 ? -15 : 15) * b;
break;
}
// Ensure positive lengths
q_len = std::max<ck_tile::index_t>(64, q_len);
k_len = std::max<ck_tile::index_t>(64, k_len);
seqlen_qs.push_back(q_len);
seqlen_ks.push_back(k_len);
// Calculate different padding strategies per batch
ck_tile::index_t qpad, kpad;
if(b % 4 == 0)
{
// Tight padding (minimal waste)
qpad = ((q_len + 31) / 32) * 32;
kpad = ((k_len + 31) / 32) * 32;
}
else if(b % 4 == 1)
{
// Medium padding
qpad = ((q_len + 63) / 64) * 64;
kpad = ((k_len + 63) / 64) * 64;
}
else if(b % 4 == 2)
{
// Loose padding
qpad = ((q_len + 127) / 128) * 128;
kpad = ((k_len + 127) / 128) * 128;
}
else
{
// Mixed: Q tight, K loose
qpad = ((q_len + 31) / 32) * 32;
kpad = ((k_len + 127) / 128) * 128;
}
seqlen_qpads.push_back(qpad);
seqlen_kpads.push_back(kpad);
}
auto result = fmha_bwd_run<DataTypeConfig>(
mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,
o_perm,
0,
bias_str,
use_dbias,
p_drop,
drop_seed,
drop_offset,
drop_prefs,
mask_str,
det,
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
stream_config);
if(result == bwd_result::no_instance)
GTEST_SKIP() << "No instance for multi-batch padding";
ASSERT_EQ(result, bwd_result::success);
}