diff --git a/example/ck_tile/01_fmha/example_fmha_bwd.cpp b/example/ck_tile/01_fmha/example_fmha_bwd.cpp index 73b3c1e619..3f8071be32 100644 --- a/example/ck_tile/01_fmha/example_fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_bwd.cpp @@ -24,11 +24,19 @@ auto create_args(int argc, char* argv[]) "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n" "also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch " "(group mode)") + .insert("s_qpad", + "-1", + "padded seqlen_q per batch (group mode only). " + "Use \"-s_qpad=p0,p1,...\"; -1 disables explicit padding") .insert("s_k", "-1", "seqlen_k, -1 means equal to s\n" "also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch " "(group mode)") + .insert("s_kpad", + "-1", + "padded seqlen_k per batch (group mode only). " + "Use \"-s_kpad=k0,k1,...\"; -1 disables explicit padding") .insert("d", "128", "head dim for q, k") .insert("d_v", "-1", "head dim for v, -1 means equal to d") .insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)") @@ -96,7 +104,9 @@ auto run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t nhead = arg_parser.get_int("h"); ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); auto seqlen_qs = arg_parser.get_int_vec("s"); + auto seqlen_qpads = arg_parser.get_int_vec("s_qpad"); auto seqlen_ks = arg_parser.get_int_vec("s_k"); + auto seqlen_kpads = arg_parser.get_int_vec("s_kpad"); ck_tile::index_t hdim_q = arg_parser.get_int("d"); ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); bool i_perm = arg_parser.get_bool("iperm"); @@ -130,6 +140,8 @@ auto run(const ck_tile::ArgParser& arg_parser) nhead_k, seqlen_qs, seqlen_ks, + seqlen_qpads, + seqlen_kpads, hdim_q, hdim_v, i_perm, diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 6cd1cd94fa..eac1840a19 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -116,6 +116,7 @@ struct fmha_bwd_args void* dq_acc_ptr; const void* seqstart_q_ptr; const void* seqstart_k_ptr; + const void* seqlen_q_ptr; const void* seqlen_k_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -203,6 +204,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) dq_ptr, args.seqstart_q_ptr, args.seqstart_k_ptr, + args.seqlen_q_ptr, args.seqlen_k_ptr, args.hdim_q, args.hdim_v, @@ -315,6 +317,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) args.d_ptr, args.p_undrop, args.seqstart_q_ptr, + args.seqlen_q_ptr, args.hdim_v, args.stride_do, args.stride_o, @@ -356,6 +359,8 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) args.dq_ptr, args.seqstart_q_ptr, args.seqstart_k_ptr, + args.seqlen_q_ptr, + args.seqlen_k_ptr, args.hdim_q, args.stride_dq, args.stride_dq_acc, diff --git a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp index b6f2c8ca30..01513e06ed 100644 --- a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp @@ -65,6 +65,8 @@ bwd_result fmha_bwd_run(mode_enum mode, ck_tile::index_t nhead_k, std::vector seqlen_qs, std::vector seqlen_ks, + std::vector seqlen_qpads, + std::vector seqlen_kpads, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, bool i_perm, @@ -119,10 +121,15 @@ bwd_result fmha_bwd_run(mode_enum mode, std::cerr << "dbias only exists when bias type is elementwise" << std::endl; return bwd_result::invalid_args; } - std::vector seqlen_kpads; - std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) = - generate_missing_seqlens(mode, batch, seqlen_qs, seqlen_ks, {}, 0, false, random_engine); - ck_tile::ignore = seqlen_kpads; + + std::tie(seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads) = generate_missing_seqlens( + mode, batch, seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads, 0, false, random_engine); + + bool use_qpadding = + mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1); + bool use_kpadding = + mode == mode_enum::group && (!seqlen_kpads.empty() && seqlen_kpads[0] != -1); + #if 0 std::cout << "seqlen_qs: " << seqlen_qs << std::endl; std::cout << "seqlen_ks: " << seqlen_ks << std::endl; @@ -146,8 +153,10 @@ bwd_result fmha_bwd_run(mode_enum mode, s_randval = true; } - const auto seqstart_q_host = to_seqstarts(seqlen_qs); - const auto seqstart_k_host = to_seqstarts(seqlen_ks); + const auto seqstart_q_host = + (use_qpadding ? to_seqstarts(seqlen_qpads) : to_seqstarts(seqlen_qs)); + const auto seqstart_k_host = + (use_kpadding ? to_seqstarts(seqlen_kpads) : to_seqstarts(seqlen_ks)); using TypeConfig = FmhaBwdTypeConfig; @@ -336,6 +345,10 @@ bwd_result fmha_bwd_run(mode_enum mode, ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqlen_q_dev(mode == mode_enum::batch ? 0 + : seqlen_qs.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqlen_k_dev(mode == mode_enum::batch ? 0 + : seqlen_ks.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0); ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0); @@ -349,6 +362,13 @@ bwd_result fmha_bwd_run(mode_enum mode, do_buf.ToDevice(do_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_k.ToDevice(seqstart_k_host.data()); + if(mode == mode_enum::group) + { + std::vector seqlen_q_host(seqlen_qs.begin(), seqlen_qs.end()); + seqlen_q_dev.ToDevice(seqlen_q_host.data()); + std::vector seqlen_k_host(seqlen_ks.begin(), seqlen_ks.end()); + seqlen_k_dev.ToDevice(seqlen_k_host.data()); + } drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr); drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr); alibi_slope_buf.ToDevice(alibi_slope_host.data()); @@ -440,11 +460,14 @@ bwd_result fmha_bwd_run(mode_enum mode, } }(); + const void* seqlen_q_ptr_dev = use_qpadding ? seqlen_q_dev.GetDeviceBuffer() : nullptr; + const void* seqlen_k_ptr_dev = use_kpadding ? seqlen_k_dev.GetDeviceBuffer() : nullptr; + return fmha_bwd_args{q_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(), bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() - : bias_buf.GetDeviceBuffer(), + : bias_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(), lse_buf.GetDeviceBuffer(), do_buf.GetDeviceBuffer(), @@ -457,7 +480,8 @@ bwd_result fmha_bwd_run(mode_enum mode, dq_acc_buf.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(), seqstart_k.GetDeviceBuffer(), - nullptr, + seqlen_q_ptr_dev, + seqlen_k_ptr_dev, shape_seqlen_q, shape_seqlen_k, batch, @@ -472,7 +496,7 @@ bwd_result fmha_bwd_run(mode_enum mode, stride_k, stride_v, bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) - : stride_bias, + : stride_bias, stride_o, stride_randval, stride_do, diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 0703af71e3..de9fcc713f 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -330,11 +330,12 @@ fwd_result fmha_fwd_run(mode_enum mode, return fwd_result::invalid_args; } - std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) = + std::tie(seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads) = generate_missing_seqlens(mode, batch, seqlen_qs, seqlen_ks, + seqlen_qpads, seqlen_kpads, /*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0, need_append_kvcache, @@ -346,7 +347,13 @@ fwd_result fmha_fwd_run(mode_enum mode, std::cerr << "kpad must be greater than or equal to seqlen for k" << std::endl; return fwd_result::invalid_args; } + if(seqlen_qpads[wb] > 0 && seqlen_qpads[wb] < seqlen_qs[wb]) + { + std::cerr << "qpad must be greater than or equal to seqlen for q" << std::endl; + return fwd_result::invalid_args; + } } + // compute kvcache seqlen_k (before appending knew/vnew) auto cache_seqlen_ks = seqlen_ks; std::transform(cache_seqlen_ks.begin(), @@ -357,6 +364,7 @@ fwd_result fmha_fwd_run(mode_enum mode, #if 0 std::cout << "seqlen_qs: " << seqlen_qs << std::endl; std::cout << "seqlen_ks: " << seqlen_ks << std::endl; + std::cout << "seqlen_qpads: " << seqlen_qpads << std::endl; std::cout << "seqlen_kpads: " << seqlen_kpads << std::endl; std::cout << "cache_seqlen_ks: " << cache_seqlen_ks << std::endl; #endif @@ -514,9 +522,6 @@ fwd_result fmha_fwd_run(mode_enum mode, // host memory for storing all the tensor elements const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); - // logical(unpadded) total seqlen_q for group; batch uses fixed seqlen - const ck_tile::index_t shape_seqlen_q_lse = - (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); // physical(padded) total seqlen_q for group when s_qpad is provided; else use logical const ck_tile::index_t shape_seqlen_q = (mode == mode_enum::batch @@ -580,7 +585,7 @@ fwd_result fmha_fwd_run(mode_enum mode, // batch mode of lse data layout is [batch, nhead, seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q] ck_tile::HostTensor lse_host( - lse ? std::array{shape_batch, nhead, shape_seqlen_q_lse} + lse ? std::array{shape_batch, nhead, shape_seqlen_q} : std::array{1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor o_host( @@ -970,8 +975,8 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t nhead_stride_lse = shape_seqlen_q_lse; - const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; + const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments @@ -986,8 +991,8 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q_lse); - const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q_lse); + const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); + const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q); const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); @@ -1727,12 +1732,11 @@ fwd_result fmha_fwd_run(mode_enum mode, if(lse) { ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); - const ck_tile::index_t query_offset_lse = - (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); lse_host_result.ForEach([&](auto& self, auto idx) { - self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset_lse); + self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset); }); - + std::cout << "lse_host_result shape: " << shape_batch << ", " << nhead << ", " << shape_seqlen_q << std::endl; + cur_pass = ck_tile::check_err(lse_host_result, lse_host_ref, "LSE Error: Incorrect results!", diff --git a/example/ck_tile/01_fmha/utils.hpp b/example/ck_tile/01_fmha/utils.hpp index 7f44d87180..0303ded238 100644 --- a/example/ck_tile/01_fmha/utils.hpp +++ b/example/ck_tile/01_fmha/utils.hpp @@ -142,12 +142,14 @@ auto randints(ForwardIterator first, */ template std::tuple, + std::vector, std::vector, std::vector> generate_missing_seqlens(mode_enum mode, ck_tile::index_t batch, const std::vector& q_val, const std::vector& k_val, + const std::vector& q_pad_val, const std::vector& k_pad_val, ck_tile::index_t seqlen_k_min, bool need_append_kvcache, @@ -177,7 +179,7 @@ generate_missing_seqlens(mode_enum mode, return seqlen_ks; }(); auto s_kpad = std::vector(batch, -1); // TODO: batch not support k_padding - + auto s_qpad = std::vector(batch, -1); // s_k should be greater than or equal to seqlen_k_min if provided if(s_k.back() < seqlen_k_min) { @@ -187,13 +189,14 @@ generate_missing_seqlens(mode_enum mode, throw std::runtime_error(msg.str()); } - return std::make_tuple(s_q, s_k, s_kpad); + return std::make_tuple(s_q, s_k, s_qpad, s_kpad); } else { std::vector s_q; std::vector s_k; std::vector s_kpad; + std::vector s_qpad; ck_tile::index_t idx = 0; for(; idx < std::min(static_cast(q_val.size()), batch); ++idx) { @@ -205,9 +208,15 @@ generate_missing_seqlens(mode_enum mode, ? -1 : k_pad_val[std::min(idx, static_cast(k_pad_val.size()) - 1)]; + ck_tile::index_t qp = + q_pad_val.empty() + ? -1 + : q_pad_val[std::min(idx, static_cast(q_pad_val.size()) - 1)]; + s_q.push_back(q); s_k.push_back(k < 0 ? q : k); s_kpad.push_back(kp); + s_qpad.push_back(qp); // s_k should be greater than or equal to seqlen_k_min if(s_k.back() < seqlen_k_min) @@ -228,8 +237,9 @@ generate_missing_seqlens(mode_enum mode, s_q.insert(s_q.end(), rem_q.begin(), rem_q.end()); s_k.insert(s_k.end(), rem_k.begin(), rem_k.end()); s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back()); + s_qpad.insert(s_qpad.end(), batch - idx, s_qpad.back()); } - return std::make_tuple(s_q, s_k, s_kpad); + return std::make_tuple(s_q, s_k, s_qpad, s_kpad); } } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 980dfb06ae..e6cd9c0b7b 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -313,6 +313,7 @@ struct FmhaBwdDQDKDVKernel { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; + const int32_t* seqlen_q_ptr; const int32_t* seqlen_k_ptr; }; @@ -520,6 +521,7 @@ struct FmhaBwdDQDKDVKernel void* dq_acc_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, + const void* seqlen_q_ptr, const void* seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, @@ -594,6 +596,7 @@ struct FmhaBwdDQDKDVKernel {}, // placeholder for deterministic reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_q_ptr), reinterpret_cast(seqlen_k_ptr)}; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) @@ -738,7 +741,10 @@ struct FmhaBwdDQDKDVKernel // get real # queries & # keys under group mode const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + const ck_tile::index_t physical_seqlen_q = + adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + kargs.seqlen_q = kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[i_batch] : physical_seqlen_q; + if(kargs.seqlen_k_ptr != nullptr) { kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; @@ -749,6 +755,12 @@ struct FmhaBwdDQDKDVKernel kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; } + // skip if logical lengths are zero + if(kargs.seqlen_q == 0 || kargs.seqlen_k == 0) + { + return; + } + // # of required blocks is different in each groups, terminate unnecessary blocks // earlier if constexpr(!kUseQrQtrDorPipeline) @@ -1246,6 +1258,7 @@ struct FmhaBwdOGradDotOKernel struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs { const int32_t* seqstart_q_ptr; + const int32_t* seqlen_q_ptr; }; using Kargs = std:: @@ -1293,6 +1306,7 @@ struct FmhaBwdOGradDotOKernel void* d_ptr, float p_undrop, const void* seqstart_q_ptr, + const void* seqlen_q_ptr, ck_tile::index_t hdim_v, ck_tile::index_t stride_do, ck_tile::index_t stride_o, @@ -1311,7 +1325,8 @@ struct FmhaBwdOGradDotOKernel nhead_stride_do, nhead_stride_o, nhead_stride_d}, - reinterpret_cast(seqstart_q_ptr)}; + reinterpret_cast(seqstart_q_ptr), + reinterpret_cast(seqlen_q_ptr)}; return kargs; } @@ -1357,7 +1372,12 @@ struct FmhaBwdOGradDotOKernel // get real # queries & # keys under group mode const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + const ck_tile::index_t physical_seqlen_q = + adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + const ck_tile::index_t logical_seqlen_q = + kargs.seqlen_q_ptr ? static_cast(kargs.seqlen_q_ptr[i_batch]) + : physical_seqlen_q; + kargs.seqlen_q = logical_seqlen_q; // # of required blocks is different in each groups, terminate unnecessary blocks // earlier if(kargs.seqlen_q <= i_m0) @@ -1521,6 +1541,8 @@ struct FmhaBwdConvertQGradKernel { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; + const int32_t* seqlen_q_ptr; + const int32_t* seqlen_k_ptr; }; using Kargs = std::conditional_t(seqstart_q_ptr), - reinterpret_cast(seqstart_k_ptr)}; + reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_q_ptr), + reinterpret_cast(seqlen_k_ptr)}; if constexpr(kIsDeterministic) { @@ -1634,11 +1660,20 @@ struct FmhaBwdConvertQGradKernel // get real # queries & # keys under group mode const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + const ck_tile::index_t physical_seqlen_q = + adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + const ck_tile::index_t logical_seqlen_q = + kargs.seqlen_q_ptr ? static_cast(kargs.seqlen_q_ptr[i_batch]) + : physical_seqlen_q; + kargs.seqlen_q = logical_seqlen_q; if constexpr(kIsDeterministic) { const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; - kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + const ck_tile::index_t physical_seqlen_k = + adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + kargs.seqlen_k = kargs.seqlen_k_ptr + ? static_cast(kargs.seqlen_k_ptr[i_batch]) + : physical_seqlen_k; } // # of required blocks is different in each groups, terminate unnecessary blocks // earlier diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index dafe99febe..1b2554d0a2 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1137,8 +1137,8 @@ struct FmhaFwdKernel } if constexpr(kStoreLSE) { - // LSE stays indexed by unpadded starts - batch_offset_lse = query_start_unpadded; + // LSE follows the padded layout to stay consistent with other tensors + batch_offset_lse = query_start_padded; } if constexpr(kHasDropout) { @@ -1630,8 +1630,8 @@ struct FmhaFwdKernel batch_offset_bias = query_start_padded * kargs.stride_bias; } - // LSE layout is [nhead, total_seqlen], index by unpadded start - batch_offset_lse = query_start_unpadded; + // LSE layout is [nhead, total_seqlen] following the padded layout for Q/O + batch_offset_lse = query_start_padded; batch_offset_o = query_start_padded * kargs.stride_o; // get real # queries & # keys under group mode diff --git a/test/ck_tile/fmha/CMakeLists.txt b/test/ck_tile/fmha/CMakeLists.txt index ca7b7b6324..bbd9fc3d14 100644 --- a/test/ck_tile/fmha/CMakeLists.txt +++ b/test/ck_tile/fmha/CMakeLists.txt @@ -5,6 +5,11 @@ endif() set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances") set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances") + +add_gtest_executable(test_ck_tile_fmha_bwd_kernels test_fmha_bwd_kernel_padding.cpp) +target_link_libraries(test_ck_tile_fmha_bwd_kernels PRIVATE ${FMHA_BWD_INSTANCES}) + + set(TEST_NAME "test_ck_tile_fmha") function(add_gtest_fwd test_group) diff --git a/test/ck_tile/fmha/test_fmha_bwd.cpp b/test/ck_tile/fmha/test_fmha_bwd.cpp index 1279b98383..710069febe 100644 --- a/test/ck_tile/fmha/test_fmha_bwd.cpp +++ b/test/ck_tile/fmha/test_fmha_bwd.cpp @@ -77,6 +77,8 @@ void fmha_bwd_test(const FmhaBwdTestParam& param) nhead_k, {seqlen_q}, {seqlen_k}, + {-1}, + {-1}, hdim_q, hdim_v, i_perm, diff --git a/test/ck_tile/fmha/test_fmha_bwd_kernel_padding.cpp b/test/ck_tile/fmha/test_fmha_bwd_kernel_padding.cpp new file mode 100644 index 0000000000..fa3704a2a0 --- /dev/null +++ b/test/ck_tile/fmha/test_fmha_bwd_kernel_padding.cpp @@ -0,0 +1,707 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include "ck_tile/host.hpp" +#include "ck_tile/host/device_memory.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "example/ck_tile/01_fmha/fmha_bwd.hpp" +#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp" // for get_elimit +#include "gtest/gtest.h" + +namespace { + +using bf16 = ck_tile::bf16_t; +using ck_tile::DeviceMem; + +const ck_tile::stream_config kStreamConfig{ + nullptr, // stream_id_ + false, // time_kernel_ + 1, // log_level_ + 0, // cold_niters_ + 1, // nrepeat_ + true, // is_gpu_timer_ + false, // flush_cache_ + 1, // rotating_count_ +}; + +template +std::vector MakeVectorFromFunction(size_t count, std::function fn) +{ + std::vector data(count); + for(size_t i = 0; i < count; ++i) + { + data[i] = static_cast(fn(i)); + } + return data; +} + +template +std::vector ToFloatVector(const std::vector& src) +{ + std::vector dst(src.size()); + for(size_t i = 0; i < src.size(); ++i) + { + dst[i] = ck_tile::type_convert(src[i]); + } + return dst; +} + +template +std::vector CopyDeviceToHost(const ck_tile::DeviceMem& dev, size_t element_count) +{ + std::vector host(element_count); + if(element_count > 0) + { + dev.FromDevice(host.data()); + } + return host; +} + +float SentinelValue() { return -999.f; } + +} // namespace + +// Typed tests over {fp32, fp16, bf16} +template +class FmhaBwdKernelPaddingTyped : public ::testing::Test +{ +}; + +using KernelPaddingTypes = ::testing::Types; +TYPED_TEST_SUITE(FmhaBwdKernelPaddingTyped, KernelPaddingTypes); + +TYPED_TEST(FmhaBwdKernelPaddingTyped, OGradDotO_GroupPaddingRespectsLogicalLengths) +{ + constexpr ck_tile::index_t batch = 2; + constexpr ck_tile::index_t nhead = 1; + constexpr ck_tile::index_t hdim = 128; + constexpr ck_tile::index_t phys_rows0 = 8; + constexpr ck_tile::index_t phys_rows1 = 8; + constexpr ck_tile::index_t max_phys = phys_rows0; // both batches equal + + const std::vector seqstart_q_host{0, phys_rows0, phys_rows0 + phys_rows1}; + const std::vector seqlen_q_host{5, 3}; + + const ck_tile::index_t total_rows = seqstart_q_host.back(); + + // Types per config + using TypeConfig = FmhaBwdTypeConfig; + using OType = typename TypeConfig::ODataType; + using DOType = typename TypeConfig::OGradDataType; + using DType = typename TypeConfig::DDataType; // float under bf16 config + using AccType = typename TypeConfig::AccDataType; // float + + // Host tensors laid out as [b, h, s, d] with b=1, h=1 under group mode + ck_tile::HostTensor o_host({1, nhead, total_rows, hdim}); + ck_tile::HostTensor do_host({1, nhead, total_rows, hdim}); + ck_tile::HostTensor d_init_host({1, nhead, total_rows}); + + // Initialize O/dO with constants using FillConstant (no manual bf16 casts) + const float o_const = 0.25f; + const float do_const = 0.5f; + ck_tile::FillConstant{ck_tile::type_convert(o_const)}(o_host); + ck_tile::FillConstant{ck_tile::type_convert(do_const)}(do_host); + ck_tile::FillConstant{ck_tile::type_convert(SentinelValue())}(d_init_host); + + // Prepare expected D via runner-style CPU reference, sentinel elsewhere + std::vector expected(static_cast(total_rows), SentinelValue()); + for(ck_tile::index_t b = 0; b < batch; ++b) + { + const ck_tile::index_t start = seqstart_q_host[b]; + const ck_tile::index_t len = seqlen_q_host[b]; + for(ck_tile::index_t row = 0; row < len; ++row) + { + AccType acc = 0; + for(ck_tile::index_t c = 0; c < hdim; ++c) + { + // o_host/do_host are [1, nhead, s, d] + const auto o_val = ck_tile::type_convert(o_host(0, 0, start + row, c)); + const auto do_val = ck_tile::type_convert(do_host(0, 0, start + row, c)); + acc += do_val * o_val; + } + expected[start + row] = ck_tile::type_convert(acc); + } + } + std::vector sentinel_ref(static_cast(total_rows), SentinelValue()); + + // Device buffers + ck_tile::DeviceMem o_dev(o_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem do_dev(do_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_dev(d_init_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem seqstart_dev(seqstart_q_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqlen_dev(seqlen_q_host.size() * sizeof(int32_t)); + + o_dev.ToDevice(o_host.data()); + do_dev.ToDevice(do_host.data()); + d_dev.ToDevice(d_init_host.data()); + seqstart_dev.ToDevice(seqstart_q_host.data()); + seqlen_dev.ToDevice(seqlen_q_host.data()); + + fmha_bwd_args args{}; + args.q_ptr = nullptr; + args.k_ptr = nullptr; + args.v_ptr = nullptr; + args.bias_ptr = nullptr; + args.o_ptr = o_dev.GetDeviceBuffer(); + args.lse_ptr = nullptr; + args.do_ptr = do_dev.GetDeviceBuffer(); + args.d_ptr = d_dev.GetDeviceBuffer(); + args.rand_val_ptr = nullptr; + args.dq_ptr = nullptr; + args.dk_ptr = nullptr; + args.dv_ptr = nullptr; + args.dbias_ptr = nullptr; + args.dq_acc_ptr = nullptr; + args.seqstart_q_ptr = seqstart_dev.GetDeviceBuffer(); + args.seqstart_k_ptr = nullptr; + args.seqlen_k_ptr = nullptr; + args.seqlen_q_ptr = seqlen_dev.GetDeviceBuffer(); + args.seqlen_q = 0; + args.seqlen_k = 0; + args.batch = batch; + args.max_seqlen_q = max_phys; + args.max_seqlen_k = 0; + args.hdim_q = hdim; + args.hdim_v = hdim; + args.nhead_q = nhead; + args.nhead_k = nhead; + args.scale = 1.0f; + args.stride_q = 0; + args.stride_k = 0; + args.stride_v = 0; + args.stride_bias = 0; + args.stride_o = hdim; + args.stride_randval = 0; + args.stride_do = hdim; + args.stride_dq_acc = 0; + args.stride_dq = 0; + args.stride_dk = 0; + args.stride_dv = 0; + args.stride_dbias = 0; + args.nhead_stride_q = 0; + args.nhead_stride_k = 0; + args.nhead_stride_v = 0; + args.nhead_stride_bias = 0; + args.nhead_stride_o = max_phys * hdim; + args.nhead_stride_randval = 0; + args.nhead_stride_do = max_phys * hdim; + args.nhead_stride_lsed = max_phys; + args.nhead_stride_dq_acc = 0; + args.nhead_stride_dq = 0; + args.nhead_stride_dk = 0; + args.nhead_stride_dv = 0; + args.nhead_stride_dbias = 0; + args.batch_stride_q = 0; + args.batch_stride_k = 0; + args.batch_stride_v = 0; + args.batch_stride_bias = 0; + args.batch_stride_o = 0; + args.batch_stride_randval = 0; + args.batch_stride_do = 0; + args.batch_stride_lsed = 0; + args.batch_stride_dq_acc = 0; + args.batch_stride_dq = 0; + args.batch_stride_dk = 0; + args.batch_stride_dv = 0; + args.batch_stride_dbias = 0; + args.split_stride_dq_acc = 0; + args.window_size_left = -1; + args.window_size_right = 0; + args.mask_type = static_cast(mask_enum::no_mask); + args.p_drop = 0.0f; + args.p_undrop = 1.0f; + args.drop_seed_offset = std::make_pair(uint64_t{0}, uint64_t{0}); + + using DotTileTraits = ck_tile::TileFmhaBwdOGradDotOTraits; + using DotProblem = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename TypeConfig::ODataType, + typename TypeConfig::OGradDataType, + typename TypeConfig::DDataType, + 64, + hdim, + true, + DotTileTraits>; + using DotPipeline = ck_tile::BlockFmhaBwdOGradDotO; + using DotKernel = ck_tile::FmhaBwdOGradDotOKernel; + + auto [dot_kargs, dot_grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(args); + const dim3 dot_blocks = DotKernel::BlockSize(); + constexpr ck_tile::index_t kDotBlockPerCu = DotKernel::kBlockPerCu; + auto dot_kernel = ck_tile::make_kernel( + DotKernel{}, dot_grids, dot_blocks, 0, dot_kargs); + dot_kernel(kStreamConfig); + ASSERT_EQ(hipDeviceSynchronize(), hipSuccess); + + auto d_result_host = CopyDeviceToHost(d_dev, total_rows); + + auto [rtol_doto, atol_doto] = get_elimit(hdim, hdim); + for(size_t i = 0; i < d_result_host.size(); ++i) + { + SCOPED_TRACE(::testing::Message() << "index=" << i); + if(std::fabs(expected[i] - sentinel_ref[i]) < 1e-6f) + { + EXPECT_FLOAT_EQ(d_result_host[i], sentinel_ref[i]); + } + else + { + EXPECT_NEAR(d_result_host[i], expected[i], static_cast(atol_doto)); + } + } +} + +TYPED_TEST(FmhaBwdKernelPaddingTyped, OGradDotO_VariedPhysicalAndZeroLogical) +{ + constexpr ck_tile::index_t batch = 3; + constexpr ck_tile::index_t nhead = 1; + constexpr ck_tile::index_t hdim = 64; + constexpr ck_tile::index_t phys_r0 = 5; + constexpr ck_tile::index_t phys_r1 = 7; + constexpr ck_tile::index_t phys_r2 = 4; + constexpr ck_tile::index_t max_phys = phys_r1; + + const std::vector seqstart_q_host{0, phys_r0, phys_r0 + phys_r1, phys_r0 + phys_r1 + phys_r2}; + const std::vector seqlen_q_host{3, 0, 4}; + const ck_tile::index_t total_rows = seqstart_q_host.back(); + + using TypeConfig = FmhaBwdTypeConfig; + using OType = typename TypeConfig::ODataType; + using DOType = typename TypeConfig::OGradDataType; + using DType = typename TypeConfig::DDataType; + + ck_tile::HostTensor o_host({1, nhead, total_rows, hdim}); + ck_tile::HostTensor do_host({1, nhead, total_rows, hdim}); + ck_tile::HostTensor d_init_host({1, nhead, total_rows}); + ck_tile::FillConstant{ck_tile::type_convert(1.0f)}(o_host); + ck_tile::FillConstant{ck_tile::type_convert(2.0f)}(do_host); + ck_tile::FillConstant{ck_tile::type_convert(SentinelValue())}(d_init_host); + + std::vector expected(static_cast(total_rows), SentinelValue()); + const float dot = 2.0f * 1.0f * static_cast(hdim); + for(ck_tile::index_t b = 0; b < batch; ++b) + { + const ck_tile::index_t start = seqstart_q_host[b]; + const ck_tile::index_t len = seqlen_q_host[b]; + for(ck_tile::index_t row = 0; row < len; ++row) expected[start + row] = dot; + } + std::vector sentinel_ref(static_cast(total_rows), SentinelValue()); + + ck_tile::DeviceMem o_dev(o_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem do_dev(do_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_dev(d_init_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem seqstart_dev(seqstart_q_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqlen_dev(seqlen_q_host.size() * sizeof(int32_t)); + o_dev.ToDevice(o_host.data()); + do_dev.ToDevice(do_host.data()); + d_dev.ToDevice(d_init_host.data()); + seqstart_dev.ToDevice(seqstart_q_host.data()); + seqlen_dev.ToDevice(seqlen_q_host.data()); + + fmha_bwd_args args{}; + args.o_ptr = o_dev.GetDeviceBuffer(); + args.do_ptr = do_dev.GetDeviceBuffer(); + args.d_ptr = d_dev.GetDeviceBuffer(); + args.seqstart_q_ptr = seqstart_dev.GetDeviceBuffer(); + args.seqlen_q_ptr = seqlen_dev.GetDeviceBuffer(); + args.batch = batch; + args.max_seqlen_q = max_phys; + args.hdim_v = hdim; + args.nhead_q = nhead; + args.nhead_k = nhead; + args.stride_o = hdim; + args.stride_do = hdim; + args.nhead_stride_o = max_phys * hdim; + args.nhead_stride_do = max_phys * hdim; + args.nhead_stride_lsed = max_phys; + args.p_undrop = 1.0f; + + using DotTileTraits = ck_tile::TileFmhaBwdOGradDotOTraits; + using DotProblem = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename TypeConfig::ODataType, + typename TypeConfig::OGradDataType, + typename TypeConfig::DDataType, + 64, + hdim, + true, + DotTileTraits>; + using DotPipeline = ck_tile::BlockFmhaBwdOGradDotO; + using DotKernel = ck_tile::FmhaBwdOGradDotOKernel; + + auto [dot_kargs, dot_grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(args); + const dim3 dot_blocks = DotKernel::BlockSize(); + constexpr ck_tile::index_t kDotBlockPerCu = DotKernel::kBlockPerCu; + auto dot_kernel = ck_tile::make_kernel(DotKernel{}, dot_grids, dot_blocks, 0, dot_kargs); + dot_kernel(kStreamConfig); + ASSERT_EQ(hipDeviceSynchronize(), hipSuccess); + + auto d_result_host = CopyDeviceToHost(d_dev, total_rows); + auto [rtol, atol] = get_elimit(hdim, hdim); + for(size_t i = 0; i < d_result_host.size(); ++i) + { + SCOPED_TRACE(::testing::Message() << "index=" << i); + if(std::fabs(expected[i] - sentinel_ref[i]) < 1e-6f) + EXPECT_FLOAT_EQ(d_result_host[i], sentinel_ref[i]); + else + EXPECT_NEAR(d_result_host[i], expected[i], static_cast(atol)); + } +} + +TYPED_TEST(FmhaBwdKernelPaddingTyped, OGradDotO_VariedPhysical_NoLogicalPtr) +{ + constexpr ck_tile::index_t batch = 3; + constexpr ck_tile::index_t nhead = 1; + constexpr ck_tile::index_t hdim = 64; + constexpr ck_tile::index_t phys_r0 = 5; + constexpr ck_tile::index_t phys_r1 = 7; + constexpr ck_tile::index_t phys_r2 = 4; + constexpr ck_tile::index_t max_phys = phys_r1; + + const std::vector seqstart_q_host{0, phys_r0, phys_r0 + phys_r1, phys_r0 + phys_r1 + phys_r2}; + const ck_tile::index_t total_rows = seqstart_q_host.back(); + + using TypeConfig = FmhaBwdTypeConfig; + using OType = typename TypeConfig::ODataType; + using DOType = typename TypeConfig::OGradDataType; + using DType = typename TypeConfig::DDataType; + + ck_tile::HostTensor o_host({1, nhead, total_rows, hdim}); + ck_tile::HostTensor do_host({1, nhead, total_rows, hdim}); + ck_tile::HostTensor d_init_host({1, nhead, total_rows}); + ck_tile::FillConstant{ck_tile::type_convert(1.0f)}(o_host); + ck_tile::FillConstant{ck_tile::type_convert(2.0f)}(do_host); + ck_tile::FillConstant{ck_tile::type_convert(SentinelValue())}(d_init_host); + + std::vector expected(static_cast(total_rows), SentinelValue()); + const float dot = 2.0f * 1.0f * static_cast(hdim); + // seqlen_q_ptr is null; logical lengths equal physical lengths per group + for(int r = 0; r < phys_r0; ++r) expected[0 + r] = dot; + for(int r = 0; r < phys_r1; ++r) expected[phys_r0 + r] = dot; + for(int r = 0; r < phys_r2; ++r) expected[phys_r0 + phys_r1 + r] = dot; + std::vector sentinel_ref(static_cast(total_rows), SentinelValue()); + + ck_tile::DeviceMem o_dev(o_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem do_dev(do_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_dev(d_init_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem seqstart_dev(seqstart_q_host.size() * sizeof(int32_t)); + o_dev.ToDevice(o_host.data()); + do_dev.ToDevice(do_host.data()); + d_dev.ToDevice(d_init_host.data()); + seqstart_dev.ToDevice(seqstart_q_host.data()); + + fmha_bwd_args args{}; + args.o_ptr = o_dev.GetDeviceBuffer(); + args.do_ptr = do_dev.GetDeviceBuffer(); + args.d_ptr = d_dev.GetDeviceBuffer(); + args.seqstart_q_ptr = seqstart_dev.GetDeviceBuffer(); + args.seqlen_q_ptr = nullptr; // no logical len ptr + args.batch = batch; + args.max_seqlen_q = max_phys; + args.hdim_v = hdim; + args.nhead_q = nhead; + args.nhead_k = nhead; + args.stride_o = hdim; + args.stride_do = hdim; + args.nhead_stride_o = max_phys * hdim; + args.nhead_stride_do = max_phys * hdim; + args.nhead_stride_lsed = max_phys; + args.p_undrop = 1.0f; + + using DotTileTraits = ck_tile::TileFmhaBwdOGradDotOTraits; + using DotProblem = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename TypeConfig::ODataType, + typename TypeConfig::OGradDataType, + typename TypeConfig::DDataType, + 64, + hdim, + true, + DotTileTraits>; + using DotPipeline = ck_tile::BlockFmhaBwdOGradDotO; + using DotKernel = ck_tile::FmhaBwdOGradDotOKernel; + + auto [dot_kargs, dot_grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(args); + const dim3 dot_blocks = DotKernel::BlockSize(); + constexpr ck_tile::index_t kDotBlockPerCu = DotKernel::kBlockPerCu; + auto dot_kernel = ck_tile::make_kernel(DotKernel{}, dot_grids, dot_blocks, 0, dot_kargs); + dot_kernel(kStreamConfig); + ASSERT_EQ(hipDeviceSynchronize(), hipSuccess); + + auto d_result_host = CopyDeviceToHost(d_dev, total_rows); + auto [rtol, atol] = get_elimit(hdim, hdim); + for(size_t i = 0; i < d_result_host.size(); ++i) + { + SCOPED_TRACE(::testing::Message() << "index=" << i); + if(std::fabs(expected[i] - sentinel_ref[i]) < 1e-6f) + EXPECT_FLOAT_EQ(d_result_host[i], sentinel_ref[i]); + else + EXPECT_NEAR(d_result_host[i], expected[i], static_cast(atol)); + } +} + +TYPED_TEST(FmhaBwdKernelPaddingTyped, ConvertQGrad_GroupPaddingAndZeroLength) +{ + constexpr ck_tile::index_t batch = 3; + constexpr ck_tile::index_t nhead = 1; + constexpr ck_tile::index_t hdim = 128; + + const std::vector seqstart_q_host{0, 6, 6, 10}; // physical lengths: 6,0,4 + const std::vector seqlen_q_host{4, 0, 3}; + const std::vector seqstart_k_host{0, 7, 15, 18}; + const std::vector seqlen_k_host{5, 8, 3}; + + const ck_tile::index_t total_rows_q = seqstart_q_host.back(); + + using TypeConfigC = FmhaBwdTypeConfig; + using AccType = typename TypeConfigC::AccDataType; // float + using QGradType = typename TypeConfigC::QGradDataType; // bf16 + + ck_tile::HostTensor dq_acc_host({1, nhead, total_rows_q, hdim}); + ck_tile::HostTensor dq_host_init({1, nhead, total_rows_q, hdim}); + + const float dq_acc_const = 1.25f; + ck_tile::FillConstant{ck_tile::type_convert(dq_acc_const)}(dq_acc_host); + ck_tile::FillConstant{ck_tile::type_convert(SentinelValue())}(dq_host_init); + + const float dq_sentinel_val = ck_tile::type_convert( + ck_tile::type_convert(SentinelValue())); + std::vector dq_sentinel_ref(static_cast(total_rows_q * hdim), + dq_sentinel_val); + std::vector expected = dq_sentinel_ref; + for(ck_tile::index_t b = 0; b < batch; ++b) + { + const ck_tile::index_t q_start = seqstart_q_host[b]; + const ck_tile::index_t q_len = seqlen_q_host[b]; + for(ck_tile::index_t row = 0; row < q_len; ++row) + { + for(ck_tile::index_t c = 0; c < hdim; ++c) + { + const size_t idx = (q_start + row) * hdim + c; + // dq_acc_host is [1, nhead, s, d] + expected[idx] = ck_tile::type_convert(dq_acc_host(0, 0, q_start + row, c)); + } + } + } + + ck_tile::DeviceMem dq_acc_dev(dq_acc_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dq_dev(dq_host_init.get_element_space_size_in_bytes()); + ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqlen_q_dev(seqlen_q_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqlen_k_dev(seqlen_k_host.size() * sizeof(int32_t)); + + dq_acc_dev.ToDevice(dq_acc_host.data()); + dq_dev.ToDevice(dq_host_init.data()); + seqstart_q.ToDevice(seqstart_q_host.data()); + seqstart_k.ToDevice(seqstart_k_host.data()); + seqlen_q_dev.ToDevice(seqlen_q_host.data()); + seqlen_k_dev.ToDevice(seqlen_k_host.data()); + + fmha_bwd_args args{}; + args.dq_acc_ptr = dq_acc_dev.GetDeviceBuffer(); + args.dq_ptr = dq_dev.GetDeviceBuffer(); + args.seqstart_q_ptr = seqstart_q.GetDeviceBuffer(); + args.seqstart_k_ptr = seqstart_k.GetDeviceBuffer(); + args.seqlen_q_ptr = seqlen_q_dev.GetDeviceBuffer(); + args.seqlen_k_ptr = seqlen_k_dev.GetDeviceBuffer(); + args.batch = batch; + args.nhead_q = nhead; + args.nhead_k = nhead; + args.hdim_q = hdim; + args.hdim_v = hdim; + args.max_seqlen_q = 6; + args.max_seqlen_k = 8; + args.stride_dq_acc = hdim; + args.stride_dq = hdim; + args.nhead_stride_dq_acc = hdim * args.max_seqlen_q; + args.nhead_stride_dq = hdim * args.max_seqlen_q; + args.nhead_stride_q = 0; + args.nhead_stride_k = 0; + args.nhead_stride_v = 0; + args.nhead_stride_o = 0; + args.batch_stride_dq_acc = 0; + args.batch_stride_dq = 0; + args.split_stride_dq_acc = args.max_seqlen_q * args.stride_dq_acc; + args.window_size_left = -1; + args.window_size_right = 0; + args.mask_type = static_cast(mask_enum::no_mask); + args.p_drop = 0.0f; + args.p_undrop = 1.0f; + args.drop_seed_offset = std::make_pair(uint64_t{0}, uint64_t{0}); + + using TypeConfig = FmhaBwdTypeConfig; + using ConvertTileTraits = ck_tile::TileFmhaBwdConvertQGradTraits; + using ConvertProblem = ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename TypeConfig::AccDataType, + typename TypeConfig::QGradDataType, + 256, + 64, + 0, + hdim, + true, + false, + ConvertTileTraits>; + using ConvertPipeline = ck_tile::BlockFmhaBwdConvertQGrad; + using ConvertKernel = ck_tile::FmhaBwdConvertQGradKernel; + + auto [convert_kargs, convert_grids] = + fmha_bwd_convert_dq_create_kargs_and_grids(args); + const dim3 convert_blocks = ConvertKernel::BlockSize(); + constexpr ck_tile::index_t kConvertBlockPerCu = ConvertKernel::kBlockPerCu; + auto convert_kernel = ck_tile::make_kernel( + ConvertKernel{}, convert_grids, convert_blocks, 0, convert_kargs); + convert_kernel(kStreamConfig); + ASSERT_EQ(hipDeviceSynchronize(), hipSuccess); + + using QGradOutT = typename TypeConfigC::QGradDataType; + auto dq_result_host_t = CopyDeviceToHost(dq_dev, total_rows_q * hdim); + auto dq_result_host = ToFloatVector(dq_result_host_t); + + auto [rtol_gpad, atol_gpad] = get_elimit(hdim, hdim); + for(size_t i = 0; i < dq_result_host.size(); ++i) + { + SCOPED_TRACE(::testing::Message() << "index=" << i); + if(std::fabs(expected[i] - dq_sentinel_ref[i]) < 1e-6f) + { + EXPECT_FLOAT_EQ(dq_result_host[i], dq_sentinel_ref[i]); + } + else + { + EXPECT_NEAR(dq_result_host[i], expected[i], static_cast(atol_gpad)); + } + } +} + +TYPED_TEST(FmhaBwdKernelPaddingTyped, ConvertQGrad_DeterministicPaddingUsesLogicalLengths) +{ + constexpr ck_tile::index_t batch = 1; + constexpr ck_tile::index_t nhead = 1; + constexpr ck_tile::index_t hdim = 128; + constexpr ck_tile::index_t phys_rows = 8; + constexpr ck_tile::index_t logical_rows = 5; + constexpr ck_tile::index_t phys_k = 24; + constexpr ck_tile::index_t logical_k = 20; + constexpr ck_tile::index_t kN0 = 16; + constexpr ck_tile::index_t nsplits = (logical_k + kN0 - 1) / kN0; + + const std::vector seqstart_q_host{0, phys_rows}; + const std::vector seqlen_q_host{logical_rows}; + const std::vector seqstart_k_host{0, phys_k}; + const std::vector seqlen_k_host{logical_k}; + const ck_tile::index_t total_rows_q = seqstart_q_host.back(); + + using TypeConfigD = FmhaBwdTypeConfig; + using AccTypeDet = typename TypeConfigD::AccDataType; // float + using QGradTypeD = typename TypeConfigD::QGradDataType; // bf16 + + ck_tile::HostTensor dq_acc_host({nsplits, 1, nhead, phys_rows, hdim}); + dq_acc_host.ForEach([&](auto& self, auto idx) { + const float s = static_cast(idx[0]); + // Use split-dependent constant to avoid per-element variance and rounding interplay + self(idx) = ck_tile::type_convert(1.0f + 0.1f * s); + }); + + const float dq_sentinel_val_det = ck_tile::type_convert( + ck_tile::type_convert(SentinelValue())); + std::vector expected(total_rows_q * hdim, dq_sentinel_val_det); + // Expected is the sum over splits of the constant (1.0 + 0.1*s) + for(ck_tile::index_t row = 0; row < logical_rows; ++row) + for(ck_tile::index_t c = 0; c < hdim; ++c) + { + float acc = 0.f; + for(ck_tile::index_t s = 0; s < nsplits; ++s) + { + acc += (1.0f + 0.1f * static_cast(s)); + } + expected[row * hdim + c] = acc; + } + + ck_tile::HostTensor dq_init({1, nhead, total_rows_q, hdim}); + ck_tile::FillConstant{ck_tile::type_convert(SentinelValue())}(dq_init); + + DeviceMem dq_acc_dev(dq_acc_host.get_element_space_size_in_bytes()); + DeviceMem dq_dev(dq_init.get_element_space_size_in_bytes()); + DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); + DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + DeviceMem seqlen_q_dev(seqlen_q_host.size() * sizeof(int32_t)); + DeviceMem seqlen_k_dev(seqlen_k_host.size() * sizeof(int32_t)); + + dq_acc_dev.ToDevice(dq_acc_host.data()); + dq_dev.ToDevice(dq_init.data()); + seqstart_q.ToDevice(seqstart_q_host.data()); + seqstart_k.ToDevice(seqstart_k_host.data()); + seqlen_q_dev.ToDevice(seqlen_q_host.data()); + seqlen_k_dev.ToDevice(seqlen_k_host.data()); + + fmha_bwd_args args{}; + args.dq_acc_ptr = dq_acc_dev.GetDeviceBuffer(); + args.dq_ptr = dq_dev.GetDeviceBuffer(); + args.seqstart_q_ptr = seqstart_q.GetDeviceBuffer(); + args.seqstart_k_ptr = seqstart_k.GetDeviceBuffer(); + args.seqlen_q_ptr = seqlen_q_dev.GetDeviceBuffer(); + args.seqlen_k_ptr = seqlen_k_dev.GetDeviceBuffer(); + args.batch = batch; + args.nhead_q = nhead; + args.nhead_k = nhead; + args.hdim_q = hdim; + args.hdim_v = hdim; + args.max_seqlen_q = phys_rows; + args.max_seqlen_k = phys_k; + args.stride_dq_acc = hdim; + args.stride_dq = hdim; + args.nhead_stride_dq_acc = phys_rows * hdim; + args.nhead_stride_dq = phys_rows * hdim; + args.split_stride_dq_acc = phys_rows * hdim; + args.window_size_left = -1; + args.window_size_right = 0; + args.mask_type = static_cast(mask_enum::no_mask); + args.p_drop = 0.0f; + args.p_undrop = 1.0f; + args.drop_seed_offset = std::make_pair(uint64_t{0}, uint64_t{0}); + + using TypeConfig = FmhaBwdTypeConfig; + using TileTraitsDet = ck_tile::TileFmhaBwdConvertQGradTraits; + using PipelineProblemDet = ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename TypeConfig::AccDataType, + typename TypeConfig::QGradDataType, + 256, + 64, + kN0, + hdim, + true, + true, + TileTraitsDet>; + using PipelineDet = ck_tile::BlockFmhaBwdConvertQGrad; + using ConvertKernelDet = ck_tile::FmhaBwdConvertQGradKernel; + + auto [convert_kargs, convert_grids] = + fmha_bwd_convert_dq_create_kargs_and_grids(args); + const dim3 convert_blocks = ConvertKernelDet::BlockSize(); + constexpr ck_tile::index_t kConvertBlockPerCu = ConvertKernelDet::kBlockPerCu; + auto convert_kernel = ck_tile::make_kernel( + ConvertKernelDet{}, convert_grids, convert_blocks, 0, convert_kargs); + convert_kernel(kStreamConfig); + ASSERT_EQ(hipDeviceSynchronize(), hipSuccess); + + using QGradOutTD = typename TypeConfigD::QGradDataType; + auto dq_result_host_t = CopyDeviceToHost(dq_dev, total_rows_q * hdim); + auto dq_result_host = ToFloatVector(dq_result_host_t); + + const float dq_sentinel_val2 = dq_sentinel_val_det; + + auto [rtol_det, atol_det] = get_elimit(hdim, hdim); + for(size_t i = 0; i < dq_result_host.size(); ++i) + { + SCOPED_TRACE(::testing::Message() << "index=" << i); + if(std::fabs(expected[i] - dq_sentinel_val2) < 1e-6f) + { + EXPECT_FLOAT_EQ(dq_result_host[i], dq_sentinel_val2); + } + else + { + EXPECT_NEAR(dq_result_host[i], expected[i], static_cast(atol_det)); + } + } +}