update the way to compute fmha fwd tflop, include mask type (#2386)

* update the way to compute fwd tflop, include mask type

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* remove unneccessary comment

* add necessary comment

* remove some comment

---------

Signed-off-by: JL-underdog <Jun.Lin@amd.com>
Co-authored-by: root <root@GT-SC-DI16-08.dh144.dcgpu>

[ROCm/composable_kernel commit: 61eb622e85]
This commit is contained in:
Linjun-AMD
2025-06-23 15:53:58 +08:00
committed by GitHub
parent 7001322416
commit 17346f2c91
2 changed files with 22 additions and 3 deletions

4
example/ck_tile/01_fmha/fmha_fwd.cpp Normal file → Executable file
View File

@@ -542,8 +542,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
max_seqlen_k = real_seqlen_k;
}
flop += nhead * (static_cast<std::size_t>(2) * real_seqlen_q * real_seqlen_k * hdim_q +
static_cast<std::size_t>(2) * real_seqlen_q * hdim_v * real_seqlen_k);
flop += nhead * (static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_q +
static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_v);
num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q +
sizeof(KDataType) * real_seqlen_k * hdim_q +

21
example/ck_tile/01_fmha/mask.hpp Normal file → Executable file
View File

@@ -21,6 +21,8 @@ enum class mask_enum
struct mask_info
{
mask_enum type;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t y, x;
ck_tile::index_t left, right; // FA style SWA left/right
@@ -42,6 +44,8 @@ struct mask_info
ck_tile::index_t x_total = seqlen_k;
ck_tile::index_t y_total = seqlen_q;
mask_info tmp;
tmp.seqlen_q = seqlen_q;
tmp.seqlen_k = seqlen_k;
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
@@ -148,7 +152,22 @@ struct mask_info
}
return tmp;
}
ck_tile::index_t get_unmaskarea() const
{
if(type == mask_enum::no_mask)
return seqlen_q * seqlen_k;
ck_tile::index_t area = 0;
for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y)
{
ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast<ck_tile::index_t>(0));
ck_tile::index_t x_end = std::min(i_y + x, seqlen_k);
if(x_end > x_start)
{
area += (x_end - x_start);
}
}
return area;
}
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
{
mi.serialize(os);