From ac8cd9c0004fcc3b12be62a1507ffeb7cf9a44dd Mon Sep 17 00:00:00 2001 From: danthe3rd Date: Thu, 13 Oct 2022 08:55:30 +0000 Subject: [PATCH] Update FLOPS counter --- .../fused_multihead_attention.cu | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/examples/42_fused_multi_head_attention/fused_multihead_attention.cu b/examples/42_fused_multi_head_attention/fused_multihead_attention.cu index 254fd0956..61388c242 100644 --- a/examples/42_fused_multi_head_attention/fused_multihead_attention.cu +++ b/examples/42_fused_multi_head_attention/fused_multihead_attention.cu @@ -252,15 +252,30 @@ struct Options { double gflops(double runtime_s) const { // Number of real-valued multiply-adds - int64_t fmas = int64_t(); + int64_t fops = int64_t(); - for (auto const & problem : problem_sizes0) { - // Two flops per multiply-add - fmas += problem.product() * 2; + for (int i = 0; i < problem_sizes0.size(); ++i) { + auto const& problem0 = problem_sizes0[i]; + auto const& problem1 = problem_sizes1[i]; + for (int row = 0; row < problem0.m(); ++row) { + int num_cols0 = problem0.n(); + if (causal) { + num_cols0 = std::min(row + 1, num_cols0); + } + // P <- Q . K_t + fops += 2 * num_cols0 * problem0.k(); + // P <- exp(P - max(P)) + fops += 2 * num_cols0; + // S <- sum(P) + fops += num_cols0 - 1; + // O <- P . V + fops += 2 * num_cols0 * problem1.n(); + // O <- O / S + fops += num_cols0 * problem1.n(); + } } - - // Multiply another '2' because of the back-to-back GEMM problems in attention - return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + + return double(fops) / double(1.0e9) / runtime_s; } };