diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index bb6bb9050f..bafeef361e 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -279,6 +279,21 @@ bool run(const ck_tile::ArgParser& arg_parser) }; }; + long total_flops = 0; + + // estimate the total flops occurred, ignoring the scaling and SILu + if(is_jagged) + { + for(auto len : seq_lengths) + total_flops += 2 * (len * len * hdim_qk + len * hdim_v * len); + + total_flops *= num_head; + } + else + { + total_flops = num_batch * num_head * (seqlen * seqlen * hdim_qk + seqlen * hdim_v * seqlen); + }; + int batches_for_alloc = is_jagged ? 1 : num_batch; ck_tile::HostTensor q_host( @@ -476,7 +491,8 @@ bool run(const ck_tile::ArgParser& arg_parser) auto ms = timer.duration() / 10.f; std::cout << "Average execution time of the hstu_attention operation is " << ms - << " milli-seconds" << std::endl; + << " milli-seconds, estimated TFLOPS is " + << (static_cast(total_flops) / ms) / 1.0e9 << std::endl; } return res;