entropy criterion optimizations (#286)

* entropy criterion optimizations

* online linear regression module

* online regression refactor

* revising ss_tot handling

---------

Co-authored-by: Jerry Hou <jerryhou@fb.com>
This commit is contained in:
Jerry Hou
2025-12-05 17:02:21 -08:00
committed by GitHub
parent a6995413ac
commit f651636501
3 changed files with 247 additions and 27 deletions

View File

@@ -18,6 +18,7 @@
#pragma once
#include <nvbench/detail/online_linear_regression.cuh>
#include <nvbench/detail/ring_buffer.cuh>
#include <nvbench/stopping_criterion.cuh>
#include <nvbench/types.cuh>
@@ -33,14 +34,15 @@ class entropy_criterion final : public stopping_criterion_base
nvbench::int64_t m_total_samples{};
nvbench::float64_t m_total_cuda_time{};
std::vector<std::pair<nvbench::float64_t, nvbench::int64_t>> m_freq_tracker;
nvbench::float64_t m_sum_count_log_counter{};
// TODO The window size should be user-configurable
nvbench::detail::ring_buffer<nvbench::float64_t> m_entropy_tracker{299};
// Used to avoid re-allocating temporary memory
std::vector<nvbench::float64_t> m_probabilities;
online_linear_regression m_regression;
nvbench::float64_t compute_entropy();
void update_entropy_sum(nvbench::float64_t old_count, nvbench::float64_t new_count);
public:
entropy_criterion();

View File

@@ -28,7 +28,6 @@ entropy_criterion::entropy_criterion()
: stopping_criterion_base{"entropy", {{"max-angle", 0.048}, {"min-r2", 0.36}}}
{
m_freq_tracker.reserve(m_entropy_tracker.capacity() * 2);
m_probabilities.reserve(m_entropy_tracker.capacity() * 2);
}
void entropy_criterion::do_initialize()
@@ -37,37 +36,44 @@ void entropy_criterion::do_initialize()
m_total_cuda_time = 0.0;
m_entropy_tracker.clear();
m_freq_tracker.clear();
m_sum_count_log_counter = 0.0;
m_regression.clear();
}
void entropy_criterion::update_entropy_sum(nvbench::float64_t old_count,
nvbench::float64_t new_count)
{
if (old_count > 0)
{
auto diff = new_count - old_count;
m_sum_count_log_counter += new_count * std::log2(1 + diff / old_count) +
diff * std::log2(old_count);
}
else
{
m_sum_count_log_counter += new_count * std::log2(new_count);
}
}
nvbench::float64_t entropy_criterion::compute_entropy()
{
const std::size_t n = m_freq_tracker.size();
if (n == 0)
if (m_total_samples == 0)
{
return 0.0;
}
m_probabilities.resize(n);
for (std::size_t i = 0; i < n; i++)
{
m_probabilities[i] = static_cast<nvbench::float64_t>(m_freq_tracker[i].second) /
static_cast<nvbench::float64_t>(m_total_samples);
}
const auto n = static_cast<nvbench::float64_t>(m_total_samples);
const nvbench::float64_t entropy = std::log2(n) - m_sum_count_log_counter / n;
nvbench::float64_t entropy{};
for (nvbench::float64_t p : m_probabilities)
{
entropy -= p * std::log2(p);
}
return entropy;
return std::copysign(std::max(0.0, entropy), 1.0);
}
void entropy_criterion::do_add_measurement(nvbench::float64_t measurement)
{
m_total_samples++;
m_total_cuda_time += measurement;
nvbench::int64_t old_count = 0;
{
auto key = measurement;
constexpr bool bin_keys = false;
@@ -88,15 +94,34 @@ void entropy_criterion::do_add_measurement(nvbench::float64_t measurement)
if (it != m_freq_tracker.end() && it->first == key)
{
old_count = it->second;
it->second += 1;
}
else
{
old_count = 0;
m_freq_tracker.insert(it, std::make_pair(key, nvbench::int64_t{1}));
}
}
m_entropy_tracker.push_back(compute_entropy());
update_entropy_sum(static_cast<nvbench::float64_t>(old_count),
static_cast<nvbench::float64_t>(old_count + 1));
const nvbench::float64_t entropy = compute_entropy();
const nvbench::float64_t n = static_cast<nvbench::float64_t>(m_entropy_tracker.size());
if (m_entropy_tracker.size() == m_entropy_tracker.capacity())
{
const nvbench::float64_t old_entropy = *m_entropy_tracker.cbegin();
m_regression.slide_window(old_entropy, entropy);
}
else
{
const nvbench::float64_t new_x = n;
m_regression.update({new_x, entropy});
}
m_entropy_tracker.push_back(entropy);
}
bool entropy_criterion::do_is_finished()
@@ -106,25 +131,30 @@ bool entropy_criterion::do_is_finished()
return false;
}
// Even number of samples is used to reduce the overhead and not required to compute entropy.
// This makes `is_finished()` about 20% faster than corresponding stdrel method.
if (m_total_samples % 2 != 0)
{
return false;
}
auto begin = m_entropy_tracker.cbegin();
auto end = m_entropy_tracker.cend();
auto mean = statistics::compute_mean(begin, end);
const nvbench::float64_t slope = m_regression.slope();
const auto [slope, intercept] = statistics::compute_linear_regression(begin, end, mean);
if (!std::isfinite(slope))
{
return false;
}
if (statistics::slope2deg(slope) > m_params.get_float64("max-angle"))
{
return false;
}
const auto r2 = statistics::compute_r2(begin, end, mean, slope, intercept);
const nvbench::float64_t r2 = m_regression.r_squared();
if (!std::isfinite(r2))
{
return false;
}
if (r2 < m_params.get_float64("min-r2"))
{
return false;

View File

@@ -0,0 +1,188 @@
/*
* Copyright 2025 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 with the LLVM exception
* (the "License"); you may not use this file except in compliance with
* the License.
*
* You may obtain a copy of the License at
*
* http://llvm.org/foundation/relicensing/LICENSE.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <nvbench/types.cuh>
#include <cmath>
#include <limits>
#include <utility>
namespace nvbench::detail
{
class online_linear_regression
{
nvbench::float64_t m_sum_x{};
nvbench::float64_t m_sum_y{};
nvbench::float64_t m_sum_xy{};
nvbench::float64_t m_sum_x2{};
nvbench::float64_t m_sum_y2{};
nvbench::int64_t m_count{};
public:
online_linear_regression() = default;
void update(std::pair<nvbench::float64_t, nvbench::float64_t> incoming)
{
const auto [x, y] = incoming;
m_sum_x += x;
m_sum_y += y;
m_sum_xy += x * y;
m_sum_x2 += x * x;
m_sum_y2 += y * y;
m_count++;
}
void update(std::pair<nvbench::float64_t, nvbench::float64_t> outgoing,
std::pair<nvbench::float64_t, nvbench::float64_t> incoming)
{
const auto [x_out, y_out] = outgoing;
m_sum_x -= x_out;
m_sum_y -= y_out;
m_sum_xy -= x_out * y_out;
m_sum_x2 -= x_out * x_out;
m_sum_y2 -= y_out * y_out;
const auto [x_in, y_in] = incoming;
m_sum_x += x_in;
m_sum_y += y_in;
m_sum_xy += x_in * y_in;
m_sum_x2 += x_in * x_in;
m_sum_y2 += y_in * y_in;
}
void slide_window(nvbench::float64_t y_out, nvbench::float64_t y_in)
{
m_sum_y -= y_out;
m_sum_y += y_in;
m_sum_y2 -= y_out * y_out;
m_sum_y2 += y_in * y_in;
m_sum_xy -= m_sum_y - y_in;
m_sum_xy += (static_cast<nvbench::float64_t>(m_count) - 1.0) * y_in;
}
void clear()
{
m_sum_x = 0.0;
m_sum_y = 0.0;
m_sum_xy = 0.0;
m_sum_x2 = 0.0;
m_sum_y2 = 0.0;
m_count = 0;
}
[[nodiscard]] nvbench::int64_t count() const { return m_count; }
[[nodiscard]] nvbench::float64_t mean_x() const
{
return m_count > 0 ? m_sum_x / static_cast<nvbench::float64_t>(m_count) : 0.0;
}
[[nodiscard]] nvbench::float64_t mean_y() const
{
return m_count > 0 ? m_sum_y / static_cast<nvbench::float64_t>(m_count) : 0.0;
}
[[nodiscard]] nvbench::float64_t slope() const
{
static constexpr nvbench::float64_t q_nan =
std::numeric_limits<nvbench::float64_t>::quiet_NaN();
if (m_count < 2)
return q_nan;
const nvbench::float64_t n = static_cast<nvbench::float64_t>(m_count);
const nvbench::float64_t mean_x = (m_sum_x / n);
const nvbench::float64_t mean_y = (m_sum_y / n);
const nvbench::float64_t numerator = (m_sum_xy / n) - mean_x * mean_y;
const nvbench::float64_t denominator = (m_sum_x2 / n) - mean_x * mean_x;
if (std::abs(denominator) < 1e-12)
return q_nan;
return numerator / denominator;
}
[[nodiscard]] nvbench::float64_t intercept() const
{
if (m_count < 2)
{
return std::numeric_limits<nvbench::float64_t>::quiet_NaN();
}
const nvbench::float64_t current_slope = slope();
if (!std::isfinite(current_slope))
{
return std::numeric_limits<nvbench::float64_t>::quiet_NaN();
}
return mean_y() - current_slope * mean_x();
}
[[nodiscard]] nvbench::float64_t r_squared() const
{
if (m_count < 2)
{
return std::numeric_limits<nvbench::float64_t>::quiet_NaN();
}
// ss_tot and ss_res scaled by 1/n to avoid overflow
const nvbench::float64_t n = static_cast<nvbench::float64_t>(m_count);
const nvbench::float64_t mean_y_v = mean_y();
const nvbench::float64_t ss_tot = (m_sum_y2 / n) - mean_y_v * mean_y_v;
if (ss_tot < std::numeric_limits<nvbench::float64_t>::epsilon())
{
return 1.0;
}
const nvbench::float64_t slope_v = slope();
const nvbench::float64_t intercept_v = intercept();
if (!std::isfinite(slope_v) || !std::isfinite(intercept_v))
{
return std::numeric_limits<nvbench::float64_t>::quiet_NaN();
}
else
{
const nvbench::float64_t mean_xy_v = m_sum_xy / n;
const nvbench::float64_t mean_xx_v = m_sum_x2 / n;
const nvbench::float64_t mean_x_v = m_sum_x / n;
const nvbench::float64_t ss_tot_m_res =
slope_v * ((mean_xy_v - slope_v * mean_xx_v) + (mean_xy_v - intercept_v * mean_x_v)) +
intercept_v * (mean_y_v - slope_v * mean_x_v - intercept_v) +
mean_y_v * (intercept_v - mean_y_v);
return std::min(std::max(ss_tot_m_res / ss_tot, 0.0), 1.0);
}
}
[[nodiscard]] nvbench::float64_t sum_x() const { return m_sum_x; }
[[nodiscard]] nvbench::float64_t sum_y() const { return m_sum_y; }
[[nodiscard]] nvbench::float64_t sum_xy() const { return m_sum_xy; }
[[nodiscard]] nvbench::float64_t sum_x2() const { return m_sum_x2; }
[[nodiscard]] nvbench::float64_t sum_y2() const { return m_sum_y2; }
};
} // namespace nvbench::detail