Remove debug statements in example

This commit is contained in:
PoYen, Chen
2024-06-11 14:02:53 +00:00
parent 912a6cb2ea
commit 5efb80347e

View File

@@ -162,10 +162,7 @@ float fmha_fwd_dispatch(fmha_fwd_traits traits,
fmha_fwd_args args,
const ck_tile::stream_config& config)
{
#if defined(ALWAYS_INVOKE_SPLITKV_KERNEL)
return fmha_fwd_splitkv(traits, args, config);
#else
if(1 < args.num_splits)
if(1 < args.num_splits && args.p_drop == 0.0f)
{
return fmha_fwd_splitkv(traits, args, config);
}
@@ -173,8 +170,7 @@ float fmha_fwd_dispatch(fmha_fwd_traits traits,
{
return fmha_fwd(traits, args, config);
}
#endif
};
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
@@ -385,22 +381,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
: std::array<ck_tile::index_t, 2>{1, 1});
ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 <= num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
APP_DEBUG_STMTS
{
std::cout << "lse_acc_host shape: " << num_splits << ", " << batch << ", " << nhead << ", "
<< max_seqlen_q << std::endl;
}
1 < num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host(
1 <= num_splits
1 < num_splits
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
APP_DEBUG_STMTS
{
std::cout << "o_acc_host shape: " << num_splits << ", " << shape_batch << ", " << nhead
<< ", " << shape_seqlen_q << ", " << hdim_v << std::endl;
}
// self define lse data layout as [batch, nhead, max_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}
@@ -669,7 +656,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
o_buf.FromDevice(o_host.data());
lse_buf.FromDevice(lse_host.data());
lse_acc_buf.FromDevice(lse_acc_host.data());
randval_buf.FromDevice(randval_host.data());
float p_undrop = 1.0 - p_drop;
uint8_t p_undrop_in_uint8_t =
@@ -678,124 +664,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
bool pass = true;
APP_DEBUG_STMTS
{
printf("\n");
printf("[POYENC][HOST] lse shape: %d, %d, %d, %d\n",
num_splits,
shape_batch,
nhead,
shape_seqlen_q);
}
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
APP_DEBUG_STMTS { printf("[POYENC][HOST] wb: %d\n", wb); }
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
APP_DEBUG_STMTS
{
// lse_acc_host shape: num_splits, shape_batch, nhead, shape_seqlen_q
for(int i_split = 0; i_split < num_splits; ++i_split)
{
printf("[POYENC][HOST] i_split: %d\n", i_split);
printf("[POYENC][HOST] lse_acc_host(%2d,%2d, 0) = ", i_split, wb);
for(int row = 0; row < real_seqlen_q; ++row)
{
printf(
"%11.7f",
// printf("[POYENC][HOST] lse_acc_host(%2d,%2d, 0): %11.7f\n", i_split, wb,
lse_acc_host(i_split, wb, 0, row));
}
printf("\n");
}
}
APP_DEBUG_STMTS
{
ck_tile::HostTensor<LSEDataType> lse_max({real_seqlen_q});
for(int row = 0; row < real_seqlen_q; ++row)
{
lse_max(row) = -ck_tile::numeric<LSEDataType>::infinity();
for(int i_split = 0; i_split < num_splits; ++i_split)
{
if(lse_max(row) < lse_acc_host(i_split, wb, 0, row))
{
lse_max(row) = lse_acc_host(i_split, wb, 0, row);
}
}
}
printf("[POYENC][HOST] lse_max: ");
for(int row = 0; row < real_seqlen_q; ++row)
{
printf("%11.7f", lse_max(row));
}
printf("\n");
static const auto get_validated_m = [](LSEDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
return raw_m == -ck_tile::numeric<LSEDataType>::infinity()
? ck_tile::type_convert<LSEDataType>(0.f)
: raw_m;
};
ck_tile::HostTensor<LSEDataType> lse_sum({shape_seqlen_q});
for(int row = 0; row < real_seqlen_q; ++row)
{
lse_sum(row) = 0;
for(int i_split = 0; i_split < num_splits; ++i_split)
{
lse_sum(row) += ck_tile::exp(lse_acc_host(i_split, wb, 0, row) -
get_validated_m(lse_max(row)));
}
}
printf("[POYENC][HOST] lse_sum: ");
for(int row = 0; row < real_seqlen_q; ++row)
{
printf("%11.7f", lse_sum(row));
}
printf("\n");
ck_tile::HostTensor<LSEDataType> lse_logsum({real_seqlen_q});
for(int row = 0; row < real_seqlen_q; ++row)
{
if(lse_sum(row) == 0.f || lse_sum(row) != lse_sum(row))
{
lse_logsum(row) = ck_tile::numeric<LSEDataType>::infinity();
}
else
{
lse_logsum(row) = ck_tile::log(lse_sum(row)) + get_validated_m(lse_max(row));
}
}
for(int row = 0; row < real_seqlen_q; ++row)
{
if(lse_logsum(row) == ck_tile::numeric<LSEDataType>::infinity())
{
lse_logsum(row) = -ck_tile::numeric<LSEDataType>::infinity();
}
}
// lse_host shape: [batch, nhead, max_seqlen_q]
printf("[POYENC][DEVICE] lse_host: ");
for(int row = 0; row < real_seqlen_q; ++row)
{
printf("%11.7f", lse_host(wb, 0, row));
}
printf("\n");
printf("[POYENC][HOST] lse_logsum: ");
for(int row = 0; row < real_seqlen_q; ++row)
{
printf("%11.7f", lse_logsum(row));
}
printf("\n");
}
// adjust matrix index according to the mode
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
@@ -990,34 +863,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
// clang-format on
auto [rtol, atol] = get_elimit<DataType>(init_method);
bool cur_pass = true;
if(lse)
{
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
lse_host_result.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); });
cur_pass = ck_tile::check_err(lse_host_result,
lse_host_ref,
"LSE Error: Incorrect results!",
rtol,
atol,
/* allow_infinity_ref = */ true);
// if (cur_pass) std::cout << "LSE pass" << std::endl;
pass &= cur_pass;
if(!cur_pass)
{
std::cerr << "LSE mismatch found at batch: " << wb << std::endl
<< "\tseqlen_q: " << real_seqlen_q << std::endl
<< "\tseqlen_k: " << real_seqlen_k << std::endl
<< "\tseqstart_q: " << seqstart_q_host << std::endl
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
}
}
#if 1
cur_pass = ck_tile::check_err(
bool cur_pass = ck_tile::check_err(
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
// if (cur_pass) std::cout << "OUT pass" << std::endl;
pass &= cur_pass;
if(!cur_pass)
{
@@ -1029,7 +876,32 @@ bool run(const ck_tile::ArgParser& arg_parser)
break;
}
#endif
if(lse)
{
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
lse_host_result.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); });
bool lse_pass = ck_tile::check_err(lse_host_result,
lse_host_ref,
"LSE Error: Incorrect results!",
rtol,
atol,
/* allow_infinity_ref = */ true);
pass &= lse_pass;
if(!cur_pass)
{
std::cerr << "LSE mismatch found at batch: " << wb << std::endl
<< "\tseqlen_q: " << real_seqlen_q << std::endl
<< "\tseqlen_k: " << real_seqlen_k << std::endl
<< "\tseqstart_q: " << seqstart_q_host << std::endl
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
break;
}
}
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;