mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Always add fmha_fwd() api
This commit is contained in:
@@ -262,15 +262,11 @@ float fmha_fwd_dispatch(fmha_fwd_traits traits,
|
||||
{
|
||||
return fmha_fwd_splitkv(traits, args, config);
|
||||
}
|
||||
#endif
|
||||
#if 0
|
||||
else
|
||||
#endif
|
||||
{
|
||||
return fmha_fwd(traits, args, config);
|
||||
}
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
@@ -546,9 +542,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
(mode == mode_enum::batch ? seqlen_ks[0]
|
||||
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
|
||||
: seqstart_k_with_padding_host.back()));
|
||||
#ifdef ENABLE_HOST_DEBUG_MSG
|
||||
std::cerr << "[HOST] num_blocks: " << max_num_blocks << std::endl;
|
||||
#endif
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
ck_tile::HostTensor<KDataType> k_host(
|
||||
@@ -1032,35 +1026,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::cout << std::flush << std::endl;
|
||||
return true;
|
||||
}
|
||||
#if defined(ENABLE_HOST_DEBUG_MSG)
|
||||
k_buf.FromDevice(k_host.data());
|
||||
|
||||
printf("\n");
|
||||
// batch, nhead_k, seqlen_knew, hdim_q/hdim_v
|
||||
for(int row = 0; row < seqlen_knew; ++row)
|
||||
{
|
||||
printf("[HOST] vnew_host[%d] = ", row);
|
||||
for(int col = 0; col < 32; ++col)
|
||||
{
|
||||
printf("%11.7f", ck_tile::type_convert<float>(vnew_host(0, 0, row, col)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
// max_num_blocks, nhead_k, page_block_size, hdim_q/hdim_v
|
||||
int block_index = 1;
|
||||
int psychical_block_index = block_table_host(0, block_index);
|
||||
for(int row = 0; row < min(seqlen_knew, page_block_size); ++row)
|
||||
{
|
||||
printf("[HOST] v_host[%d] = ", row);
|
||||
for(int col = 0; col < 32; ++col)
|
||||
{
|
||||
printf("%11.7f",
|
||||
ck_tile::type_convert<float>(v_host(psychical_block_index, 0, row, col)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
o_buf.FromDevice(o_host.data());
|
||||
lse_buf.FromDevice(lse_host.data());
|
||||
|
||||
Reference in New Issue
Block a user