mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 23:05:54 +00:00
Add output of estimated TFLOPS
This commit is contained in:
@@ -279,6 +279,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
};
|
||||
};
|
||||
|
||||
long total_flops = 0;
|
||||
|
||||
// estimate the total flops occurred, ignoring the scaling and SILu
|
||||
if(is_jagged)
|
||||
{
|
||||
for(auto len : seq_lengths)
|
||||
total_flops += 2 * (len * len * hdim_qk + len * hdim_v * len);
|
||||
|
||||
total_flops *= num_head;
|
||||
}
|
||||
else
|
||||
{
|
||||
total_flops = num_batch * num_head * (seqlen * seqlen * hdim_qk + seqlen * hdim_v * seqlen);
|
||||
};
|
||||
|
||||
int batches_for_alloc = is_jagged ? 1 : num_batch;
|
||||
|
||||
ck_tile::HostTensor<InOutDataType> q_host(
|
||||
@@ -476,7 +491,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
auto ms = timer.duration() / 10.f;
|
||||
|
||||
std::cout << "Average execution time of the hstu_attention operation is " << ms
|
||||
<< " milli-seconds" << std::endl;
|
||||
<< " milli-seconds, estimated TFLOPS is "
|
||||
<< (static_cast<float>(total_flops) / ms) / 1.0e9 << std::endl;
|
||||
}
|
||||
|
||||
return res;
|
||||
|
||||
Reference in New Issue
Block a user