From 238fde80a6ffc80140dc611f16f8384e2df0c16a Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 6 Jun 2024 10:45:59 +0000 Subject: [PATCH] Fix o_acc memory error --- example/ck_tile/01_fmha/fmha_fwd.cpp | 44 +++++++++++++++------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index a278e31a6e..d9616c101e 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -387,25 +387,26 @@ bool run(const ck_tile::ArgParser& arg_parser) : std::array{1, 1}); ck_tile::HostTensor lse_acc_host( - 1 < num_splits - ? std::array{num_splits, batch, nhead, max_seqlen_q} - : std::array{1, 1, 1, 1}); - APP_DEBUG_STMTS { - std::cout << "lse_acc_host shape: " << num_splits << ", " << batch << ", " - << nhead << ", " << max_seqlen_q << std::endl; + 1 < num_splits ? std::array{num_splits, batch, nhead, max_seqlen_q} + : std::array{1, 1, 1, 1}); + APP_DEBUG_STMTS + { + std::cout << "lse_acc_host shape: " << num_splits << ", " << batch << ", " << nhead << ", " + << max_seqlen_q << std::endl; } ck_tile::HostTensor o_acc_host( 1 < num_splits ? std::array{num_splits, - shape_batch, + batch, nhead, - shape_seqlen_q, + max_seqlen_q, hdim_v} : std::array{1, 1, 1, 1, 1}); - APP_DEBUG_STMTS{ - std::cout << "o_acc_host shape: " << num_splits << ", " << shape_batch << ", " - << nhead << ", " << shape_seqlen_q << ", " << hdim_v << std::endl; + APP_DEBUG_STMTS + { + std::cout << "o_acc_host shape: " << num_splits << ", " << shape_batch << ", " << nhead + << ", " << shape_seqlen_q << ", " << hdim_v << std::endl; } - // self define lse data layout as [shape_batch, nhead, shape_seqlen_q] + // self define lse data layout as [batch, nhead, max_seqlen_q] ck_tile::HostTensor lse_host( lse ? std::array{batch, nhead, max_seqlen_q} : std::array{1, 1, 1} /* dummy shape for simplifying code */); @@ -525,7 +526,8 @@ bool run(const ck_tile::ArgParser& arg_parser) : (std::string("(") + std::to_string(seqlen_kpads[0]) + ")")) << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias << ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant - << ", mask:" << mask << ", v:" << vlayout << std::flush; + << ", mask:" << mask << ", v:" << vlayout << ", num_splits: " << num_splits + << std::flush; auto fmha_traits = fmha_fwd_traits{hdim_q, hdim_v, @@ -990,7 +992,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // clang-format on auto [rtol, atol] = get_elimit(init_method); - bool cur_pass = true; + bool cur_pass = true; if(lse) { ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); @@ -998,11 +1000,11 @@ bool run(const ck_tile::ArgParser& arg_parser) [&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); }); cur_pass = ck_tile::check_err(lse_host_result, - lse_host_ref, - "LSE Error: Incorrect results!", - rtol, - atol, - /* allow_infinity_ref = */ true); + lse_host_ref, + "LSE Error: Incorrect results!", + rtol, + atol, + /* allow_infinity_ref = */ true); // if (cur_pass) std::cout << "LSE pass" << std::endl; pass &= cur_pass; if(!cur_pass) @@ -1014,7 +1016,8 @@ bool run(const ck_tile::ArgParser& arg_parser) << "\tseqstart_k: " << seqstart_k_host << std::endl; } } - cur_pass = ck_tile::check_err( + #if 1 + cur_pass = ck_tile::check_err( o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); // if (cur_pass) std::cout << "OUT pass" << std::endl; pass &= cur_pass; @@ -1028,6 +1031,7 @@ bool run(const ck_tile::ArgParser& arg_parser) break; } + #endif } std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;