diff --git a/nvbench/detail/statistics.cuh b/nvbench/detail/statistics.cuh index 358bb6c..311a20d 100644 --- a/nvbench/detail/statistics.cuh +++ b/nvbench/detail/statistics.cuh @@ -62,11 +62,22 @@ ValueType standard_deviation(Iter first, Iter last, ValueType mean) return std::sqrt(variance); } +/** + * Computes and returns the mean. + * + * If the input has fewer than 1 sample, infinity is returned. + */ template nvbench::float64_t compute_mean(It first, It last) { - const auto n = std::distance(first, last); - return std::accumulate(first, last, 0.0) / static_cast(n); + const auto num = std::distance(first, last); + + if (num < 1) + { + return std::numeric_limits::infinity(); + } + + return std::accumulate(first, last, 0.0) / static_cast(num); } /** diff --git a/testing/statistics.cu b/testing/statistics.cu index e2e1463..50e1014 100644 --- a/testing/statistics.cu +++ b/testing/statistics.cu @@ -25,6 +25,22 @@ namespace statistics = nvbench::detail::statistics; +void test_mean() +{ + { + std::vector data{1.0, 2.0, 3.0, 4.0, 5.0}; + const nvbench::float64_t actual = statistics::compute_mean(std::begin(data), std::end(data)); + const nvbench::float64_t expected = 3.0; + ASSERT(std::abs(actual - expected) < 0.001); + } + + { + std::vector data; + const bool finite = std::isfinite(statistics::compute_mean(std::begin(data), std::end(data))); + ASSERT(!finite); + } +} + void test_std() { std::vector data{1.0, 2.0, 3.0, 4.0, 5.0}; @@ -104,6 +120,7 @@ void test_slope_conversion() int main() { + test_mean(); test_std(); test_lin_regression(); test_r2();