[CK_TILE] fmha: Add query padding support to backward pass

Introduces support for query sequence padding (q_padding) in the FMHA backward pass kernels.
- Passing `seqlen_q_ptr` to the backward kernels to distinguish logical from physical sequence lengths.
- Updating `OGradDotO`, `ConvertQGrad`, and `DQDKDV` kernels to respect logical lengths and handle zero-length sequences.
- Aligning LSE indexing in the forward kernel with the padded layout for consistency.
- Adding a new GTest suite (`test_fmha_bwd_kernel_padding.cpp`) with comprehensive tests for various padding scenarios, including zero-length
  sequences and deterministic mode.
This commit is contained in:
Jeff Huang
2025-10-08 15:35:02 +08:00
parent 86d542f663
commit 4e06eaa417
10 changed files with 839 additions and 35 deletions

View File

@@ -24,11 +24,19 @@ auto create_args(int argc, char* argv[])
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
"also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_qpad",
"-1",
"padded seqlen_q per batch (group mode only). "
"Use \"-s_qpad=p0,p1,...\"; -1 disables explicit padding")
.insert("s_k",
"-1",
"seqlen_k, -1 means equal to s\n"
"also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_kpad",
"-1",
"padded seqlen_k per batch (group mode only). "
"Use \"-s_kpad=k0,k1,...\"; -1 disables explicit padding")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)")
@@ -96,7 +104,9 @@ auto run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
auto seqlen_qs = arg_parser.get_int_vec("s");
auto seqlen_qpads = arg_parser.get_int_vec("s_qpad");
auto seqlen_ks = arg_parser.get_int_vec("s_k");
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
bool i_perm = arg_parser.get_bool("iperm");
@@ -130,6 +140,8 @@ auto run(const ck_tile::ArgParser& arg_parser)
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,

View File

@@ -116,6 +116,7 @@ struct fmha_bwd_args
void* dq_acc_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_q_ptr;
const void* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
@@ -203,6 +204,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
@@ -315,6 +317,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
args.d_ptr,
args.p_undrop,
args.seqstart_q_ptr,
args.seqlen_q_ptr,
args.hdim_v,
args.stride_do,
args.stride_o,
@@ -356,6 +359,8 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
args.dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,

View File

@@ -65,6 +65,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::index_t nhead_k,
std::vector<ck_tile::index_t> seqlen_qs,
std::vector<ck_tile::index_t> seqlen_ks,
std::vector<ck_tile::index_t> seqlen_qpads,
std::vector<ck_tile::index_t> seqlen_kpads,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
bool i_perm,
@@ -119,10 +121,15 @@ bwd_result fmha_bwd_run(mode_enum mode,
std::cerr << "dbias only exists when bias type is elementwise" << std::endl;
return bwd_result::invalid_args;
}
std::vector<ck_tile::index_t> seqlen_kpads;
std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) =
generate_missing_seqlens(mode, batch, seqlen_qs, seqlen_ks, {}, 0, false, random_engine);
ck_tile::ignore = seqlen_kpads;
std::tie(seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads) = generate_missing_seqlens(
mode, batch, seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads, 0, false, random_engine);
bool use_qpadding =
mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1);
bool use_kpadding =
mode == mode_enum::group && (!seqlen_kpads.empty() && seqlen_kpads[0] != -1);
#if 0
std::cout << "seqlen_qs: " << seqlen_qs << std::endl;
std::cout << "seqlen_ks: " << seqlen_ks << std::endl;
@@ -146,8 +153,10 @@ bwd_result fmha_bwd_run(mode_enum mode,
s_randval = true;
}
const auto seqstart_q_host = to_seqstarts(seqlen_qs);
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
const auto seqstart_q_host =
(use_qpadding ? to_seqstarts(seqlen_qpads) : to_seqstarts(seqlen_qs));
const auto seqstart_k_host =
(use_kpadding ? to_seqstarts(seqlen_kpads) : to_seqstarts(seqlen_ks));
using TypeConfig = FmhaBwdTypeConfig<DataTypeConfig>;
@@ -336,6 +345,10 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_q_dev(mode == mode_enum::batch ? 0
: seqlen_qs.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_k_dev(mode == mode_enum::batch ? 0
: seqlen_ks.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
@@ -349,6 +362,13 @@ bwd_result fmha_bwd_run(mode_enum mode,
do_buf.ToDevice(do_host.data());
seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqstart_k_host.data());
if(mode == mode_enum::group)
{
std::vector<int32_t> seqlen_q_host(seqlen_qs.begin(), seqlen_qs.end());
seqlen_q_dev.ToDevice(seqlen_q_host.data());
std::vector<int32_t> seqlen_k_host(seqlen_ks.begin(), seqlen_ks.end());
seqlen_k_dev.ToDevice(seqlen_k_host.data());
}
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
alibi_slope_buf.ToDevice(alibi_slope_host.data());
@@ -440,11 +460,14 @@ bwd_result fmha_bwd_run(mode_enum mode,
}
}();
const void* seqlen_q_ptr_dev = use_qpadding ? seqlen_q_dev.GetDeviceBuffer() : nullptr;
const void* seqlen_k_ptr_dev = use_kpadding ? seqlen_k_dev.GetDeviceBuffer() : nullptr;
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
: bias_buf.GetDeviceBuffer(),
: bias_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(),
do_buf.GetDeviceBuffer(),
@@ -457,7 +480,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
dq_acc_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
nullptr,
seqlen_q_ptr_dev,
seqlen_k_ptr_dev,
shape_seqlen_q,
shape_seqlen_k,
batch,
@@ -472,7 +496,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
stride_k,
stride_v,
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
: stride_bias,
: stride_bias,
stride_o,
stride_randval,
stride_do,

View File

@@ -330,11 +330,12 @@ fwd_result fmha_fwd_run(mode_enum mode,
return fwd_result::invalid_args;
}
std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) =
std::tie(seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads) =
generate_missing_seqlens(mode,
batch,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
/*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0,
need_append_kvcache,
@@ -346,7 +347,13 @@ fwd_result fmha_fwd_run(mode_enum mode,
std::cerr << "kpad must be greater than or equal to seqlen for k" << std::endl;
return fwd_result::invalid_args;
}
if(seqlen_qpads[wb] > 0 && seqlen_qpads[wb] < seqlen_qs[wb])
{
std::cerr << "qpad must be greater than or equal to seqlen for q" << std::endl;
return fwd_result::invalid_args;
}
}
// compute kvcache seqlen_k (before appending knew/vnew)
auto cache_seqlen_ks = seqlen_ks;
std::transform(cache_seqlen_ks.begin(),
@@ -357,6 +364,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
#if 0
std::cout << "seqlen_qs: " << seqlen_qs << std::endl;
std::cout << "seqlen_ks: " << seqlen_ks << std::endl;
std::cout << "seqlen_qpads: " << seqlen_qpads << std::endl;
std::cout << "seqlen_kpads: " << seqlen_kpads << std::endl;
std::cout << "cache_seqlen_ks: " << cache_seqlen_ks << std::endl;
#endif
@@ -514,9 +522,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
// host memory for storing all the tensor elements
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
// logical(unpadded) total seqlen_q for group; batch uses fixed seqlen
const ck_tile::index_t shape_seqlen_q_lse =
(mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back());
// physical(padded) total seqlen_q for group when s_qpad is provided; else use logical
const ck_tile::index_t shape_seqlen_q =
(mode == mode_enum::batch
@@ -580,7 +585,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
// batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q_lse}
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<ODataType> o_host(
@@ -970,8 +975,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
const ck_tile::index_t nhead_stride_bias =
(i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q_lse;
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments
@@ -986,8 +991,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q_lse);
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q_lse);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
@@ -1727,12 +1732,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
if(lse)
{
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
const ck_tile::index_t query_offset_lse =
(mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
lse_host_result.ForEach([&](auto& self, auto idx) {
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset_lse);
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset);
});
std::cout << "lse_host_result shape: " << shape_batch << ", " << nhead << ", " << shape_seqlen_q << std::endl;
cur_pass = ck_tile::check_err(lse_host_result,
lse_host_ref,
"LSE Error: Incorrect results!",

View File

@@ -142,12 +142,14 @@ auto randints(ForwardIterator first,
*/
template <typename RandomEngine>
std::tuple<std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>>
generate_missing_seqlens(mode_enum mode,
ck_tile::index_t batch,
const std::vector<ck_tile::index_t>& q_val,
const std::vector<ck_tile::index_t>& k_val,
const std::vector<ck_tile::index_t>& q_pad_val,
const std::vector<ck_tile::index_t>& k_pad_val,
ck_tile::index_t seqlen_k_min,
bool need_append_kvcache,
@@ -177,7 +179,7 @@ generate_missing_seqlens(mode_enum mode,
return seqlen_ks;
}();
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
auto s_qpad = std::vector<ck_tile::index_t>(batch, -1);
// s_k should be greater than or equal to seqlen_k_min if provided
if(s_k.back() < seqlen_k_min)
{
@@ -187,13 +189,14 @@ generate_missing_seqlens(mode_enum mode,
throw std::runtime_error(msg.str());
}
return std::make_tuple(s_q, s_k, s_kpad);
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
}
else
{
std::vector<ck_tile::index_t> s_q;
std::vector<ck_tile::index_t> s_k;
std::vector<ck_tile::index_t> s_kpad;
std::vector<ck_tile::index_t> s_qpad;
ck_tile::index_t idx = 0;
for(; idx < std::min(static_cast<ck_tile::index_t>(q_val.size()), batch); ++idx)
{
@@ -205,9 +208,15 @@ generate_missing_seqlens(mode_enum mode,
? -1
: k_pad_val[std::min(idx, static_cast<ck_tile::index_t>(k_pad_val.size()) - 1)];
ck_tile::index_t qp =
q_pad_val.empty()
? -1
: q_pad_val[std::min(idx, static_cast<ck_tile::index_t>(q_pad_val.size()) - 1)];
s_q.push_back(q);
s_k.push_back(k < 0 ? q : k);
s_kpad.push_back(kp);
s_qpad.push_back(qp);
// s_k should be greater than or equal to seqlen_k_min
if(s_k.back() < seqlen_k_min)
@@ -228,8 +237,9 @@ generate_missing_seqlens(mode_enum mode,
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back());
s_qpad.insert(s_qpad.end(), batch - idx, s_qpad.back());
}
return std::make_tuple(s_q, s_k, s_kpad);
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
}
}

View File

@@ -313,6 +313,7 @@ struct FmhaBwdDQDKDVKernel
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_q_ptr;
const int32_t* seqlen_k_ptr;
};
@@ -520,6 +521,7 @@ struct FmhaBwdDQDKDVKernel
void* dq_acc_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_q_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
@@ -594,6 +596,7 @@ struct FmhaBwdDQDKDVKernel
{}, // placeholder for deterministic
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
@@ -738,7 +741,10 @@ struct FmhaBwdDQDKDVKernel
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
const ck_tile::index_t physical_seqlen_q =
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
kargs.seqlen_q = kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[i_batch] : physical_seqlen_q;
if(kargs.seqlen_k_ptr != nullptr)
{
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
@@ -749,6 +755,12 @@ struct FmhaBwdDQDKDVKernel
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
}
// skip if logical lengths are zero
if(kargs.seqlen_q == 0 || kargs.seqlen_k == 0)
{
return;
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if constexpr(!kUseQrQtrDorPipeline)
@@ -1246,6 +1258,7 @@ struct FmhaBwdOGradDotOKernel
struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs
{
const int32_t* seqstart_q_ptr;
const int32_t* seqlen_q_ptr;
};
using Kargs = std::
@@ -1293,6 +1306,7 @@ struct FmhaBwdOGradDotOKernel
void* d_ptr,
float p_undrop,
const void* seqstart_q_ptr,
const void* seqlen_q_ptr,
ck_tile::index_t hdim_v,
ck_tile::index_t stride_do,
ck_tile::index_t stride_o,
@@ -1311,7 +1325,8 @@ struct FmhaBwdOGradDotOKernel
nhead_stride_do,
nhead_stride_o,
nhead_stride_d},
reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqlen_q_ptr)};
return kargs;
}
@@ -1357,7 +1372,12 @@ struct FmhaBwdOGradDotOKernel
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
const ck_tile::index_t physical_seqlen_q =
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
const ck_tile::index_t logical_seqlen_q =
kargs.seqlen_q_ptr ? static_cast<ck_tile::index_t>(kargs.seqlen_q_ptr[i_batch])
: physical_seqlen_q;
kargs.seqlen_q = logical_seqlen_q;
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if(kargs.seqlen_q <= i_m0)
@@ -1521,6 +1541,8 @@ struct FmhaBwdConvertQGradKernel
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_q_ptr;
const int32_t* seqlen_k_ptr;
};
using Kargs = std::conditional_t<kIsGroupMode,
@@ -1569,6 +1591,8 @@ struct FmhaBwdConvertQGradKernel
void* dq_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_q_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t stride_dq,
ck_tile::index_t stride_dq_acc,
@@ -1587,7 +1611,9 @@ struct FmhaBwdConvertQGradKernel
nhead_stride_dq_acc},
{},
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr)};
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
if constexpr(kIsDeterministic)
{
@@ -1634,11 +1660,20 @@ struct FmhaBwdConvertQGradKernel
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
const ck_tile::index_t physical_seqlen_q =
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
const ck_tile::index_t logical_seqlen_q =
kargs.seqlen_q_ptr ? static_cast<ck_tile::index_t>(kargs.seqlen_q_ptr[i_batch])
: physical_seqlen_q;
kargs.seqlen_q = logical_seqlen_q;
if constexpr(kIsDeterministic)
{
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
const ck_tile::index_t physical_seqlen_k =
adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
kargs.seqlen_k = kargs.seqlen_k_ptr
? static_cast<ck_tile::index_t>(kargs.seqlen_k_ptr[i_batch])
: physical_seqlen_k;
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier

View File

@@ -1137,8 +1137,8 @@ struct FmhaFwdKernel
}
if constexpr(kStoreLSE)
{
// LSE stays indexed by unpadded starts
batch_offset_lse = query_start_unpadded;
// LSE follows the padded layout to stay consistent with other tensors
batch_offset_lse = query_start_padded;
}
if constexpr(kHasDropout)
{
@@ -1630,8 +1630,8 @@ struct FmhaFwdKernel
batch_offset_bias = query_start_padded * kargs.stride_bias;
}
// LSE layout is [nhead, total_seqlen], index by unpadded start
batch_offset_lse = query_start_unpadded;
// LSE layout is [nhead, total_seqlen] following the padded layout for Q/O
batch_offset_lse = query_start_padded;
batch_offset_o = query_start_padded * kargs.stride_o;
// get real # queries & # keys under group mode

View File

@@ -5,6 +5,11 @@ endif()
set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances")
set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")
add_gtest_executable(test_ck_tile_fmha_bwd_kernels test_fmha_bwd_kernel_padding.cpp)
target_link_libraries(test_ck_tile_fmha_bwd_kernels PRIVATE ${FMHA_BWD_INSTANCES})
set(TEST_NAME "test_ck_tile_fmha")
function(add_gtest_fwd test_group)

View File

@@ -77,6 +77,8 @@ void fmha_bwd_test(const FmhaBwdTestParam& param)
nhead_k,
{seqlen_q},
{seqlen_k},
{-1},
{-1},
hdim_q,
hdim_v,
i_perm,

View File

@@ -0,0 +1,707 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cmath>
#include <functional>
#include <initializer_list>
#include <vector>
#include "ck_tile/host.hpp"
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "example/ck_tile/01_fmha/fmha_bwd.hpp"
#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp" // for get_elimit
#include "gtest/gtest.h"
namespace {
using bf16 = ck_tile::bf16_t;
using ck_tile::DeviceMem;
const ck_tile::stream_config kStreamConfig{
nullptr, // stream_id_
false, // time_kernel_
1, // log_level_
0, // cold_niters_
1, // nrepeat_
true, // is_gpu_timer_
false, // flush_cache_
1, // rotating_count_
};
template <typename T>
std::vector<T> MakeVectorFromFunction(size_t count, std::function<float(size_t)> fn)
{
std::vector<T> data(count);
for(size_t i = 0; i < count; ++i)
{
data[i] = static_cast<T>(fn(i));
}
return data;
}
template <typename T>
std::vector<float> ToFloatVector(const std::vector<T>& src)
{
std::vector<float> dst(src.size());
for(size_t i = 0; i < src.size(); ++i)
{
dst[i] = ck_tile::type_convert<float>(src[i]);
}
return dst;
}
template <typename T>
std::vector<T> CopyDeviceToHost(const ck_tile::DeviceMem& dev, size_t element_count)
{
std::vector<T> host(element_count);
if(element_count > 0)
{
dev.FromDevice(host.data());
}
return host;
}
float SentinelValue() { return -999.f; }
} // namespace
// Typed tests over {fp32, fp16, bf16}
template <typename DataTypeConfig>
class FmhaBwdKernelPaddingTyped : public ::testing::Test
{
};
using KernelPaddingTypes = ::testing::Types<FmhaBwdFp32, FmhaBwdFp16, FmhaBwdBf16>;
TYPED_TEST_SUITE(FmhaBwdKernelPaddingTyped, KernelPaddingTypes);
TYPED_TEST(FmhaBwdKernelPaddingTyped, OGradDotO_GroupPaddingRespectsLogicalLengths)
{
constexpr ck_tile::index_t batch = 2;
constexpr ck_tile::index_t nhead = 1;
constexpr ck_tile::index_t hdim = 128;
constexpr ck_tile::index_t phys_rows0 = 8;
constexpr ck_tile::index_t phys_rows1 = 8;
constexpr ck_tile::index_t max_phys = phys_rows0; // both batches equal
const std::vector<int32_t> seqstart_q_host{0, phys_rows0, phys_rows0 + phys_rows1};
const std::vector<int32_t> seqlen_q_host{5, 3};
const ck_tile::index_t total_rows = seqstart_q_host.back();
// Types per config
using TypeConfig = FmhaBwdTypeConfig<TypeParam>;
using OType = typename TypeConfig::ODataType;
using DOType = typename TypeConfig::OGradDataType;
using DType = typename TypeConfig::DDataType; // float under bf16 config
using AccType = typename TypeConfig::AccDataType; // float
// Host tensors laid out as [b, h, s, d] with b=1, h=1 under group mode
ck_tile::HostTensor<OType> o_host({1, nhead, total_rows, hdim});
ck_tile::HostTensor<DOType> do_host({1, nhead, total_rows, hdim});
ck_tile::HostTensor<DType> d_init_host({1, nhead, total_rows});
// Initialize O/dO with constants using FillConstant (no manual bf16 casts)
const float o_const = 0.25f;
const float do_const = 0.5f;
ck_tile::FillConstant<OType>{ck_tile::type_convert<OType>(o_const)}(o_host);
ck_tile::FillConstant<DOType>{ck_tile::type_convert<DOType>(do_const)}(do_host);
ck_tile::FillConstant<DType>{ck_tile::type_convert<DType>(SentinelValue())}(d_init_host);
// Prepare expected D via runner-style CPU reference, sentinel elsewhere
std::vector<float> expected(static_cast<size_t>(total_rows), SentinelValue());
for(ck_tile::index_t b = 0; b < batch; ++b)
{
const ck_tile::index_t start = seqstart_q_host[b];
const ck_tile::index_t len = seqlen_q_host[b];
for(ck_tile::index_t row = 0; row < len; ++row)
{
AccType acc = 0;
for(ck_tile::index_t c = 0; c < hdim; ++c)
{
// o_host/do_host are [1, nhead, s, d]
const auto o_val = ck_tile::type_convert<AccType>(o_host(0, 0, start + row, c));
const auto do_val = ck_tile::type_convert<AccType>(do_host(0, 0, start + row, c));
acc += do_val * o_val;
}
expected[start + row] = ck_tile::type_convert<float>(acc);
}
}
std::vector<float> sentinel_ref(static_cast<size_t>(total_rows), SentinelValue());
// Device buffers
ck_tile::DeviceMem o_dev(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem do_dev(do_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_dev(d_init_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_dev(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_dev(seqlen_q_host.size() * sizeof(int32_t));
o_dev.ToDevice(o_host.data());
do_dev.ToDevice(do_host.data());
d_dev.ToDevice(d_init_host.data());
seqstart_dev.ToDevice(seqstart_q_host.data());
seqlen_dev.ToDevice(seqlen_q_host.data());
fmha_bwd_args args{};
args.q_ptr = nullptr;
args.k_ptr = nullptr;
args.v_ptr = nullptr;
args.bias_ptr = nullptr;
args.o_ptr = o_dev.GetDeviceBuffer();
args.lse_ptr = nullptr;
args.do_ptr = do_dev.GetDeviceBuffer();
args.d_ptr = d_dev.GetDeviceBuffer();
args.rand_val_ptr = nullptr;
args.dq_ptr = nullptr;
args.dk_ptr = nullptr;
args.dv_ptr = nullptr;
args.dbias_ptr = nullptr;
args.dq_acc_ptr = nullptr;
args.seqstart_q_ptr = seqstart_dev.GetDeviceBuffer();
args.seqstart_k_ptr = nullptr;
args.seqlen_k_ptr = nullptr;
args.seqlen_q_ptr = seqlen_dev.GetDeviceBuffer();
args.seqlen_q = 0;
args.seqlen_k = 0;
args.batch = batch;
args.max_seqlen_q = max_phys;
args.max_seqlen_k = 0;
args.hdim_q = hdim;
args.hdim_v = hdim;
args.nhead_q = nhead;
args.nhead_k = nhead;
args.scale = 1.0f;
args.stride_q = 0;
args.stride_k = 0;
args.stride_v = 0;
args.stride_bias = 0;
args.stride_o = hdim;
args.stride_randval = 0;
args.stride_do = hdim;
args.stride_dq_acc = 0;
args.stride_dq = 0;
args.stride_dk = 0;
args.stride_dv = 0;
args.stride_dbias = 0;
args.nhead_stride_q = 0;
args.nhead_stride_k = 0;
args.nhead_stride_v = 0;
args.nhead_stride_bias = 0;
args.nhead_stride_o = max_phys * hdim;
args.nhead_stride_randval = 0;
args.nhead_stride_do = max_phys * hdim;
args.nhead_stride_lsed = max_phys;
args.nhead_stride_dq_acc = 0;
args.nhead_stride_dq = 0;
args.nhead_stride_dk = 0;
args.nhead_stride_dv = 0;
args.nhead_stride_dbias = 0;
args.batch_stride_q = 0;
args.batch_stride_k = 0;
args.batch_stride_v = 0;
args.batch_stride_bias = 0;
args.batch_stride_o = 0;
args.batch_stride_randval = 0;
args.batch_stride_do = 0;
args.batch_stride_lsed = 0;
args.batch_stride_dq_acc = 0;
args.batch_stride_dq = 0;
args.batch_stride_dk = 0;
args.batch_stride_dv = 0;
args.batch_stride_dbias = 0;
args.split_stride_dq_acc = 0;
args.window_size_left = -1;
args.window_size_right = 0;
args.mask_type = static_cast<ck_tile::index_t>(mask_enum::no_mask);
args.p_drop = 0.0f;
args.p_undrop = 1.0f;
args.drop_seed_offset = std::make_pair(uint64_t{0}, uint64_t{0});
using DotTileTraits = ck_tile::TileFmhaBwdOGradDotOTraits<true, true, 2>;
using DotProblem = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
typename TypeConfig::ODataType,
typename TypeConfig::OGradDataType,
typename TypeConfig::DDataType,
64,
hdim,
true,
DotTileTraits>;
using DotPipeline = ck_tile::BlockFmhaBwdOGradDotO<DotProblem>;
using DotKernel = ck_tile::FmhaBwdOGradDotOKernel<DotPipeline>;
auto [dot_kargs, dot_grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<DotKernel>(args);
const dim3 dot_blocks = DotKernel::BlockSize();
constexpr ck_tile::index_t kDotBlockPerCu = DotKernel::kBlockPerCu;
auto dot_kernel = ck_tile::make_kernel<kDotBlockPerCu>(
DotKernel{}, dot_grids, dot_blocks, 0, dot_kargs);
dot_kernel(kStreamConfig);
ASSERT_EQ(hipDeviceSynchronize(), hipSuccess);
auto d_result_host = CopyDeviceToHost<float>(d_dev, total_rows);
auto [rtol_doto, atol_doto] = get_elimit<TypeParam>(hdim, hdim);
for(size_t i = 0; i < d_result_host.size(); ++i)
{
SCOPED_TRACE(::testing::Message() << "index=" << i);
if(std::fabs(expected[i] - sentinel_ref[i]) < 1e-6f)
{
EXPECT_FLOAT_EQ(d_result_host[i], sentinel_ref[i]);
}
else
{
EXPECT_NEAR(d_result_host[i], expected[i], static_cast<float>(atol_doto));
}
}
}
TYPED_TEST(FmhaBwdKernelPaddingTyped, OGradDotO_VariedPhysicalAndZeroLogical)
{
constexpr ck_tile::index_t batch = 3;
constexpr ck_tile::index_t nhead = 1;
constexpr ck_tile::index_t hdim = 64;
constexpr ck_tile::index_t phys_r0 = 5;
constexpr ck_tile::index_t phys_r1 = 7;
constexpr ck_tile::index_t phys_r2 = 4;
constexpr ck_tile::index_t max_phys = phys_r1;
const std::vector<int32_t> seqstart_q_host{0, phys_r0, phys_r0 + phys_r1, phys_r0 + phys_r1 + phys_r2};
const std::vector<int32_t> seqlen_q_host{3, 0, 4};
const ck_tile::index_t total_rows = seqstart_q_host.back();
using TypeConfig = FmhaBwdTypeConfig<TypeParam>;
using OType = typename TypeConfig::ODataType;
using DOType = typename TypeConfig::OGradDataType;
using DType = typename TypeConfig::DDataType;
ck_tile::HostTensor<OType> o_host({1, nhead, total_rows, hdim});
ck_tile::HostTensor<DOType> do_host({1, nhead, total_rows, hdim});
ck_tile::HostTensor<DType> d_init_host({1, nhead, total_rows});
ck_tile::FillConstant<OType>{ck_tile::type_convert<OType>(1.0f)}(o_host);
ck_tile::FillConstant<DOType>{ck_tile::type_convert<DOType>(2.0f)}(do_host);
ck_tile::FillConstant<DType>{ck_tile::type_convert<DType>(SentinelValue())}(d_init_host);
std::vector<float> expected(static_cast<size_t>(total_rows), SentinelValue());
const float dot = 2.0f * 1.0f * static_cast<float>(hdim);
for(ck_tile::index_t b = 0; b < batch; ++b)
{
const ck_tile::index_t start = seqstart_q_host[b];
const ck_tile::index_t len = seqlen_q_host[b];
for(ck_tile::index_t row = 0; row < len; ++row) expected[start + row] = dot;
}
std::vector<float> sentinel_ref(static_cast<size_t>(total_rows), SentinelValue());
ck_tile::DeviceMem o_dev(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem do_dev(do_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_dev(d_init_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_dev(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_dev(seqlen_q_host.size() * sizeof(int32_t));
o_dev.ToDevice(o_host.data());
do_dev.ToDevice(do_host.data());
d_dev.ToDevice(d_init_host.data());
seqstart_dev.ToDevice(seqstart_q_host.data());
seqlen_dev.ToDevice(seqlen_q_host.data());
fmha_bwd_args args{};
args.o_ptr = o_dev.GetDeviceBuffer();
args.do_ptr = do_dev.GetDeviceBuffer();
args.d_ptr = d_dev.GetDeviceBuffer();
args.seqstart_q_ptr = seqstart_dev.GetDeviceBuffer();
args.seqlen_q_ptr = seqlen_dev.GetDeviceBuffer();
args.batch = batch;
args.max_seqlen_q = max_phys;
args.hdim_v = hdim;
args.nhead_q = nhead;
args.nhead_k = nhead;
args.stride_o = hdim;
args.stride_do = hdim;
args.nhead_stride_o = max_phys * hdim;
args.nhead_stride_do = max_phys * hdim;
args.nhead_stride_lsed = max_phys;
args.p_undrop = 1.0f;
using DotTileTraits = ck_tile::TileFmhaBwdOGradDotOTraits<true, true, 2>;
using DotProblem = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
typename TypeConfig::ODataType,
typename TypeConfig::OGradDataType,
typename TypeConfig::DDataType,
64,
hdim,
true,
DotTileTraits>;
using DotPipeline = ck_tile::BlockFmhaBwdOGradDotO<DotProblem>;
using DotKernel = ck_tile::FmhaBwdOGradDotOKernel<DotPipeline>;
auto [dot_kargs, dot_grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<DotKernel>(args);
const dim3 dot_blocks = DotKernel::BlockSize();
constexpr ck_tile::index_t kDotBlockPerCu = DotKernel::kBlockPerCu;
auto dot_kernel = ck_tile::make_kernel<kDotBlockPerCu>(DotKernel{}, dot_grids, dot_blocks, 0, dot_kargs);
dot_kernel(kStreamConfig);
ASSERT_EQ(hipDeviceSynchronize(), hipSuccess);
auto d_result_host = CopyDeviceToHost<float>(d_dev, total_rows);
auto [rtol, atol] = get_elimit<TypeParam>(hdim, hdim);
for(size_t i = 0; i < d_result_host.size(); ++i)
{
SCOPED_TRACE(::testing::Message() << "index=" << i);
if(std::fabs(expected[i] - sentinel_ref[i]) < 1e-6f)
EXPECT_FLOAT_EQ(d_result_host[i], sentinel_ref[i]);
else
EXPECT_NEAR(d_result_host[i], expected[i], static_cast<float>(atol));
}
}
TYPED_TEST(FmhaBwdKernelPaddingTyped, OGradDotO_VariedPhysical_NoLogicalPtr)
{
constexpr ck_tile::index_t batch = 3;
constexpr ck_tile::index_t nhead = 1;
constexpr ck_tile::index_t hdim = 64;
constexpr ck_tile::index_t phys_r0 = 5;
constexpr ck_tile::index_t phys_r1 = 7;
constexpr ck_tile::index_t phys_r2 = 4;
constexpr ck_tile::index_t max_phys = phys_r1;
const std::vector<int32_t> seqstart_q_host{0, phys_r0, phys_r0 + phys_r1, phys_r0 + phys_r1 + phys_r2};
const ck_tile::index_t total_rows = seqstart_q_host.back();
using TypeConfig = FmhaBwdTypeConfig<TypeParam>;
using OType = typename TypeConfig::ODataType;
using DOType = typename TypeConfig::OGradDataType;
using DType = typename TypeConfig::DDataType;
ck_tile::HostTensor<OType> o_host({1, nhead, total_rows, hdim});
ck_tile::HostTensor<DOType> do_host({1, nhead, total_rows, hdim});
ck_tile::HostTensor<DType> d_init_host({1, nhead, total_rows});
ck_tile::FillConstant<OType>{ck_tile::type_convert<OType>(1.0f)}(o_host);
ck_tile::FillConstant<DOType>{ck_tile::type_convert<DOType>(2.0f)}(do_host);
ck_tile::FillConstant<DType>{ck_tile::type_convert<DType>(SentinelValue())}(d_init_host);
std::vector<float> expected(static_cast<size_t>(total_rows), SentinelValue());
const float dot = 2.0f * 1.0f * static_cast<float>(hdim);
// seqlen_q_ptr is null; logical lengths equal physical lengths per group
for(int r = 0; r < phys_r0; ++r) expected[0 + r] = dot;
for(int r = 0; r < phys_r1; ++r) expected[phys_r0 + r] = dot;
for(int r = 0; r < phys_r2; ++r) expected[phys_r0 + phys_r1 + r] = dot;
std::vector<float> sentinel_ref(static_cast<size_t>(total_rows), SentinelValue());
ck_tile::DeviceMem o_dev(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem do_dev(do_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_dev(d_init_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_dev(seqstart_q_host.size() * sizeof(int32_t));
o_dev.ToDevice(o_host.data());
do_dev.ToDevice(do_host.data());
d_dev.ToDevice(d_init_host.data());
seqstart_dev.ToDevice(seqstart_q_host.data());
fmha_bwd_args args{};
args.o_ptr = o_dev.GetDeviceBuffer();
args.do_ptr = do_dev.GetDeviceBuffer();
args.d_ptr = d_dev.GetDeviceBuffer();
args.seqstart_q_ptr = seqstart_dev.GetDeviceBuffer();
args.seqlen_q_ptr = nullptr; // no logical len ptr
args.batch = batch;
args.max_seqlen_q = max_phys;
args.hdim_v = hdim;
args.nhead_q = nhead;
args.nhead_k = nhead;
args.stride_o = hdim;
args.stride_do = hdim;
args.nhead_stride_o = max_phys * hdim;
args.nhead_stride_do = max_phys * hdim;
args.nhead_stride_lsed = max_phys;
args.p_undrop = 1.0f;
using DotTileTraits = ck_tile::TileFmhaBwdOGradDotOTraits<true, true, 2>;
using DotProblem = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
typename TypeConfig::ODataType,
typename TypeConfig::OGradDataType,
typename TypeConfig::DDataType,
64,
hdim,
true,
DotTileTraits>;
using DotPipeline = ck_tile::BlockFmhaBwdOGradDotO<DotProblem>;
using DotKernel = ck_tile::FmhaBwdOGradDotOKernel<DotPipeline>;
auto [dot_kargs, dot_grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<DotKernel>(args);
const dim3 dot_blocks = DotKernel::BlockSize();
constexpr ck_tile::index_t kDotBlockPerCu = DotKernel::kBlockPerCu;
auto dot_kernel = ck_tile::make_kernel<kDotBlockPerCu>(DotKernel{}, dot_grids, dot_blocks, 0, dot_kargs);
dot_kernel(kStreamConfig);
ASSERT_EQ(hipDeviceSynchronize(), hipSuccess);
auto d_result_host = CopyDeviceToHost<float>(d_dev, total_rows);
auto [rtol, atol] = get_elimit<TypeParam>(hdim, hdim);
for(size_t i = 0; i < d_result_host.size(); ++i)
{
SCOPED_TRACE(::testing::Message() << "index=" << i);
if(std::fabs(expected[i] - sentinel_ref[i]) < 1e-6f)
EXPECT_FLOAT_EQ(d_result_host[i], sentinel_ref[i]);
else
EXPECT_NEAR(d_result_host[i], expected[i], static_cast<float>(atol));
}
}
TYPED_TEST(FmhaBwdKernelPaddingTyped, ConvertQGrad_GroupPaddingAndZeroLength)
{
constexpr ck_tile::index_t batch = 3;
constexpr ck_tile::index_t nhead = 1;
constexpr ck_tile::index_t hdim = 128;
const std::vector<int32_t> seqstart_q_host{0, 6, 6, 10}; // physical lengths: 6,0,4
const std::vector<int32_t> seqlen_q_host{4, 0, 3};
const std::vector<int32_t> seqstart_k_host{0, 7, 15, 18};
const std::vector<int32_t> seqlen_k_host{5, 8, 3};
const ck_tile::index_t total_rows_q = seqstart_q_host.back();
using TypeConfigC = FmhaBwdTypeConfig<TypeParam>;
using AccType = typename TypeConfigC::AccDataType; // float
using QGradType = typename TypeConfigC::QGradDataType; // bf16
ck_tile::HostTensor<AccType> dq_acc_host({1, nhead, total_rows_q, hdim});
ck_tile::HostTensor<QGradType> dq_host_init({1, nhead, total_rows_q, hdim});
const float dq_acc_const = 1.25f;
ck_tile::FillConstant<AccType>{ck_tile::type_convert<AccType>(dq_acc_const)}(dq_acc_host);
ck_tile::FillConstant<QGradType>{ck_tile::type_convert<QGradType>(SentinelValue())}(dq_host_init);
const float dq_sentinel_val = ck_tile::type_convert<float>(
ck_tile::type_convert<QGradType>(SentinelValue()));
std::vector<float> dq_sentinel_ref(static_cast<size_t>(total_rows_q * hdim),
dq_sentinel_val);
std::vector<float> expected = dq_sentinel_ref;
for(ck_tile::index_t b = 0; b < batch; ++b)
{
const ck_tile::index_t q_start = seqstart_q_host[b];
const ck_tile::index_t q_len = seqlen_q_host[b];
for(ck_tile::index_t row = 0; row < q_len; ++row)
{
for(ck_tile::index_t c = 0; c < hdim; ++c)
{
const size_t idx = (q_start + row) * hdim + c;
// dq_acc_host is [1, nhead, s, d]
expected[idx] = ck_tile::type_convert<float>(dq_acc_host(0, 0, q_start + row, c));
}
}
}
ck_tile::DeviceMem dq_acc_dev(dq_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dq_dev(dq_host_init.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_q_dev(seqlen_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_k_dev(seqlen_k_host.size() * sizeof(int32_t));
dq_acc_dev.ToDevice(dq_acc_host.data());
dq_dev.ToDevice(dq_host_init.data());
seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqstart_k_host.data());
seqlen_q_dev.ToDevice(seqlen_q_host.data());
seqlen_k_dev.ToDevice(seqlen_k_host.data());
fmha_bwd_args args{};
args.dq_acc_ptr = dq_acc_dev.GetDeviceBuffer();
args.dq_ptr = dq_dev.GetDeviceBuffer();
args.seqstart_q_ptr = seqstart_q.GetDeviceBuffer();
args.seqstart_k_ptr = seqstart_k.GetDeviceBuffer();
args.seqlen_q_ptr = seqlen_q_dev.GetDeviceBuffer();
args.seqlen_k_ptr = seqlen_k_dev.GetDeviceBuffer();
args.batch = batch;
args.nhead_q = nhead;
args.nhead_k = nhead;
args.hdim_q = hdim;
args.hdim_v = hdim;
args.max_seqlen_q = 6;
args.max_seqlen_k = 8;
args.stride_dq_acc = hdim;
args.stride_dq = hdim;
args.nhead_stride_dq_acc = hdim * args.max_seqlen_q;
args.nhead_stride_dq = hdim * args.max_seqlen_q;
args.nhead_stride_q = 0;
args.nhead_stride_k = 0;
args.nhead_stride_v = 0;
args.nhead_stride_o = 0;
args.batch_stride_dq_acc = 0;
args.batch_stride_dq = 0;
args.split_stride_dq_acc = args.max_seqlen_q * args.stride_dq_acc;
args.window_size_left = -1;
args.window_size_right = 0;
args.mask_type = static_cast<ck_tile::index_t>(mask_enum::no_mask);
args.p_drop = 0.0f;
args.p_undrop = 1.0f;
args.drop_seed_offset = std::make_pair(uint64_t{0}, uint64_t{0});
using TypeConfig = FmhaBwdTypeConfig<TypeParam>;
using ConvertTileTraits = ck_tile::TileFmhaBwdConvertQGradTraits<true, true, 2>;
using ConvertProblem = ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
typename TypeConfig::AccDataType,
typename TypeConfig::QGradDataType,
256,
64,
0,
hdim,
true,
false,
ConvertTileTraits>;
using ConvertPipeline = ck_tile::BlockFmhaBwdConvertQGrad<ConvertProblem>;
using ConvertKernel = ck_tile::FmhaBwdConvertQGradKernel<ConvertPipeline>;
auto [convert_kargs, convert_grids] =
fmha_bwd_convert_dq_create_kargs_and_grids<ConvertKernel>(args);
const dim3 convert_blocks = ConvertKernel::BlockSize();
constexpr ck_tile::index_t kConvertBlockPerCu = ConvertKernel::kBlockPerCu;
auto convert_kernel = ck_tile::make_kernel<kConvertBlockPerCu>(
ConvertKernel{}, convert_grids, convert_blocks, 0, convert_kargs);
convert_kernel(kStreamConfig);
ASSERT_EQ(hipDeviceSynchronize(), hipSuccess);
using QGradOutT = typename TypeConfigC::QGradDataType;
auto dq_result_host_t = CopyDeviceToHost<QGradOutT>(dq_dev, total_rows_q * hdim);
auto dq_result_host = ToFloatVector(dq_result_host_t);
auto [rtol_gpad, atol_gpad] = get_elimit<TypeParam>(hdim, hdim);
for(size_t i = 0; i < dq_result_host.size(); ++i)
{
SCOPED_TRACE(::testing::Message() << "index=" << i);
if(std::fabs(expected[i] - dq_sentinel_ref[i]) < 1e-6f)
{
EXPECT_FLOAT_EQ(dq_result_host[i], dq_sentinel_ref[i]);
}
else
{
EXPECT_NEAR(dq_result_host[i], expected[i], static_cast<float>(atol_gpad));
}
}
}
TYPED_TEST(FmhaBwdKernelPaddingTyped, ConvertQGrad_DeterministicPaddingUsesLogicalLengths)
{
constexpr ck_tile::index_t batch = 1;
constexpr ck_tile::index_t nhead = 1;
constexpr ck_tile::index_t hdim = 128;
constexpr ck_tile::index_t phys_rows = 8;
constexpr ck_tile::index_t logical_rows = 5;
constexpr ck_tile::index_t phys_k = 24;
constexpr ck_tile::index_t logical_k = 20;
constexpr ck_tile::index_t kN0 = 16;
constexpr ck_tile::index_t nsplits = (logical_k + kN0 - 1) / kN0;
const std::vector<int32_t> seqstart_q_host{0, phys_rows};
const std::vector<int32_t> seqlen_q_host{logical_rows};
const std::vector<int32_t> seqstart_k_host{0, phys_k};
const std::vector<int32_t> seqlen_k_host{logical_k};
const ck_tile::index_t total_rows_q = seqstart_q_host.back();
using TypeConfigD = FmhaBwdTypeConfig<TypeParam>;
using AccTypeDet = typename TypeConfigD::AccDataType; // float
using QGradTypeD = typename TypeConfigD::QGradDataType; // bf16
ck_tile::HostTensor<AccTypeDet> dq_acc_host({nsplits, 1, nhead, phys_rows, hdim});
dq_acc_host.ForEach([&](auto& self, auto idx) {
const float s = static_cast<float>(idx[0]);
// Use split-dependent constant to avoid per-element variance and rounding interplay
self(idx) = ck_tile::type_convert<AccTypeDet>(1.0f + 0.1f * s);
});
const float dq_sentinel_val_det = ck_tile::type_convert<float>(
ck_tile::type_convert<QGradTypeD>(SentinelValue()));
std::vector<float> expected(total_rows_q * hdim, dq_sentinel_val_det);
// Expected is the sum over splits of the constant (1.0 + 0.1*s)
for(ck_tile::index_t row = 0; row < logical_rows; ++row)
for(ck_tile::index_t c = 0; c < hdim; ++c)
{
float acc = 0.f;
for(ck_tile::index_t s = 0; s < nsplits; ++s)
{
acc += (1.0f + 0.1f * static_cast<float>(s));
}
expected[row * hdim + c] = acc;
}
ck_tile::HostTensor<QGradTypeD> dq_init({1, nhead, total_rows_q, hdim});
ck_tile::FillConstant<QGradTypeD>{ck_tile::type_convert<QGradTypeD>(SentinelValue())}(dq_init);
DeviceMem dq_acc_dev(dq_acc_host.get_element_space_size_in_bytes());
DeviceMem dq_dev(dq_init.get_element_space_size_in_bytes());
DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
DeviceMem seqlen_q_dev(seqlen_q_host.size() * sizeof(int32_t));
DeviceMem seqlen_k_dev(seqlen_k_host.size() * sizeof(int32_t));
dq_acc_dev.ToDevice(dq_acc_host.data());
dq_dev.ToDevice(dq_init.data());
seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqstart_k_host.data());
seqlen_q_dev.ToDevice(seqlen_q_host.data());
seqlen_k_dev.ToDevice(seqlen_k_host.data());
fmha_bwd_args args{};
args.dq_acc_ptr = dq_acc_dev.GetDeviceBuffer();
args.dq_ptr = dq_dev.GetDeviceBuffer();
args.seqstart_q_ptr = seqstart_q.GetDeviceBuffer();
args.seqstart_k_ptr = seqstart_k.GetDeviceBuffer();
args.seqlen_q_ptr = seqlen_q_dev.GetDeviceBuffer();
args.seqlen_k_ptr = seqlen_k_dev.GetDeviceBuffer();
args.batch = batch;
args.nhead_q = nhead;
args.nhead_k = nhead;
args.hdim_q = hdim;
args.hdim_v = hdim;
args.max_seqlen_q = phys_rows;
args.max_seqlen_k = phys_k;
args.stride_dq_acc = hdim;
args.stride_dq = hdim;
args.nhead_stride_dq_acc = phys_rows * hdim;
args.nhead_stride_dq = phys_rows * hdim;
args.split_stride_dq_acc = phys_rows * hdim;
args.window_size_left = -1;
args.window_size_right = 0;
args.mask_type = static_cast<ck_tile::index_t>(mask_enum::no_mask);
args.p_drop = 0.0f;
args.p_undrop = 1.0f;
args.drop_seed_offset = std::make_pair(uint64_t{0}, uint64_t{0});
using TypeConfig = FmhaBwdTypeConfig<TypeParam>;
using TileTraitsDet = ck_tile::TileFmhaBwdConvertQGradTraits<true, true, 2>;
using PipelineProblemDet = ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
typename TypeConfig::AccDataType,
typename TypeConfig::QGradDataType,
256,
64,
kN0,
hdim,
true,
true,
TileTraitsDet>;
using PipelineDet = ck_tile::BlockFmhaBwdConvertQGrad<PipelineProblemDet>;
using ConvertKernelDet = ck_tile::FmhaBwdConvertQGradKernel<PipelineDet>;
auto [convert_kargs, convert_grids] =
fmha_bwd_convert_dq_create_kargs_and_grids<ConvertKernelDet>(args);
const dim3 convert_blocks = ConvertKernelDet::BlockSize();
constexpr ck_tile::index_t kConvertBlockPerCu = ConvertKernelDet::kBlockPerCu;
auto convert_kernel = ck_tile::make_kernel<kConvertBlockPerCu>(
ConvertKernelDet{}, convert_grids, convert_blocks, 0, convert_kargs);
convert_kernel(kStreamConfig);
ASSERT_EQ(hipDeviceSynchronize(), hipSuccess);
using QGradOutTD = typename TypeConfigD::QGradDataType;
auto dq_result_host_t = CopyDeviceToHost<QGradOutTD>(dq_dev, total_rows_q * hdim);
auto dq_result_host = ToFloatVector(dq_result_host_t);
const float dq_sentinel_val2 = dq_sentinel_val_det;
auto [rtol_det, atol_det] = get_elimit<TypeParam>(hdim, hdim);
for(size_t i = 0; i < dq_result_host.size(); ++i)
{
SCOPED_TRACE(::testing::Message() << "index=" << i);
if(std::fabs(expected[i] - dq_sentinel_val2) < 1e-6f)
{
EXPECT_FLOAT_EQ(dq_result_host[i], dq_sentinel_val2);
}
else
{
EXPECT_NEAR(dq_result_host[i], expected[i], static_cast<float>(atol_det));
}
}
}