Implement sample-count stopping criterion with parameter target-samples

--stopping-criterion sample-count --target-samples 100 would stop once
max(--min-samples, --target-samples) samples are collected
This commit is contained in:
Oleksandr Pavlyk
2026-04-30 17:33:35 -05:00
parent 84af4ea455
commit e8d7340696
15 changed files with 228 additions and 3 deletions

View File

@@ -157,6 +157,7 @@
* "stdrel": (default) Converges to a minimal relative standard deviation,
stdev / mean
* "entropy": Converges based on the cumulative entropy of all samples.
* "sample-count": Stops after a target number of samples.
* Each stopping criterion may provide additional parameters to customize
behavior, as detailed below:
@@ -192,3 +193,13 @@
* Default is 0.36.
* Applies to the most recent `--benchmark`, or all benchmarks if specified
before any `--benchmark` arguments.
### "sample-count" Stopping Criterion Parameters
* `--target-samples <count>`
* Stop after at least `<count>` samples are collected.
* Default is 100 samples.
* The total number of collected samples is
`max(--min-samples, --target-samples)`.
* Applies to the most recent `--benchmark`, or all benchmarks if specified
before any `--benchmark` arguments.

View File

@@ -27,6 +27,7 @@ set(srcs
detail/measure_cold.cu
detail/measure_cpu_only.cxx
detail/measure_hot.cu
detail/sample_count_criterion.cxx
detail/state_generator.cxx
detail/stdrel_criterion.cxx
detail/gpu_frequency.cxx

View File

@@ -29,6 +29,7 @@
#endif
#include <nvbench/detail/entropy_criterion.cuh>
#include <nvbench/detail/sample_count_criterion.cuh>
#include <nvbench/detail/stdrel_criterion.cuh>
#include <nvbench/stopping_criterion.cuh>
#include <nvbench/types.cuh>

View File

@@ -33,6 +33,7 @@ criterion_manager::criterion_manager()
{
m_map.emplace("stdrel", std::make_unique<nvbench::detail::stdrel_criterion>());
m_map.emplace("entropy", std::make_unique<nvbench::detail::entropy_criterion>());
m_map.emplace("sample-count", std::make_unique<nvbench::detail::sample_count_criterion>());
}
criterion_manager &criterion_manager::get()

View File

@@ -175,7 +175,7 @@ bool measure_cold_base::is_finished()
}
// Check that we've gathered enough samples:
if (m_total_samples > m_min_samples)
if (m_total_samples >= m_min_samples)
{
if (m_stopping_criterion.is_finished())
{

View File

@@ -93,7 +93,7 @@ bool measure_cpu_only_base::is_finished()
}
// Check that we've gathered enough samples:
if (m_total_samples > m_min_samples)
if (m_total_samples >= m_min_samples)
{
if (m_stopping_criterion.is_finished())
{

View File

@@ -186,7 +186,7 @@ private:
(m_total_cuda_time / static_cast<nvbench::float64_t>(m_total_samples)));
if (m_total_cuda_time > m_min_time && // min time okay
m_total_samples > m_min_samples) // min samples okay
m_total_samples >= m_min_samples) // min samples okay
{
break; // Stop iterating
}

View File

@@ -0,0 +1,50 @@
/*
* Copyright 2026 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 with the LLVM exception
* (the "License"); you may not use this file except in compliance with
* the License.
*
* You may obtain a copy of the License at
*
* http://llvm.org/foundation/relicensing/LICENSE.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <nvbench/config.cuh>
#if defined(NVBENCH_IMPLICIT_SYSTEM_HEADER_GCC)
#pragma GCC system_header
#elif defined(NVBENCH_IMPLICIT_SYSTEM_HEADER_CLANG)
#pragma clang system_header
#elif defined(NVBENCH_IMPLICIT_SYSTEM_HEADER_MSVC)
#pragma system_header
#endif
#include <nvbench/stopping_criterion.cuh>
#include <nvbench/types.cuh>
namespace nvbench::detail
{
class sample_count_criterion final : public stopping_criterion_base
{
nvbench::int64_t m_total_samples{};
public:
sample_count_criterion();
protected:
virtual void do_initialize() override;
virtual void do_add_measurement(nvbench::float64_t measurement) override;
virtual bool do_is_finished() override;
};
} // namespace nvbench::detail

View File

@@ -0,0 +1,37 @@
/*
* Copyright 2026 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 with the LLVM exception
* (the "License"); you may not use this file except in compliance with
* the License.
*
* You may obtain a copy of the License at
*
* http://llvm.org/foundation/relicensing/LICENSE.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <nvbench/detail/sample_count_criterion.cuh>
namespace nvbench::detail
{
sample_count_criterion::sample_count_criterion()
: stopping_criterion_base{"sample-count", {{"target-samples", nvbench::int64_t{100}}}}
{}
void sample_count_criterion::do_initialize() { m_total_samples = 0; }
void sample_count_criterion::do_add_measurement(nvbench::float64_t) { ++m_total_samples; }
bool sample_count_criterion::do_is_finished()
{
return m_total_samples >= m_params.get_int64("target-samples");
}
} // namespace nvbench::detail

View File

@@ -60,6 +60,11 @@ bool stdrel_criterion::do_is_finished()
return false;
}
if (m_noise_tracker.empty())
{
return false;
}
// Noise has dropped below threshold
if (m_noise_tracker.back() < m_params.get_float64("max-noise"))
{

View File

@@ -20,6 +20,7 @@ set(test_srcs
reset_error.cu
ring_buffer.cu
runner.cu
sample_count_criterion.cu
state.cu
statistics.cu
state_generator.cu

View File

@@ -25,6 +25,8 @@ void test_standard_criteria_exist()
{
ASSERT(nvbench::criterion_manager::get().get_criterion("stdrel").get_name() == "stdrel");
ASSERT(nvbench::criterion_manager::get().get_criterion("entropy").get_name() == "entropy");
ASSERT(nvbench::criterion_manager::get().get_criterion("sample-count").get_name() ==
"sample-count");
}
class custom_criterion : public nvbench::stopping_criterion_base

View File

@@ -1307,6 +1307,45 @@ void test_stopping_criterion()
ASSERT(criterion_params.get_float64("max-angle") == 0.42);
ASSERT(criterion_params.get_float64("min-r2") == 0.6);
}
{ // Sample-count criterion default params
nvbench::option_parser parser;
parser.parse({
"--benchmark",
"DummyBench",
"--stopping-criterion",
"sample-count",
});
const auto &states = parser_to_states(parser);
ASSERT(states.size() == 1);
ASSERT(states[0].get_stopping_criterion() == "sample-count");
const nvbench::criterion_params &criterion_params = states[0].get_criterion_params();
ASSERT(criterion_params.has_value("target-samples"));
ASSERT(criterion_params.get_int64("target-samples") == 100);
}
{ // Sample-count criterion params are independent from min_samples
nvbench::option_parser parser;
parser.parse({
"--benchmark",
"DummyBench",
"--min-samples",
"7",
"--stopping-criterion",
"sample-count",
"--target-samples",
"123",
});
const auto &states = parser_to_states(parser);
ASSERT(states.size() == 1);
ASSERT(states[0].get_min_samples() == 7);
ASSERT(states[0].get_stopping_criterion() == "sample-count");
const nvbench::criterion_params &criterion_params = states[0].get_criterion_params();
ASSERT(criterion_params.has_value("target-samples"));
ASSERT(criterion_params.get_int64("target-samples") == 123);
}
{ // Unknown stopping criterion should throw
bool exception_thrown = false;
try

View File

@@ -0,0 +1,61 @@
/*
* Copyright 2026 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 with the LLVM exception
* (the "License"); you may not use this file except in compliance with
* the License.
*
* You may obtain a copy of the License at
*
* http://llvm.org/foundation/relicensing/LICENSE.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <nvbench/detail/sample_count_criterion.cuh>
#include <nvbench/stopping_criterion.cuh>
#include "test_asserts.cuh"
void test_default_target_samples()
{
nvbench::detail::sample_count_criterion criterion;
criterion.initialize(nvbench::criterion_params{});
for (int i = 0; i < 99; ++i)
{
criterion.add_measurement(1.0);
ASSERT(!criterion.is_finished());
}
criterion.add_measurement(1.0);
ASSERT(criterion.is_finished());
}
void test_custom_target_samples()
{
nvbench::criterion_params params;
params.set_int64("target-samples", 3);
nvbench::detail::sample_count_criterion criterion;
criterion.initialize(params);
criterion.add_measurement(1.0);
ASSERT(!criterion.is_finished());
criterion.add_measurement(1.0);
ASSERT(!criterion.is_finished());
criterion.add_measurement(1.0);
ASSERT(criterion.is_finished());
}
int main()
{
test_default_target_samples();
test_custom_target_samples();
}

View File

@@ -78,8 +78,24 @@ void test_stdrel()
ASSERT(!criterion.is_finished());
}
void test_stdrel_needs_enough_samples()
{
nvbench::criterion_params params;
params.set_float64("min-time", 0.0);
nvbench::detail::stdrel_criterion criterion;
criterion.initialize(params);
for (int i = 0; i < 4; ++i)
{
criterion.add_measurement(42.0);
}
ASSERT(!criterion.is_finished());
}
int main()
{
test_const();
test_stdrel();
test_stdrel_needs_enough_samples();
}