mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Remove debug statements in example
This commit is contained in:
@@ -162,10 +162,7 @@ float fmha_fwd_dispatch(fmha_fwd_traits traits,
|
||||
fmha_fwd_args args,
|
||||
const ck_tile::stream_config& config)
|
||||
{
|
||||
#if defined(ALWAYS_INVOKE_SPLITKV_KERNEL)
|
||||
return fmha_fwd_splitkv(traits, args, config);
|
||||
#else
|
||||
if(1 < args.num_splits)
|
||||
if(1 < args.num_splits && args.p_drop == 0.0f)
|
||||
{
|
||||
return fmha_fwd_splitkv(traits, args, config);
|
||||
}
|
||||
@@ -173,8 +170,7 @@ float fmha_fwd_dispatch(fmha_fwd_traits traits,
|
||||
{
|
||||
return fmha_fwd(traits, args, config);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
@@ -385,22 +381,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
: std::array<ck_tile::index_t, 2>{1, 1});
|
||||
|
||||
ck_tile::HostTensor<LSEDataType> lse_acc_host(
|
||||
1 <= num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
APP_DEBUG_STMTS
|
||||
{
|
||||
std::cout << "lse_acc_host shape: " << num_splits << ", " << batch << ", " << nhead << ", "
|
||||
<< max_seqlen_q << std::endl;
|
||||
}
|
||||
1 < num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
ck_tile::HostTensor<OaccDataType> o_acc_host(
|
||||
1 <= num_splits
|
||||
1 < num_splits
|
||||
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
|
||||
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
|
||||
APP_DEBUG_STMTS
|
||||
{
|
||||
std::cout << "o_acc_host shape: " << num_splits << ", " << shape_batch << ", " << nhead
|
||||
<< ", " << shape_seqlen_q << ", " << hdim_v << std::endl;
|
||||
}
|
||||
|
||||
// self define lse data layout as [batch, nhead, max_seqlen_q]
|
||||
ck_tile::HostTensor<LSEDataType> lse_host(
|
||||
lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}
|
||||
@@ -669,7 +656,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
o_buf.FromDevice(o_host.data());
|
||||
lse_buf.FromDevice(lse_host.data());
|
||||
lse_acc_buf.FromDevice(lse_acc_host.data());
|
||||
randval_buf.FromDevice(randval_host.data());
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
uint8_t p_undrop_in_uint8_t =
|
||||
@@ -678,124 +664,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
bool pass = true;
|
||||
|
||||
APP_DEBUG_STMTS
|
||||
{
|
||||
printf("\n");
|
||||
printf("[POYENC][HOST] lse shape: %d, %d, %d, %d\n",
|
||||
num_splits,
|
||||
shape_batch,
|
||||
nhead,
|
||||
shape_seqlen_q);
|
||||
}
|
||||
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
APP_DEBUG_STMTS { printf("[POYENC][HOST] wb: %d\n", wb); }
|
||||
|
||||
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
||||
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
||||
|
||||
APP_DEBUG_STMTS
|
||||
{
|
||||
// lse_acc_host shape: num_splits, shape_batch, nhead, shape_seqlen_q
|
||||
for(int i_split = 0; i_split < num_splits; ++i_split)
|
||||
{
|
||||
printf("[POYENC][HOST] i_split: %d\n", i_split);
|
||||
printf("[POYENC][HOST] lse_acc_host(%2d,%2d, 0) = ", i_split, wb);
|
||||
for(int row = 0; row < real_seqlen_q; ++row)
|
||||
{
|
||||
printf(
|
||||
"%11.7f",
|
||||
// printf("[POYENC][HOST] lse_acc_host(%2d,%2d, 0): %11.7f\n", i_split, wb,
|
||||
lse_acc_host(i_split, wb, 0, row));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
APP_DEBUG_STMTS
|
||||
{
|
||||
ck_tile::HostTensor<LSEDataType> lse_max({real_seqlen_q});
|
||||
for(int row = 0; row < real_seqlen_q; ++row)
|
||||
{
|
||||
lse_max(row) = -ck_tile::numeric<LSEDataType>::infinity();
|
||||
for(int i_split = 0; i_split < num_splits; ++i_split)
|
||||
{
|
||||
if(lse_max(row) < lse_acc_host(i_split, wb, 0, row))
|
||||
{
|
||||
lse_max(row) = lse_acc_host(i_split, wb, 0, row);
|
||||
}
|
||||
}
|
||||
}
|
||||
printf("[POYENC][HOST] lse_max: ");
|
||||
for(int row = 0; row < real_seqlen_q; ++row)
|
||||
{
|
||||
printf("%11.7f", lse_max(row));
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
static const auto get_validated_m = [](LSEDataType raw_m) {
|
||||
/// NOTICE: bias might be materialized mask including -inf values, need
|
||||
/// consideration
|
||||
return raw_m == -ck_tile::numeric<LSEDataType>::infinity()
|
||||
? ck_tile::type_convert<LSEDataType>(0.f)
|
||||
: raw_m;
|
||||
};
|
||||
|
||||
ck_tile::HostTensor<LSEDataType> lse_sum({shape_seqlen_q});
|
||||
for(int row = 0; row < real_seqlen_q; ++row)
|
||||
{
|
||||
lse_sum(row) = 0;
|
||||
for(int i_split = 0; i_split < num_splits; ++i_split)
|
||||
{
|
||||
lse_sum(row) += ck_tile::exp(lse_acc_host(i_split, wb, 0, row) -
|
||||
get_validated_m(lse_max(row)));
|
||||
}
|
||||
}
|
||||
printf("[POYENC][HOST] lse_sum: ");
|
||||
for(int row = 0; row < real_seqlen_q; ++row)
|
||||
{
|
||||
printf("%11.7f", lse_sum(row));
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
ck_tile::HostTensor<LSEDataType> lse_logsum({real_seqlen_q});
|
||||
for(int row = 0; row < real_seqlen_q; ++row)
|
||||
{
|
||||
if(lse_sum(row) == 0.f || lse_sum(row) != lse_sum(row))
|
||||
{
|
||||
lse_logsum(row) = ck_tile::numeric<LSEDataType>::infinity();
|
||||
}
|
||||
else
|
||||
{
|
||||
lse_logsum(row) = ck_tile::log(lse_sum(row)) + get_validated_m(lse_max(row));
|
||||
}
|
||||
}
|
||||
|
||||
for(int row = 0; row < real_seqlen_q; ++row)
|
||||
{
|
||||
if(lse_logsum(row) == ck_tile::numeric<LSEDataType>::infinity())
|
||||
{
|
||||
lse_logsum(row) = -ck_tile::numeric<LSEDataType>::infinity();
|
||||
}
|
||||
}
|
||||
|
||||
// lse_host shape: [batch, nhead, max_seqlen_q]
|
||||
printf("[POYENC][DEVICE] lse_host: ");
|
||||
for(int row = 0; row < real_seqlen_q; ++row)
|
||||
{
|
||||
printf("%11.7f", lse_host(wb, 0, row));
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
printf("[POYENC][HOST] lse_logsum: ");
|
||||
for(int row = 0; row < real_seqlen_q; ++row)
|
||||
{
|
||||
printf("%11.7f", lse_logsum(row));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
// adjust matrix index according to the mode
|
||||
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
@@ -990,34 +863,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// clang-format on
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataType>(init_method);
|
||||
bool cur_pass = true;
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
|
||||
lse_host_result.ForEach(
|
||||
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); });
|
||||
|
||||
cur_pass = ck_tile::check_err(lse_host_result,
|
||||
lse_host_ref,
|
||||
"LSE Error: Incorrect results!",
|
||||
rtol,
|
||||
atol,
|
||||
/* allow_infinity_ref = */ true);
|
||||
// if (cur_pass) std::cout << "LSE pass" << std::endl;
|
||||
pass &= cur_pass;
|
||||
if(!cur_pass)
|
||||
{
|
||||
std::cerr << "LSE mismatch found at batch: " << wb << std::endl
|
||||
<< "\tseqlen_q: " << real_seqlen_q << std::endl
|
||||
<< "\tseqlen_k: " << real_seqlen_k << std::endl
|
||||
<< "\tseqstart_q: " << seqstart_q_host << std::endl
|
||||
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
|
||||
}
|
||||
}
|
||||
#if 1
|
||||
cur_pass = ck_tile::check_err(
|
||||
bool cur_pass = ck_tile::check_err(
|
||||
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
// if (cur_pass) std::cout << "OUT pass" << std::endl;
|
||||
pass &= cur_pass;
|
||||
if(!cur_pass)
|
||||
{
|
||||
@@ -1029,7 +876,32 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
|
||||
lse_host_result.ForEach(
|
||||
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); });
|
||||
|
||||
bool lse_pass = ck_tile::check_err(lse_host_result,
|
||||
lse_host_ref,
|
||||
"LSE Error: Incorrect results!",
|
||||
rtol,
|
||||
atol,
|
||||
/* allow_infinity_ref = */ true);
|
||||
|
||||
pass &= lse_pass;
|
||||
if(!cur_pass)
|
||||
{
|
||||
std::cerr << "LSE mismatch found at batch: " << wb << std::endl
|
||||
<< "\tseqlen_q: " << real_seqlen_q << std::endl
|
||||
<< "\tseqlen_k: " << real_seqlen_k << std::endl
|
||||
<< "\tseqstart_q: " << seqstart_q_host << std::endl
|
||||
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
|
||||
Reference in New Issue
Block a user