mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-12 01:10:08 +00:00
Update FLOPS counter
This commit is contained in:
@@ -252,15 +252,30 @@ struct Options {
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of real-valued multiply-adds
|
||||
int64_t fmas = int64_t();
|
||||
int64_t fops = int64_t();
|
||||
|
||||
for (auto const & problem : problem_sizes0) {
|
||||
// Two flops per multiply-add
|
||||
fmas += problem.product() * 2;
|
||||
for (int i = 0; i < problem_sizes0.size(); ++i) {
|
||||
auto const& problem0 = problem_sizes0[i];
|
||||
auto const& problem1 = problem_sizes1[i];
|
||||
for (int row = 0; row < problem0.m(); ++row) {
|
||||
int num_cols0 = problem0.n();
|
||||
if (causal) {
|
||||
num_cols0 = std::min(row + 1, num_cols0);
|
||||
}
|
||||
// P <- Q . K_t
|
||||
fops += 2 * num_cols0 * problem0.k();
|
||||
// P <- exp(P - max(P))
|
||||
fops += 2 * num_cols0;
|
||||
// S <- sum(P)
|
||||
fops += num_cols0 - 1;
|
||||
// O <- P . V
|
||||
fops += 2 * num_cols0 * problem1.n();
|
||||
// O <- O / S
|
||||
fops += num_cols0 * problem1.n();
|
||||
}
|
||||
}
|
||||
|
||||
// Multiply another '2' because of the back-to-back GEMM problems in attention
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
|
||||
return double(fops) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user