Always add fmha_fwd() api

This commit is contained in:
PoYen, Chen
2024-08-07 13:43:14 +00:00
parent eda78d1a10
commit 55ce2948a9

View File

@@ -262,15 +262,11 @@ float fmha_fwd_dispatch(fmha_fwd_traits traits,
{
return fmha_fwd_splitkv(traits, args, config);
}
#endif
#if 0
else
#endif
{
return fmha_fwd(traits, args, config);
}
#else
return 0;
#endif
}
template <typename DataType>
@@ -546,9 +542,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
(mode == mode_enum::batch ? seqlen_ks[0]
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
: seqstart_k_with_padding_host.back()));
#ifdef ENABLE_HOST_DEBUG_MSG
std::cerr << "[HOST] num_blocks: " << max_num_blocks << std::endl;
#endif
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<KDataType> k_host(
@@ -1032,35 +1026,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << std::flush << std::endl;
return true;
}
#if defined(ENABLE_HOST_DEBUG_MSG)
k_buf.FromDevice(k_host.data());
printf("\n");
// batch, nhead_k, seqlen_knew, hdim_q/hdim_v
for(int row = 0; row < seqlen_knew; ++row)
{
printf("[HOST] vnew_host[%d] = ", row);
for(int col = 0; col < 32; ++col)
{
printf("%11.7f", ck_tile::type_convert<float>(vnew_host(0, 0, row, col)));
}
printf("\n");
}
// max_num_blocks, nhead_k, page_block_size, hdim_q/hdim_v
int block_index = 1;
int psychical_block_index = block_table_host(0, block_index);
for(int row = 0; row < min(seqlen_knew, page_block_size); ++row)
{
printf("[HOST] v_host[%d] = ", row);
for(int col = 0; col < 32; ++col)
{
printf("%11.7f",
ck_tile::type_convert<float>(v_host(psychical_block_index, 0, row, col)));
}
printf("\n");
}
#endif
o_buf.FromDevice(o_host.data());
lse_buf.FromDevice(lse_host.data());