Update FLOPS counter

This commit is contained in:
danthe3rd
2022-10-13 08:55:30 +00:00
parent 00b9af0b36
commit ac8cd9c000

View File

@@ -252,15 +252,30 @@ struct Options {
double gflops(double runtime_s) const {
// Number of real-valued multiply-adds
int64_t fmas = int64_t();
int64_t fops = int64_t();
for (auto const & problem : problem_sizes0) {
// Two flops per multiply-add
fmas += problem.product() * 2;
for (int i = 0; i < problem_sizes0.size(); ++i) {
auto const& problem0 = problem_sizes0[i];
auto const& problem1 = problem_sizes1[i];
for (int row = 0; row < problem0.m(); ++row) {
int num_cols0 = problem0.n();
if (causal) {
num_cols0 = std::min(row + 1, num_cols0);
}
// P <- Q . K_t
fops += 2 * num_cols0 * problem0.k();
// P <- exp(P - max(P))
fops += 2 * num_cols0;
// S <- sum(P)
fops += num_cols0 - 1;
// O <- P . V
fops += 2 * num_cols0 * problem1.n();
// O <- O / S
fops += num_cols0 * problem1.n();
}
}
// Multiply another '2' because of the back-to-back GEMM problems in attention
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
return double(fops) / double(1.0e9) / runtime_s;
}
};