Fix o_acc memory error

This commit is contained in:
PoYen, Chen
2024-06-06 10:45:59 +00:00
parent ffd2768000
commit 238fde80a6

View File

@@ -387,25 +387,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
: std::array<ck_tile::index_t, 2>{1, 1});
ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits
? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
APP_DEBUG_STMTS {
std::cout << "lse_acc_host shape: " << num_splits << ", " << batch << ", "
<< nhead << ", " << max_seqlen_q << std::endl;
1 < num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
APP_DEBUG_STMTS
{
std::cout << "lse_acc_host shape: " << num_splits << ", " << batch << ", " << nhead << ", "
<< max_seqlen_q << std::endl;
}
ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits ? std::array<ck_tile::index_t, 5>{num_splits,
shape_batch,
batch,
nhead,
shape_seqlen_q,
max_seqlen_q,
hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
APP_DEBUG_STMTS{
std::cout << "o_acc_host shape: " << num_splits << ", " << shape_batch << ", "
<< nhead << ", " << shape_seqlen_q << ", " << hdim_v << std::endl;
APP_DEBUG_STMTS
{
std::cout << "o_acc_host shape: " << num_splits << ", " << shape_batch << ", " << nhead
<< ", " << shape_seqlen_q << ", " << hdim_v << std::endl;
}
// self define lse data layout as [shape_batch, nhead, shape_seqlen_q]
// self define lse data layout as [batch, nhead, max_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
@@ -525,7 +526,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
: (std::string("(") + std::to_string(seqlen_kpads[0]) + ")"))
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
<< ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant
<< ", mask:" << mask << ", v:" << vlayout << std::flush;
<< ", mask:" << mask << ", v:" << vlayout << ", num_splits: " << num_splits
<< std::flush;
auto fmha_traits = fmha_fwd_traits{hdim_q,
hdim_v,
@@ -990,7 +992,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// clang-format on
auto [rtol, atol] = get_elimit<DataType>(init_method);
bool cur_pass = true;
bool cur_pass = true;
if(lse)
{
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
@@ -998,11 +1000,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); });
cur_pass = ck_tile::check_err(lse_host_result,
lse_host_ref,
"LSE Error: Incorrect results!",
rtol,
atol,
/* allow_infinity_ref = */ true);
lse_host_ref,
"LSE Error: Incorrect results!",
rtol,
atol,
/* allow_infinity_ref = */ true);
// if (cur_pass) std::cout << "LSE pass" << std::endl;
pass &= cur_pass;
if(!cur_pass)
@@ -1014,7 +1016,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
}
}
cur_pass = ck_tile::check_err(
#if 1
cur_pass = ck_tile::check_err(
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
// if (cur_pass) std::cout << "OUT pass" << std::endl;
pass &= cur_pass;
@@ -1028,6 +1031,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
break;
}
#endif
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;