mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
[CK_TILE] fmha: Add query padding support to backward pass
Introduces support for query sequence padding (q_padding) in the FMHA backward pass kernels. - Passing `seqlen_q_ptr` to the backward kernels to distinguish logical from physical sequence lengths. - Updating `OGradDotO`, `ConvertQGrad`, and `DQDKDV` kernels to respect logical lengths and handle zero-length sequences. - Aligning LSE indexing in the forward kernel with the padded layout for consistency. - Adding a new GTest suite (`test_fmha_bwd_kernel_padding.cpp`) with comprehensive tests for various padding scenarios, including zero-length sequences and deterministic mode.
This commit is contained in:
@@ -24,11 +24,19 @@ auto create_args(int argc, char* argv[])
|
||||
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
|
||||
"also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
|
||||
"(group mode)")
|
||||
.insert("s_qpad",
|
||||
"-1",
|
||||
"padded seqlen_q per batch (group mode only). "
|
||||
"Use \"-s_qpad=p0,p1,...\"; -1 disables explicit padding")
|
||||
.insert("s_k",
|
||||
"-1",
|
||||
"seqlen_k, -1 means equal to s\n"
|
||||
"also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
|
||||
"(group mode)")
|
||||
.insert("s_kpad",
|
||||
"-1",
|
||||
"padded seqlen_k per batch (group mode only). "
|
||||
"Use \"-s_kpad=k0,k1,...\"; -1 disables explicit padding")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)")
|
||||
@@ -96,7 +104,9 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
auto seqlen_qs = arg_parser.get_int_vec("s");
|
||||
auto seqlen_qpads = arg_parser.get_int_vec("s_qpad");
|
||||
auto seqlen_ks = arg_parser.get_int_vec("s_k");
|
||||
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
@@ -130,6 +140,8 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
nhead_k,
|
||||
seqlen_qs,
|
||||
seqlen_ks,
|
||||
seqlen_qpads,
|
||||
seqlen_kpads,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
|
||||
@@ -116,6 +116,7 @@ struct fmha_bwd_args
|
||||
void* dq_acc_ptr;
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_q_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -203,6 +204,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
dq_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
@@ -315,6 +317,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.d_ptr,
|
||||
args.p_undrop,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.hdim_v,
|
||||
args.stride_do,
|
||||
args.stride_o,
|
||||
@@ -356,6 +359,8 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.dq_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
|
||||
@@ -65,6 +65,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::index_t nhead_k,
|
||||
std::vector<ck_tile::index_t> seqlen_qs,
|
||||
std::vector<ck_tile::index_t> seqlen_ks,
|
||||
std::vector<ck_tile::index_t> seqlen_qpads,
|
||||
std::vector<ck_tile::index_t> seqlen_kpads,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
bool i_perm,
|
||||
@@ -119,10 +121,15 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
std::cerr << "dbias only exists when bias type is elementwise" << std::endl;
|
||||
return bwd_result::invalid_args;
|
||||
}
|
||||
std::vector<ck_tile::index_t> seqlen_kpads;
|
||||
std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) =
|
||||
generate_missing_seqlens(mode, batch, seqlen_qs, seqlen_ks, {}, 0, false, random_engine);
|
||||
ck_tile::ignore = seqlen_kpads;
|
||||
|
||||
std::tie(seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads) = generate_missing_seqlens(
|
||||
mode, batch, seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads, 0, false, random_engine);
|
||||
|
||||
bool use_qpadding =
|
||||
mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1);
|
||||
bool use_kpadding =
|
||||
mode == mode_enum::group && (!seqlen_kpads.empty() && seqlen_kpads[0] != -1);
|
||||
|
||||
#if 0
|
||||
std::cout << "seqlen_qs: " << seqlen_qs << std::endl;
|
||||
std::cout << "seqlen_ks: " << seqlen_ks << std::endl;
|
||||
@@ -146,8 +153,10 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
s_randval = true;
|
||||
}
|
||||
|
||||
const auto seqstart_q_host = to_seqstarts(seqlen_qs);
|
||||
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
|
||||
const auto seqstart_q_host =
|
||||
(use_qpadding ? to_seqstarts(seqlen_qpads) : to_seqstarts(seqlen_qs));
|
||||
const auto seqstart_k_host =
|
||||
(use_kpadding ? to_seqstarts(seqlen_kpads) : to_seqstarts(seqlen_ks));
|
||||
|
||||
using TypeConfig = FmhaBwdTypeConfig<DataTypeConfig>;
|
||||
|
||||
@@ -336,6 +345,10 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqlen_q_dev(mode == mode_enum::batch ? 0
|
||||
: seqlen_qs.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqlen_k_dev(mode == mode_enum::batch ? 0
|
||||
: seqlen_ks.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
@@ -349,6 +362,13 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
do_buf.ToDevice(do_host.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
seqstart_k.ToDevice(seqstart_k_host.data());
|
||||
if(mode == mode_enum::group)
|
||||
{
|
||||
std::vector<int32_t> seqlen_q_host(seqlen_qs.begin(), seqlen_qs.end());
|
||||
seqlen_q_dev.ToDevice(seqlen_q_host.data());
|
||||
std::vector<int32_t> seqlen_k_host(seqlen_ks.begin(), seqlen_ks.end());
|
||||
seqlen_k_dev.ToDevice(seqlen_k_host.data());
|
||||
}
|
||||
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
|
||||
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
|
||||
alibi_slope_buf.ToDevice(alibi_slope_host.data());
|
||||
@@ -440,11 +460,14 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
}
|
||||
}();
|
||||
|
||||
const void* seqlen_q_ptr_dev = use_qpadding ? seqlen_q_dev.GetDeviceBuffer() : nullptr;
|
||||
const void* seqlen_k_ptr_dev = use_kpadding ? seqlen_k_dev.GetDeviceBuffer() : nullptr;
|
||||
|
||||
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
v_buf.GetDeviceBuffer(),
|
||||
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
|
||||
: bias_buf.GetDeviceBuffer(),
|
||||
: bias_buf.GetDeviceBuffer(),
|
||||
o_buf.GetDeviceBuffer(),
|
||||
lse_buf.GetDeviceBuffer(),
|
||||
do_buf.GetDeviceBuffer(),
|
||||
@@ -457,7 +480,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
dq_acc_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
seqstart_k.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
seqlen_q_ptr_dev,
|
||||
seqlen_k_ptr_dev,
|
||||
shape_seqlen_q,
|
||||
shape_seqlen_k,
|
||||
batch,
|
||||
@@ -472,7 +496,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
stride_k,
|
||||
stride_v,
|
||||
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
|
||||
: stride_bias,
|
||||
: stride_bias,
|
||||
stride_o,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
|
||||
@@ -330,11 +330,12 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
return fwd_result::invalid_args;
|
||||
}
|
||||
|
||||
std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) =
|
||||
std::tie(seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads) =
|
||||
generate_missing_seqlens(mode,
|
||||
batch,
|
||||
seqlen_qs,
|
||||
seqlen_ks,
|
||||
seqlen_qpads,
|
||||
seqlen_kpads,
|
||||
/*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0,
|
||||
need_append_kvcache,
|
||||
@@ -346,7 +347,13 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
std::cerr << "kpad must be greater than or equal to seqlen for k" << std::endl;
|
||||
return fwd_result::invalid_args;
|
||||
}
|
||||
if(seqlen_qpads[wb] > 0 && seqlen_qpads[wb] < seqlen_qs[wb])
|
||||
{
|
||||
std::cerr << "qpad must be greater than or equal to seqlen for q" << std::endl;
|
||||
return fwd_result::invalid_args;
|
||||
}
|
||||
}
|
||||
|
||||
// compute kvcache seqlen_k (before appending knew/vnew)
|
||||
auto cache_seqlen_ks = seqlen_ks;
|
||||
std::transform(cache_seqlen_ks.begin(),
|
||||
@@ -357,6 +364,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
#if 0
|
||||
std::cout << "seqlen_qs: " << seqlen_qs << std::endl;
|
||||
std::cout << "seqlen_ks: " << seqlen_ks << std::endl;
|
||||
std::cout << "seqlen_qpads: " << seqlen_qpads << std::endl;
|
||||
std::cout << "seqlen_kpads: " << seqlen_kpads << std::endl;
|
||||
std::cout << "cache_seqlen_ks: " << cache_seqlen_ks << std::endl;
|
||||
#endif
|
||||
@@ -514,9 +522,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
// host memory for storing all the tensor elements
|
||||
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
|
||||
// logical(unpadded) total seqlen_q for group; batch uses fixed seqlen
|
||||
const ck_tile::index_t shape_seqlen_q_lse =
|
||||
(mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back());
|
||||
// physical(padded) total seqlen_q for group when s_qpad is provided; else use logical
|
||||
const ck_tile::index_t shape_seqlen_q =
|
||||
(mode == mode_enum::batch
|
||||
@@ -580,7 +585,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
// batch mode of lse data layout is [batch, nhead, seqlen_q]
|
||||
// group mode of lse data layout is [nhead, total_seqlen_q]
|
||||
ck_tile::HostTensor<LSEDataType> lse_host(
|
||||
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q_lse}
|
||||
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
|
||||
|
||||
ck_tile::HostTensor<ODataType> o_host(
|
||||
@@ -970,8 +975,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
const ck_tile::index_t nhead_stride_bias =
|
||||
(i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q_lse;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse);
|
||||
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
|
||||
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
// setup batch_stride_* arguments
|
||||
@@ -986,8 +991,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
|
||||
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q_lse);
|
||||
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q_lse);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
|
||||
@@ -1727,12 +1732,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
|
||||
const ck_tile::index_t query_offset_lse =
|
||||
(mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
lse_host_result.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset_lse);
|
||||
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset);
|
||||
});
|
||||
|
||||
std::cout << "lse_host_result shape: " << shape_batch << ", " << nhead << ", " << shape_seqlen_q << std::endl;
|
||||
|
||||
cur_pass = ck_tile::check_err(lse_host_result,
|
||||
lse_host_ref,
|
||||
"LSE Error: Incorrect results!",
|
||||
|
||||
@@ -142,12 +142,14 @@ auto randints(ForwardIterator first,
|
||||
*/
|
||||
template <typename RandomEngine>
|
||||
std::tuple<std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>>
|
||||
generate_missing_seqlens(mode_enum mode,
|
||||
ck_tile::index_t batch,
|
||||
const std::vector<ck_tile::index_t>& q_val,
|
||||
const std::vector<ck_tile::index_t>& k_val,
|
||||
const std::vector<ck_tile::index_t>& q_pad_val,
|
||||
const std::vector<ck_tile::index_t>& k_pad_val,
|
||||
ck_tile::index_t seqlen_k_min,
|
||||
bool need_append_kvcache,
|
||||
@@ -177,7 +179,7 @@ generate_missing_seqlens(mode_enum mode,
|
||||
return seqlen_ks;
|
||||
}();
|
||||
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
|
||||
|
||||
auto s_qpad = std::vector<ck_tile::index_t>(batch, -1);
|
||||
// s_k should be greater than or equal to seqlen_k_min if provided
|
||||
if(s_k.back() < seqlen_k_min)
|
||||
{
|
||||
@@ -187,13 +189,14 @@ generate_missing_seqlens(mode_enum mode,
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return std::make_tuple(s_q, s_k, s_kpad);
|
||||
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::vector<ck_tile::index_t> s_q;
|
||||
std::vector<ck_tile::index_t> s_k;
|
||||
std::vector<ck_tile::index_t> s_kpad;
|
||||
std::vector<ck_tile::index_t> s_qpad;
|
||||
ck_tile::index_t idx = 0;
|
||||
for(; idx < std::min(static_cast<ck_tile::index_t>(q_val.size()), batch); ++idx)
|
||||
{
|
||||
@@ -205,9 +208,15 @@ generate_missing_seqlens(mode_enum mode,
|
||||
? -1
|
||||
: k_pad_val[std::min(idx, static_cast<ck_tile::index_t>(k_pad_val.size()) - 1)];
|
||||
|
||||
ck_tile::index_t qp =
|
||||
q_pad_val.empty()
|
||||
? -1
|
||||
: q_pad_val[std::min(idx, static_cast<ck_tile::index_t>(q_pad_val.size()) - 1)];
|
||||
|
||||
s_q.push_back(q);
|
||||
s_k.push_back(k < 0 ? q : k);
|
||||
s_kpad.push_back(kp);
|
||||
s_qpad.push_back(qp);
|
||||
|
||||
// s_k should be greater than or equal to seqlen_k_min
|
||||
if(s_k.back() < seqlen_k_min)
|
||||
@@ -228,8 +237,9 @@ generate_missing_seqlens(mode_enum mode,
|
||||
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
|
||||
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
|
||||
s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back());
|
||||
s_qpad.insert(s_qpad.end(), batch - idx, s_qpad.back());
|
||||
}
|
||||
return std::make_tuple(s_q, s_k, s_kpad);
|
||||
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -313,6 +313,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_q_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
};
|
||||
|
||||
@@ -520,6 +521,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
void* dq_acc_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
@@ -594,6 +596,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
{}, // placeholder for deterministic
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
@@ -738,7 +741,10 @@ struct FmhaBwdDQDKDVKernel
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
const ck_tile::index_t physical_seqlen_q =
|
||||
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
kargs.seqlen_q = kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[i_batch] : physical_seqlen_q;
|
||||
|
||||
if(kargs.seqlen_k_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
|
||||
@@ -749,6 +755,12 @@ struct FmhaBwdDQDKDVKernel
|
||||
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
|
||||
}
|
||||
|
||||
// skip if logical lengths are zero
|
||||
if(kargs.seqlen_q == 0 || kargs.seqlen_k == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
if constexpr(!kUseQrQtrDorPipeline)
|
||||
@@ -1246,6 +1258,7 @@ struct FmhaBwdOGradDotOKernel
|
||||
struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqlen_q_ptr;
|
||||
};
|
||||
|
||||
using Kargs = std::
|
||||
@@ -1293,6 +1306,7 @@ struct FmhaBwdOGradDotOKernel
|
||||
void* d_ptr,
|
||||
float p_undrop,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t stride_do,
|
||||
ck_tile::index_t stride_o,
|
||||
@@ -1311,7 +1325,8 @@ struct FmhaBwdOGradDotOKernel
|
||||
nhead_stride_do,
|
||||
nhead_stride_o,
|
||||
nhead_stride_d},
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_q_ptr)};
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -1357,7 +1372,12 @@ struct FmhaBwdOGradDotOKernel
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
const ck_tile::index_t physical_seqlen_q =
|
||||
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
const ck_tile::index_t logical_seqlen_q =
|
||||
kargs.seqlen_q_ptr ? static_cast<ck_tile::index_t>(kargs.seqlen_q_ptr[i_batch])
|
||||
: physical_seqlen_q;
|
||||
kargs.seqlen_q = logical_seqlen_q;
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
if(kargs.seqlen_q <= i_m0)
|
||||
@@ -1521,6 +1541,8 @@ struct FmhaBwdConvertQGradKernel
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_q_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode,
|
||||
@@ -1569,6 +1591,8 @@ struct FmhaBwdConvertQGradKernel
|
||||
void* dq_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t stride_dq,
|
||||
ck_tile::index_t stride_dq_acc,
|
||||
@@ -1587,7 +1611,9 @@ struct FmhaBwdConvertQGradKernel
|
||||
nhead_stride_dq_acc},
|
||||
{},
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr)};
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
|
||||
if constexpr(kIsDeterministic)
|
||||
{
|
||||
@@ -1634,11 +1660,20 @@ struct FmhaBwdConvertQGradKernel
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
const ck_tile::index_t physical_seqlen_q =
|
||||
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
const ck_tile::index_t logical_seqlen_q =
|
||||
kargs.seqlen_q_ptr ? static_cast<ck_tile::index_t>(kargs.seqlen_q_ptr[i_batch])
|
||||
: physical_seqlen_q;
|
||||
kargs.seqlen_q = logical_seqlen_q;
|
||||
if constexpr(kIsDeterministic)
|
||||
{
|
||||
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
|
||||
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
|
||||
const ck_tile::index_t physical_seqlen_k =
|
||||
adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
|
||||
kargs.seqlen_k = kargs.seqlen_k_ptr
|
||||
? static_cast<ck_tile::index_t>(kargs.seqlen_k_ptr[i_batch])
|
||||
: physical_seqlen_k;
|
||||
}
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
|
||||
@@ -1137,8 +1137,8 @@ struct FmhaFwdKernel
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
// LSE stays indexed by unpadded starts
|
||||
batch_offset_lse = query_start_unpadded;
|
||||
// LSE follows the padded layout to stay consistent with other tensors
|
||||
batch_offset_lse = query_start_padded;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
@@ -1630,8 +1630,8 @@ struct FmhaFwdKernel
|
||||
batch_offset_bias = query_start_padded * kargs.stride_bias;
|
||||
}
|
||||
|
||||
// LSE layout is [nhead, total_seqlen], index by unpadded start
|
||||
batch_offset_lse = query_start_unpadded;
|
||||
// LSE layout is [nhead, total_seqlen] following the padded layout for Q/O
|
||||
batch_offset_lse = query_start_padded;
|
||||
batch_offset_o = query_start_padded * kargs.stride_o;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
|
||||
@@ -5,6 +5,11 @@ endif()
|
||||
|
||||
set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances")
|
||||
set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")
|
||||
|
||||
add_gtest_executable(test_ck_tile_fmha_bwd_kernels test_fmha_bwd_kernel_padding.cpp)
|
||||
target_link_libraries(test_ck_tile_fmha_bwd_kernels PRIVATE ${FMHA_BWD_INSTANCES})
|
||||
|
||||
|
||||
set(TEST_NAME "test_ck_tile_fmha")
|
||||
|
||||
function(add_gtest_fwd test_group)
|
||||
|
||||
@@ -77,6 +77,8 @@ void fmha_bwd_test(const FmhaBwdTestParam& param)
|
||||
nhead_k,
|
||||
{seqlen_q},
|
||||
{seqlen_k},
|
||||
{-1},
|
||||
{-1},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
|
||||
707
test/ck_tile/fmha/test_fmha_bwd_kernel_padding.cpp
Normal file
707
test/ck_tile/fmha/test_fmha_bwd_kernel_padding.cpp
Normal file
@@ -0,0 +1,707 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_bwd.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp" // for get_elimit
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using bf16 = ck_tile::bf16_t;
|
||||
using ck_tile::DeviceMem;
|
||||
|
||||
const ck_tile::stream_config kStreamConfig{
|
||||
nullptr, // stream_id_
|
||||
false, // time_kernel_
|
||||
1, // log_level_
|
||||
0, // cold_niters_
|
||||
1, // nrepeat_
|
||||
true, // is_gpu_timer_
|
||||
false, // flush_cache_
|
||||
1, // rotating_count_
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> MakeVectorFromFunction(size_t count, std::function<float(size_t)> fn)
|
||||
{
|
||||
std::vector<T> data(count);
|
||||
for(size_t i = 0; i < count; ++i)
|
||||
{
|
||||
data[i] = static_cast<T>(fn(i));
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<float> ToFloatVector(const std::vector<T>& src)
|
||||
{
|
||||
std::vector<float> dst(src.size());
|
||||
for(size_t i = 0; i < src.size(); ++i)
|
||||
{
|
||||
dst[i] = ck_tile::type_convert<float>(src[i]);
|
||||
}
|
||||
return dst;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> CopyDeviceToHost(const ck_tile::DeviceMem& dev, size_t element_count)
|
||||
{
|
||||
std::vector<T> host(element_count);
|
||||
if(element_count > 0)
|
||||
{
|
||||
dev.FromDevice(host.data());
|
||||
}
|
||||
return host;
|
||||
}
|
||||
|
||||
float SentinelValue() { return -999.f; }
|
||||
|
||||
} // namespace
|
||||
|
||||
// Typed tests over {fp32, fp16, bf16}
|
||||
template <typename DataTypeConfig>
|
||||
class FmhaBwdKernelPaddingTyped : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
using KernelPaddingTypes = ::testing::Types<FmhaBwdFp32, FmhaBwdFp16, FmhaBwdBf16>;
|
||||
TYPED_TEST_SUITE(FmhaBwdKernelPaddingTyped, KernelPaddingTypes);
|
||||
|
||||
TYPED_TEST(FmhaBwdKernelPaddingTyped, OGradDotO_GroupPaddingRespectsLogicalLengths)
|
||||
{
|
||||
constexpr ck_tile::index_t batch = 2;
|
||||
constexpr ck_tile::index_t nhead = 1;
|
||||
constexpr ck_tile::index_t hdim = 128;
|
||||
constexpr ck_tile::index_t phys_rows0 = 8;
|
||||
constexpr ck_tile::index_t phys_rows1 = 8;
|
||||
constexpr ck_tile::index_t max_phys = phys_rows0; // both batches equal
|
||||
|
||||
const std::vector<int32_t> seqstart_q_host{0, phys_rows0, phys_rows0 + phys_rows1};
|
||||
const std::vector<int32_t> seqlen_q_host{5, 3};
|
||||
|
||||
const ck_tile::index_t total_rows = seqstart_q_host.back();
|
||||
|
||||
// Types per config
|
||||
using TypeConfig = FmhaBwdTypeConfig<TypeParam>;
|
||||
using OType = typename TypeConfig::ODataType;
|
||||
using DOType = typename TypeConfig::OGradDataType;
|
||||
using DType = typename TypeConfig::DDataType; // float under bf16 config
|
||||
using AccType = typename TypeConfig::AccDataType; // float
|
||||
|
||||
// Host tensors laid out as [b, h, s, d] with b=1, h=1 under group mode
|
||||
ck_tile::HostTensor<OType> o_host({1, nhead, total_rows, hdim});
|
||||
ck_tile::HostTensor<DOType> do_host({1, nhead, total_rows, hdim});
|
||||
ck_tile::HostTensor<DType> d_init_host({1, nhead, total_rows});
|
||||
|
||||
// Initialize O/dO with constants using FillConstant (no manual bf16 casts)
|
||||
const float o_const = 0.25f;
|
||||
const float do_const = 0.5f;
|
||||
ck_tile::FillConstant<OType>{ck_tile::type_convert<OType>(o_const)}(o_host);
|
||||
ck_tile::FillConstant<DOType>{ck_tile::type_convert<DOType>(do_const)}(do_host);
|
||||
ck_tile::FillConstant<DType>{ck_tile::type_convert<DType>(SentinelValue())}(d_init_host);
|
||||
|
||||
// Prepare expected D via runner-style CPU reference, sentinel elsewhere
|
||||
std::vector<float> expected(static_cast<size_t>(total_rows), SentinelValue());
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
const ck_tile::index_t start = seqstart_q_host[b];
|
||||
const ck_tile::index_t len = seqlen_q_host[b];
|
||||
for(ck_tile::index_t row = 0; row < len; ++row)
|
||||
{
|
||||
AccType acc = 0;
|
||||
for(ck_tile::index_t c = 0; c < hdim; ++c)
|
||||
{
|
||||
// o_host/do_host are [1, nhead, s, d]
|
||||
const auto o_val = ck_tile::type_convert<AccType>(o_host(0, 0, start + row, c));
|
||||
const auto do_val = ck_tile::type_convert<AccType>(do_host(0, 0, start + row, c));
|
||||
acc += do_val * o_val;
|
||||
}
|
||||
expected[start + row] = ck_tile::type_convert<float>(acc);
|
||||
}
|
||||
}
|
||||
std::vector<float> sentinel_ref(static_cast<size_t>(total_rows), SentinelValue());
|
||||
|
||||
// Device buffers
|
||||
ck_tile::DeviceMem o_dev(o_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem do_dev(do_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_dev(d_init_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_dev(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqlen_dev(seqlen_q_host.size() * sizeof(int32_t));
|
||||
|
||||
o_dev.ToDevice(o_host.data());
|
||||
do_dev.ToDevice(do_host.data());
|
||||
d_dev.ToDevice(d_init_host.data());
|
||||
seqstart_dev.ToDevice(seqstart_q_host.data());
|
||||
seqlen_dev.ToDevice(seqlen_q_host.data());
|
||||
|
||||
fmha_bwd_args args{};
|
||||
args.q_ptr = nullptr;
|
||||
args.k_ptr = nullptr;
|
||||
args.v_ptr = nullptr;
|
||||
args.bias_ptr = nullptr;
|
||||
args.o_ptr = o_dev.GetDeviceBuffer();
|
||||
args.lse_ptr = nullptr;
|
||||
args.do_ptr = do_dev.GetDeviceBuffer();
|
||||
args.d_ptr = d_dev.GetDeviceBuffer();
|
||||
args.rand_val_ptr = nullptr;
|
||||
args.dq_ptr = nullptr;
|
||||
args.dk_ptr = nullptr;
|
||||
args.dv_ptr = nullptr;
|
||||
args.dbias_ptr = nullptr;
|
||||
args.dq_acc_ptr = nullptr;
|
||||
args.seqstart_q_ptr = seqstart_dev.GetDeviceBuffer();
|
||||
args.seqstart_k_ptr = nullptr;
|
||||
args.seqlen_k_ptr = nullptr;
|
||||
args.seqlen_q_ptr = seqlen_dev.GetDeviceBuffer();
|
||||
args.seqlen_q = 0;
|
||||
args.seqlen_k = 0;
|
||||
args.batch = batch;
|
||||
args.max_seqlen_q = max_phys;
|
||||
args.max_seqlen_k = 0;
|
||||
args.hdim_q = hdim;
|
||||
args.hdim_v = hdim;
|
||||
args.nhead_q = nhead;
|
||||
args.nhead_k = nhead;
|
||||
args.scale = 1.0f;
|
||||
args.stride_q = 0;
|
||||
args.stride_k = 0;
|
||||
args.stride_v = 0;
|
||||
args.stride_bias = 0;
|
||||
args.stride_o = hdim;
|
||||
args.stride_randval = 0;
|
||||
args.stride_do = hdim;
|
||||
args.stride_dq_acc = 0;
|
||||
args.stride_dq = 0;
|
||||
args.stride_dk = 0;
|
||||
args.stride_dv = 0;
|
||||
args.stride_dbias = 0;
|
||||
args.nhead_stride_q = 0;
|
||||
args.nhead_stride_k = 0;
|
||||
args.nhead_stride_v = 0;
|
||||
args.nhead_stride_bias = 0;
|
||||
args.nhead_stride_o = max_phys * hdim;
|
||||
args.nhead_stride_randval = 0;
|
||||
args.nhead_stride_do = max_phys * hdim;
|
||||
args.nhead_stride_lsed = max_phys;
|
||||
args.nhead_stride_dq_acc = 0;
|
||||
args.nhead_stride_dq = 0;
|
||||
args.nhead_stride_dk = 0;
|
||||
args.nhead_stride_dv = 0;
|
||||
args.nhead_stride_dbias = 0;
|
||||
args.batch_stride_q = 0;
|
||||
args.batch_stride_k = 0;
|
||||
args.batch_stride_v = 0;
|
||||
args.batch_stride_bias = 0;
|
||||
args.batch_stride_o = 0;
|
||||
args.batch_stride_randval = 0;
|
||||
args.batch_stride_do = 0;
|
||||
args.batch_stride_lsed = 0;
|
||||
args.batch_stride_dq_acc = 0;
|
||||
args.batch_stride_dq = 0;
|
||||
args.batch_stride_dk = 0;
|
||||
args.batch_stride_dv = 0;
|
||||
args.batch_stride_dbias = 0;
|
||||
args.split_stride_dq_acc = 0;
|
||||
args.window_size_left = -1;
|
||||
args.window_size_right = 0;
|
||||
args.mask_type = static_cast<ck_tile::index_t>(mask_enum::no_mask);
|
||||
args.p_drop = 0.0f;
|
||||
args.p_undrop = 1.0f;
|
||||
args.drop_seed_offset = std::make_pair(uint64_t{0}, uint64_t{0});
|
||||
|
||||
using DotTileTraits = ck_tile::TileFmhaBwdOGradDotOTraits<true, true, 2>;
|
||||
using DotProblem = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
|
||||
typename TypeConfig::ODataType,
|
||||
typename TypeConfig::OGradDataType,
|
||||
typename TypeConfig::DDataType,
|
||||
64,
|
||||
hdim,
|
||||
true,
|
||||
DotTileTraits>;
|
||||
using DotPipeline = ck_tile::BlockFmhaBwdOGradDotO<DotProblem>;
|
||||
using DotKernel = ck_tile::FmhaBwdOGradDotOKernel<DotPipeline>;
|
||||
|
||||
auto [dot_kargs, dot_grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<DotKernel>(args);
|
||||
const dim3 dot_blocks = DotKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kDotBlockPerCu = DotKernel::kBlockPerCu;
|
||||
auto dot_kernel = ck_tile::make_kernel<kDotBlockPerCu>(
|
||||
DotKernel{}, dot_grids, dot_blocks, 0, dot_kargs);
|
||||
dot_kernel(kStreamConfig);
|
||||
ASSERT_EQ(hipDeviceSynchronize(), hipSuccess);
|
||||
|
||||
auto d_result_host = CopyDeviceToHost<float>(d_dev, total_rows);
|
||||
|
||||
auto [rtol_doto, atol_doto] = get_elimit<TypeParam>(hdim, hdim);
|
||||
for(size_t i = 0; i < d_result_host.size(); ++i)
|
||||
{
|
||||
SCOPED_TRACE(::testing::Message() << "index=" << i);
|
||||
if(std::fabs(expected[i] - sentinel_ref[i]) < 1e-6f)
|
||||
{
|
||||
EXPECT_FLOAT_EQ(d_result_host[i], sentinel_ref[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
EXPECT_NEAR(d_result_host[i], expected[i], static_cast<float>(atol_doto));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(FmhaBwdKernelPaddingTyped, OGradDotO_VariedPhysicalAndZeroLogical)
|
||||
{
|
||||
constexpr ck_tile::index_t batch = 3;
|
||||
constexpr ck_tile::index_t nhead = 1;
|
||||
constexpr ck_tile::index_t hdim = 64;
|
||||
constexpr ck_tile::index_t phys_r0 = 5;
|
||||
constexpr ck_tile::index_t phys_r1 = 7;
|
||||
constexpr ck_tile::index_t phys_r2 = 4;
|
||||
constexpr ck_tile::index_t max_phys = phys_r1;
|
||||
|
||||
const std::vector<int32_t> seqstart_q_host{0, phys_r0, phys_r0 + phys_r1, phys_r0 + phys_r1 + phys_r2};
|
||||
const std::vector<int32_t> seqlen_q_host{3, 0, 4};
|
||||
const ck_tile::index_t total_rows = seqstart_q_host.back();
|
||||
|
||||
using TypeConfig = FmhaBwdTypeConfig<TypeParam>;
|
||||
using OType = typename TypeConfig::ODataType;
|
||||
using DOType = typename TypeConfig::OGradDataType;
|
||||
using DType = typename TypeConfig::DDataType;
|
||||
|
||||
ck_tile::HostTensor<OType> o_host({1, nhead, total_rows, hdim});
|
||||
ck_tile::HostTensor<DOType> do_host({1, nhead, total_rows, hdim});
|
||||
ck_tile::HostTensor<DType> d_init_host({1, nhead, total_rows});
|
||||
ck_tile::FillConstant<OType>{ck_tile::type_convert<OType>(1.0f)}(o_host);
|
||||
ck_tile::FillConstant<DOType>{ck_tile::type_convert<DOType>(2.0f)}(do_host);
|
||||
ck_tile::FillConstant<DType>{ck_tile::type_convert<DType>(SentinelValue())}(d_init_host);
|
||||
|
||||
std::vector<float> expected(static_cast<size_t>(total_rows), SentinelValue());
|
||||
const float dot = 2.0f * 1.0f * static_cast<float>(hdim);
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
const ck_tile::index_t start = seqstart_q_host[b];
|
||||
const ck_tile::index_t len = seqlen_q_host[b];
|
||||
for(ck_tile::index_t row = 0; row < len; ++row) expected[start + row] = dot;
|
||||
}
|
||||
std::vector<float> sentinel_ref(static_cast<size_t>(total_rows), SentinelValue());
|
||||
|
||||
ck_tile::DeviceMem o_dev(o_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem do_dev(do_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_dev(d_init_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_dev(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqlen_dev(seqlen_q_host.size() * sizeof(int32_t));
|
||||
o_dev.ToDevice(o_host.data());
|
||||
do_dev.ToDevice(do_host.data());
|
||||
d_dev.ToDevice(d_init_host.data());
|
||||
seqstart_dev.ToDevice(seqstart_q_host.data());
|
||||
seqlen_dev.ToDevice(seqlen_q_host.data());
|
||||
|
||||
fmha_bwd_args args{};
|
||||
args.o_ptr = o_dev.GetDeviceBuffer();
|
||||
args.do_ptr = do_dev.GetDeviceBuffer();
|
||||
args.d_ptr = d_dev.GetDeviceBuffer();
|
||||
args.seqstart_q_ptr = seqstart_dev.GetDeviceBuffer();
|
||||
args.seqlen_q_ptr = seqlen_dev.GetDeviceBuffer();
|
||||
args.batch = batch;
|
||||
args.max_seqlen_q = max_phys;
|
||||
args.hdim_v = hdim;
|
||||
args.nhead_q = nhead;
|
||||
args.nhead_k = nhead;
|
||||
args.stride_o = hdim;
|
||||
args.stride_do = hdim;
|
||||
args.nhead_stride_o = max_phys * hdim;
|
||||
args.nhead_stride_do = max_phys * hdim;
|
||||
args.nhead_stride_lsed = max_phys;
|
||||
args.p_undrop = 1.0f;
|
||||
|
||||
using DotTileTraits = ck_tile::TileFmhaBwdOGradDotOTraits<true, true, 2>;
|
||||
using DotProblem = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
|
||||
typename TypeConfig::ODataType,
|
||||
typename TypeConfig::OGradDataType,
|
||||
typename TypeConfig::DDataType,
|
||||
64,
|
||||
hdim,
|
||||
true,
|
||||
DotTileTraits>;
|
||||
using DotPipeline = ck_tile::BlockFmhaBwdOGradDotO<DotProblem>;
|
||||
using DotKernel = ck_tile::FmhaBwdOGradDotOKernel<DotPipeline>;
|
||||
|
||||
auto [dot_kargs, dot_grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<DotKernel>(args);
|
||||
const dim3 dot_blocks = DotKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kDotBlockPerCu = DotKernel::kBlockPerCu;
|
||||
auto dot_kernel = ck_tile::make_kernel<kDotBlockPerCu>(DotKernel{}, dot_grids, dot_blocks, 0, dot_kargs);
|
||||
dot_kernel(kStreamConfig);
|
||||
ASSERT_EQ(hipDeviceSynchronize(), hipSuccess);
|
||||
|
||||
auto d_result_host = CopyDeviceToHost<float>(d_dev, total_rows);
|
||||
auto [rtol, atol] = get_elimit<TypeParam>(hdim, hdim);
|
||||
for(size_t i = 0; i < d_result_host.size(); ++i)
|
||||
{
|
||||
SCOPED_TRACE(::testing::Message() << "index=" << i);
|
||||
if(std::fabs(expected[i] - sentinel_ref[i]) < 1e-6f)
|
||||
EXPECT_FLOAT_EQ(d_result_host[i], sentinel_ref[i]);
|
||||
else
|
||||
EXPECT_NEAR(d_result_host[i], expected[i], static_cast<float>(atol));
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(FmhaBwdKernelPaddingTyped, OGradDotO_VariedPhysical_NoLogicalPtr)
|
||||
{
|
||||
constexpr ck_tile::index_t batch = 3;
|
||||
constexpr ck_tile::index_t nhead = 1;
|
||||
constexpr ck_tile::index_t hdim = 64;
|
||||
constexpr ck_tile::index_t phys_r0 = 5;
|
||||
constexpr ck_tile::index_t phys_r1 = 7;
|
||||
constexpr ck_tile::index_t phys_r2 = 4;
|
||||
constexpr ck_tile::index_t max_phys = phys_r1;
|
||||
|
||||
const std::vector<int32_t> seqstart_q_host{0, phys_r0, phys_r0 + phys_r1, phys_r0 + phys_r1 + phys_r2};
|
||||
const ck_tile::index_t total_rows = seqstart_q_host.back();
|
||||
|
||||
using TypeConfig = FmhaBwdTypeConfig<TypeParam>;
|
||||
using OType = typename TypeConfig::ODataType;
|
||||
using DOType = typename TypeConfig::OGradDataType;
|
||||
using DType = typename TypeConfig::DDataType;
|
||||
|
||||
ck_tile::HostTensor<OType> o_host({1, nhead, total_rows, hdim});
|
||||
ck_tile::HostTensor<DOType> do_host({1, nhead, total_rows, hdim});
|
||||
ck_tile::HostTensor<DType> d_init_host({1, nhead, total_rows});
|
||||
ck_tile::FillConstant<OType>{ck_tile::type_convert<OType>(1.0f)}(o_host);
|
||||
ck_tile::FillConstant<DOType>{ck_tile::type_convert<DOType>(2.0f)}(do_host);
|
||||
ck_tile::FillConstant<DType>{ck_tile::type_convert<DType>(SentinelValue())}(d_init_host);
|
||||
|
||||
std::vector<float> expected(static_cast<size_t>(total_rows), SentinelValue());
|
||||
const float dot = 2.0f * 1.0f * static_cast<float>(hdim);
|
||||
// seqlen_q_ptr is null; logical lengths equal physical lengths per group
|
||||
for(int r = 0; r < phys_r0; ++r) expected[0 + r] = dot;
|
||||
for(int r = 0; r < phys_r1; ++r) expected[phys_r0 + r] = dot;
|
||||
for(int r = 0; r < phys_r2; ++r) expected[phys_r0 + phys_r1 + r] = dot;
|
||||
std::vector<float> sentinel_ref(static_cast<size_t>(total_rows), SentinelValue());
|
||||
|
||||
ck_tile::DeviceMem o_dev(o_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem do_dev(do_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_dev(d_init_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_dev(seqstart_q_host.size() * sizeof(int32_t));
|
||||
o_dev.ToDevice(o_host.data());
|
||||
do_dev.ToDevice(do_host.data());
|
||||
d_dev.ToDevice(d_init_host.data());
|
||||
seqstart_dev.ToDevice(seqstart_q_host.data());
|
||||
|
||||
fmha_bwd_args args{};
|
||||
args.o_ptr = o_dev.GetDeviceBuffer();
|
||||
args.do_ptr = do_dev.GetDeviceBuffer();
|
||||
args.d_ptr = d_dev.GetDeviceBuffer();
|
||||
args.seqstart_q_ptr = seqstart_dev.GetDeviceBuffer();
|
||||
args.seqlen_q_ptr = nullptr; // no logical len ptr
|
||||
args.batch = batch;
|
||||
args.max_seqlen_q = max_phys;
|
||||
args.hdim_v = hdim;
|
||||
args.nhead_q = nhead;
|
||||
args.nhead_k = nhead;
|
||||
args.stride_o = hdim;
|
||||
args.stride_do = hdim;
|
||||
args.nhead_stride_o = max_phys * hdim;
|
||||
args.nhead_stride_do = max_phys * hdim;
|
||||
args.nhead_stride_lsed = max_phys;
|
||||
args.p_undrop = 1.0f;
|
||||
|
||||
using DotTileTraits = ck_tile::TileFmhaBwdOGradDotOTraits<true, true, 2>;
|
||||
using DotProblem = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
|
||||
typename TypeConfig::ODataType,
|
||||
typename TypeConfig::OGradDataType,
|
||||
typename TypeConfig::DDataType,
|
||||
64,
|
||||
hdim,
|
||||
true,
|
||||
DotTileTraits>;
|
||||
using DotPipeline = ck_tile::BlockFmhaBwdOGradDotO<DotProblem>;
|
||||
using DotKernel = ck_tile::FmhaBwdOGradDotOKernel<DotPipeline>;
|
||||
|
||||
auto [dot_kargs, dot_grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<DotKernel>(args);
|
||||
const dim3 dot_blocks = DotKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kDotBlockPerCu = DotKernel::kBlockPerCu;
|
||||
auto dot_kernel = ck_tile::make_kernel<kDotBlockPerCu>(DotKernel{}, dot_grids, dot_blocks, 0, dot_kargs);
|
||||
dot_kernel(kStreamConfig);
|
||||
ASSERT_EQ(hipDeviceSynchronize(), hipSuccess);
|
||||
|
||||
auto d_result_host = CopyDeviceToHost<float>(d_dev, total_rows);
|
||||
auto [rtol, atol] = get_elimit<TypeParam>(hdim, hdim);
|
||||
for(size_t i = 0; i < d_result_host.size(); ++i)
|
||||
{
|
||||
SCOPED_TRACE(::testing::Message() << "index=" << i);
|
||||
if(std::fabs(expected[i] - sentinel_ref[i]) < 1e-6f)
|
||||
EXPECT_FLOAT_EQ(d_result_host[i], sentinel_ref[i]);
|
||||
else
|
||||
EXPECT_NEAR(d_result_host[i], expected[i], static_cast<float>(atol));
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(FmhaBwdKernelPaddingTyped, ConvertQGrad_GroupPaddingAndZeroLength)
|
||||
{
|
||||
constexpr ck_tile::index_t batch = 3;
|
||||
constexpr ck_tile::index_t nhead = 1;
|
||||
constexpr ck_tile::index_t hdim = 128;
|
||||
|
||||
const std::vector<int32_t> seqstart_q_host{0, 6, 6, 10}; // physical lengths: 6,0,4
|
||||
const std::vector<int32_t> seqlen_q_host{4, 0, 3};
|
||||
const std::vector<int32_t> seqstart_k_host{0, 7, 15, 18};
|
||||
const std::vector<int32_t> seqlen_k_host{5, 8, 3};
|
||||
|
||||
const ck_tile::index_t total_rows_q = seqstart_q_host.back();
|
||||
|
||||
using TypeConfigC = FmhaBwdTypeConfig<TypeParam>;
|
||||
using AccType = typename TypeConfigC::AccDataType; // float
|
||||
using QGradType = typename TypeConfigC::QGradDataType; // bf16
|
||||
|
||||
ck_tile::HostTensor<AccType> dq_acc_host({1, nhead, total_rows_q, hdim});
|
||||
ck_tile::HostTensor<QGradType> dq_host_init({1, nhead, total_rows_q, hdim});
|
||||
|
||||
const float dq_acc_const = 1.25f;
|
||||
ck_tile::FillConstant<AccType>{ck_tile::type_convert<AccType>(dq_acc_const)}(dq_acc_host);
|
||||
ck_tile::FillConstant<QGradType>{ck_tile::type_convert<QGradType>(SentinelValue())}(dq_host_init);
|
||||
|
||||
const float dq_sentinel_val = ck_tile::type_convert<float>(
|
||||
ck_tile::type_convert<QGradType>(SentinelValue()));
|
||||
std::vector<float> dq_sentinel_ref(static_cast<size_t>(total_rows_q * hdim),
|
||||
dq_sentinel_val);
|
||||
std::vector<float> expected = dq_sentinel_ref;
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
const ck_tile::index_t q_start = seqstart_q_host[b];
|
||||
const ck_tile::index_t q_len = seqlen_q_host[b];
|
||||
for(ck_tile::index_t row = 0; row < q_len; ++row)
|
||||
{
|
||||
for(ck_tile::index_t c = 0; c < hdim; ++c)
|
||||
{
|
||||
const size_t idx = (q_start + row) * hdim + c;
|
||||
// dq_acc_host is [1, nhead, s, d]
|
||||
expected[idx] = ck_tile::type_convert<float>(dq_acc_host(0, 0, q_start + row, c));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem dq_acc_dev(dq_acc_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dq_dev(dq_host_init.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqlen_q_dev(seqlen_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqlen_k_dev(seqlen_k_host.size() * sizeof(int32_t));
|
||||
|
||||
dq_acc_dev.ToDevice(dq_acc_host.data());
|
||||
dq_dev.ToDevice(dq_host_init.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
seqstart_k.ToDevice(seqstart_k_host.data());
|
||||
seqlen_q_dev.ToDevice(seqlen_q_host.data());
|
||||
seqlen_k_dev.ToDevice(seqlen_k_host.data());
|
||||
|
||||
fmha_bwd_args args{};
|
||||
args.dq_acc_ptr = dq_acc_dev.GetDeviceBuffer();
|
||||
args.dq_ptr = dq_dev.GetDeviceBuffer();
|
||||
args.seqstart_q_ptr = seqstart_q.GetDeviceBuffer();
|
||||
args.seqstart_k_ptr = seqstart_k.GetDeviceBuffer();
|
||||
args.seqlen_q_ptr = seqlen_q_dev.GetDeviceBuffer();
|
||||
args.seqlen_k_ptr = seqlen_k_dev.GetDeviceBuffer();
|
||||
args.batch = batch;
|
||||
args.nhead_q = nhead;
|
||||
args.nhead_k = nhead;
|
||||
args.hdim_q = hdim;
|
||||
args.hdim_v = hdim;
|
||||
args.max_seqlen_q = 6;
|
||||
args.max_seqlen_k = 8;
|
||||
args.stride_dq_acc = hdim;
|
||||
args.stride_dq = hdim;
|
||||
args.nhead_stride_dq_acc = hdim * args.max_seqlen_q;
|
||||
args.nhead_stride_dq = hdim * args.max_seqlen_q;
|
||||
args.nhead_stride_q = 0;
|
||||
args.nhead_stride_k = 0;
|
||||
args.nhead_stride_v = 0;
|
||||
args.nhead_stride_o = 0;
|
||||
args.batch_stride_dq_acc = 0;
|
||||
args.batch_stride_dq = 0;
|
||||
args.split_stride_dq_acc = args.max_seqlen_q * args.stride_dq_acc;
|
||||
args.window_size_left = -1;
|
||||
args.window_size_right = 0;
|
||||
args.mask_type = static_cast<ck_tile::index_t>(mask_enum::no_mask);
|
||||
args.p_drop = 0.0f;
|
||||
args.p_undrop = 1.0f;
|
||||
args.drop_seed_offset = std::make_pair(uint64_t{0}, uint64_t{0});
|
||||
|
||||
using TypeConfig = FmhaBwdTypeConfig<TypeParam>;
|
||||
using ConvertTileTraits = ck_tile::TileFmhaBwdConvertQGradTraits<true, true, 2>;
|
||||
using ConvertProblem = ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::QGradDataType,
|
||||
256,
|
||||
64,
|
||||
0,
|
||||
hdim,
|
||||
true,
|
||||
false,
|
||||
ConvertTileTraits>;
|
||||
using ConvertPipeline = ck_tile::BlockFmhaBwdConvertQGrad<ConvertProblem>;
|
||||
using ConvertKernel = ck_tile::FmhaBwdConvertQGradKernel<ConvertPipeline>;
|
||||
|
||||
auto [convert_kargs, convert_grids] =
|
||||
fmha_bwd_convert_dq_create_kargs_and_grids<ConvertKernel>(args);
|
||||
const dim3 convert_blocks = ConvertKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kConvertBlockPerCu = ConvertKernel::kBlockPerCu;
|
||||
auto convert_kernel = ck_tile::make_kernel<kConvertBlockPerCu>(
|
||||
ConvertKernel{}, convert_grids, convert_blocks, 0, convert_kargs);
|
||||
convert_kernel(kStreamConfig);
|
||||
ASSERT_EQ(hipDeviceSynchronize(), hipSuccess);
|
||||
|
||||
using QGradOutT = typename TypeConfigC::QGradDataType;
|
||||
auto dq_result_host_t = CopyDeviceToHost<QGradOutT>(dq_dev, total_rows_q * hdim);
|
||||
auto dq_result_host = ToFloatVector(dq_result_host_t);
|
||||
|
||||
auto [rtol_gpad, atol_gpad] = get_elimit<TypeParam>(hdim, hdim);
|
||||
for(size_t i = 0; i < dq_result_host.size(); ++i)
|
||||
{
|
||||
SCOPED_TRACE(::testing::Message() << "index=" << i);
|
||||
if(std::fabs(expected[i] - dq_sentinel_ref[i]) < 1e-6f)
|
||||
{
|
||||
EXPECT_FLOAT_EQ(dq_result_host[i], dq_sentinel_ref[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
EXPECT_NEAR(dq_result_host[i], expected[i], static_cast<float>(atol_gpad));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(FmhaBwdKernelPaddingTyped, ConvertQGrad_DeterministicPaddingUsesLogicalLengths)
|
||||
{
|
||||
constexpr ck_tile::index_t batch = 1;
|
||||
constexpr ck_tile::index_t nhead = 1;
|
||||
constexpr ck_tile::index_t hdim = 128;
|
||||
constexpr ck_tile::index_t phys_rows = 8;
|
||||
constexpr ck_tile::index_t logical_rows = 5;
|
||||
constexpr ck_tile::index_t phys_k = 24;
|
||||
constexpr ck_tile::index_t logical_k = 20;
|
||||
constexpr ck_tile::index_t kN0 = 16;
|
||||
constexpr ck_tile::index_t nsplits = (logical_k + kN0 - 1) / kN0;
|
||||
|
||||
const std::vector<int32_t> seqstart_q_host{0, phys_rows};
|
||||
const std::vector<int32_t> seqlen_q_host{logical_rows};
|
||||
const std::vector<int32_t> seqstart_k_host{0, phys_k};
|
||||
const std::vector<int32_t> seqlen_k_host{logical_k};
|
||||
const ck_tile::index_t total_rows_q = seqstart_q_host.back();
|
||||
|
||||
using TypeConfigD = FmhaBwdTypeConfig<TypeParam>;
|
||||
using AccTypeDet = typename TypeConfigD::AccDataType; // float
|
||||
using QGradTypeD = typename TypeConfigD::QGradDataType; // bf16
|
||||
|
||||
ck_tile::HostTensor<AccTypeDet> dq_acc_host({nsplits, 1, nhead, phys_rows, hdim});
|
||||
dq_acc_host.ForEach([&](auto& self, auto idx) {
|
||||
const float s = static_cast<float>(idx[0]);
|
||||
// Use split-dependent constant to avoid per-element variance and rounding interplay
|
||||
self(idx) = ck_tile::type_convert<AccTypeDet>(1.0f + 0.1f * s);
|
||||
});
|
||||
|
||||
const float dq_sentinel_val_det = ck_tile::type_convert<float>(
|
||||
ck_tile::type_convert<QGradTypeD>(SentinelValue()));
|
||||
std::vector<float> expected(total_rows_q * hdim, dq_sentinel_val_det);
|
||||
// Expected is the sum over splits of the constant (1.0 + 0.1*s)
|
||||
for(ck_tile::index_t row = 0; row < logical_rows; ++row)
|
||||
for(ck_tile::index_t c = 0; c < hdim; ++c)
|
||||
{
|
||||
float acc = 0.f;
|
||||
for(ck_tile::index_t s = 0; s < nsplits; ++s)
|
||||
{
|
||||
acc += (1.0f + 0.1f * static_cast<float>(s));
|
||||
}
|
||||
expected[row * hdim + c] = acc;
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<QGradTypeD> dq_init({1, nhead, total_rows_q, hdim});
|
||||
ck_tile::FillConstant<QGradTypeD>{ck_tile::type_convert<QGradTypeD>(SentinelValue())}(dq_init);
|
||||
|
||||
DeviceMem dq_acc_dev(dq_acc_host.get_element_space_size_in_bytes());
|
||||
DeviceMem dq_dev(dq_init.get_element_space_size_in_bytes());
|
||||
DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
DeviceMem seqlen_q_dev(seqlen_q_host.size() * sizeof(int32_t));
|
||||
DeviceMem seqlen_k_dev(seqlen_k_host.size() * sizeof(int32_t));
|
||||
|
||||
dq_acc_dev.ToDevice(dq_acc_host.data());
|
||||
dq_dev.ToDevice(dq_init.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
seqstart_k.ToDevice(seqstart_k_host.data());
|
||||
seqlen_q_dev.ToDevice(seqlen_q_host.data());
|
||||
seqlen_k_dev.ToDevice(seqlen_k_host.data());
|
||||
|
||||
fmha_bwd_args args{};
|
||||
args.dq_acc_ptr = dq_acc_dev.GetDeviceBuffer();
|
||||
args.dq_ptr = dq_dev.GetDeviceBuffer();
|
||||
args.seqstart_q_ptr = seqstart_q.GetDeviceBuffer();
|
||||
args.seqstart_k_ptr = seqstart_k.GetDeviceBuffer();
|
||||
args.seqlen_q_ptr = seqlen_q_dev.GetDeviceBuffer();
|
||||
args.seqlen_k_ptr = seqlen_k_dev.GetDeviceBuffer();
|
||||
args.batch = batch;
|
||||
args.nhead_q = nhead;
|
||||
args.nhead_k = nhead;
|
||||
args.hdim_q = hdim;
|
||||
args.hdim_v = hdim;
|
||||
args.max_seqlen_q = phys_rows;
|
||||
args.max_seqlen_k = phys_k;
|
||||
args.stride_dq_acc = hdim;
|
||||
args.stride_dq = hdim;
|
||||
args.nhead_stride_dq_acc = phys_rows * hdim;
|
||||
args.nhead_stride_dq = phys_rows * hdim;
|
||||
args.split_stride_dq_acc = phys_rows * hdim;
|
||||
args.window_size_left = -1;
|
||||
args.window_size_right = 0;
|
||||
args.mask_type = static_cast<ck_tile::index_t>(mask_enum::no_mask);
|
||||
args.p_drop = 0.0f;
|
||||
args.p_undrop = 1.0f;
|
||||
args.drop_seed_offset = std::make_pair(uint64_t{0}, uint64_t{0});
|
||||
|
||||
using TypeConfig = FmhaBwdTypeConfig<TypeParam>;
|
||||
using TileTraitsDet = ck_tile::TileFmhaBwdConvertQGradTraits<true, true, 2>;
|
||||
using PipelineProblemDet = ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::QGradDataType,
|
||||
256,
|
||||
64,
|
||||
kN0,
|
||||
hdim,
|
||||
true,
|
||||
true,
|
||||
TileTraitsDet>;
|
||||
using PipelineDet = ck_tile::BlockFmhaBwdConvertQGrad<PipelineProblemDet>;
|
||||
using ConvertKernelDet = ck_tile::FmhaBwdConvertQGradKernel<PipelineDet>;
|
||||
|
||||
auto [convert_kargs, convert_grids] =
|
||||
fmha_bwd_convert_dq_create_kargs_and_grids<ConvertKernelDet>(args);
|
||||
const dim3 convert_blocks = ConvertKernelDet::BlockSize();
|
||||
constexpr ck_tile::index_t kConvertBlockPerCu = ConvertKernelDet::kBlockPerCu;
|
||||
auto convert_kernel = ck_tile::make_kernel<kConvertBlockPerCu>(
|
||||
ConvertKernelDet{}, convert_grids, convert_blocks, 0, convert_kargs);
|
||||
convert_kernel(kStreamConfig);
|
||||
ASSERT_EQ(hipDeviceSynchronize(), hipSuccess);
|
||||
|
||||
using QGradOutTD = typename TypeConfigD::QGradDataType;
|
||||
auto dq_result_host_t = CopyDeviceToHost<QGradOutTD>(dq_dev, total_rows_q * hdim);
|
||||
auto dq_result_host = ToFloatVector(dq_result_host_t);
|
||||
|
||||
const float dq_sentinel_val2 = dq_sentinel_val_det;
|
||||
|
||||
auto [rtol_det, atol_det] = get_elimit<TypeParam>(hdim, hdim);
|
||||
for(size_t i = 0; i < dq_result_host.size(); ++i)
|
||||
{
|
||||
SCOPED_TRACE(::testing::Message() << "index=" << i);
|
||||
if(std::fabs(expected[i] - dq_sentinel_val2) < 1e-6f)
|
||||
{
|
||||
EXPECT_FLOAT_EQ(dq_result_host[i], dq_sentinel_val2);
|
||||
}
|
||||
else
|
||||
{
|
||||
EXPECT_NEAR(dq_result_host[i], expected[i], static_cast<float>(atol_det));
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user