Implement DecisionReason, tracking and summarisation

- Add DecisionReason(code, message) and internal
  TimingDecision(status, reason).
- SummaryComparison now carries reason
- ComparisonStats now aggregates undecided reasons.
- Final summary prints a reason breakdown only when
  undecided reasons exist, e.g.:

  - Undecided   (comparison requires more evidence): 3
    - Reasons:
      - noise_too_high: 2 (relative dispersion is too
                           high to declare same)
      - weak_interval_overlap: 1 (timing intervals do not
                 overlap strongly enough to declare same)
This commit is contained in:
Oleksandr Pavlyk
2026-06-03 07:52:25 -05:00
parent 6de54fa07a
commit 65abfbcfb2
2 changed files with 178 additions and 25 deletions

View File

@@ -9,7 +9,7 @@ import os
import sys
import warnings
from collections import Counter
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from functools import cached_property
from typing import Any, Callable, Mapping
@@ -128,6 +128,18 @@ class ComparisonStatus(str, Enum):
SLOW = "SLOW"
@dataclass(frozen=True)
class DecisionReason:
code: str
message: str
@dataclass(frozen=True)
class TimingDecision:
status: ComparisonStatus
reason: DecisionReason
@dataclass(frozen=True)
class SummaryComparison:
ref_estimate: TimeEstimate
@@ -140,6 +152,7 @@ class SummaryComparison:
frac_diff: float
max_noise: float | None
status: ComparisonStatus
reason: DecisionReason
@dataclass
@@ -150,13 +163,18 @@ class ComparisonStats:
regression_count: int = 0
undecided_count: int = 0
unknown_count: int = 0
undecided_reasons: Counter[DecisionReason] = field(default_factory=Counter)
def record(self, status: ComparisonStatus) -> None:
def record(
self, status: ComparisonStatus, reason: DecisionReason | None = None
) -> None:
self.config_count += 1
if status == ComparisonStatus.UNKNOWN:
self.unknown_count += 1
elif status == ComparisonStatus.UNDECIDED:
self.undecided_count += 1
if reason is not None:
self.undecided_reasons[reason] += 1
elif status == ComparisonStatus.SAME:
self.pass_count += 1
elif status == ComparisonStatus.FAST:
@@ -553,6 +571,12 @@ def compute_timing_interval(timing):
return None
def make_decision(status, code, message):
return TimingDecision(
status=status, reason=DecisionReason(code=code, message=message)
)
def compare_intervals_for_clear_gap(ref_interval, cmp_interval):
# These ratios are equivalent to log(ref/cmp) >= log(1 + delta), but avoid
# evaluating logarithms on every comparison.
@@ -624,28 +648,52 @@ def confirm_clear_gap_with_clock_rate(
status, ref_timing, cmp_timing, ref_interval, cmp_interval
):
if ref_timing.sm_clock_rate_mean is None or cmp_timing.sm_clock_rate_mean is None:
return ComparisonStatus.UNDECIDED
return make_decision(
ComparisonStatus.UNDECIDED,
"missing_clock_rate",
"clear timing gap was not confirmed because SM clock summaries are unavailable",
)
ref_cycles = scale_interval(ref_interval, ref_timing.sm_clock_rate_mean)
cmp_cycles = scale_interval(cmp_interval, cmp_timing.sm_clock_rate_mean)
if ref_cycles is None or cmp_cycles is None:
return ComparisonStatus.UNDECIDED
return make_decision(
ComparisonStatus.UNDECIDED,
"invalid_clock_rate",
"clear timing gap was not confirmed because SM clock summaries are invalid",
)
cycle_status = compare_intervals_for_clear_gap(ref_cycles, cmp_cycles)
if cycle_status == status:
return status
return ComparisonStatus.UNDECIDED
return make_decision(
status,
"clear_gap_confirmed_by_cycles",
"clear timing gap was confirmed by SM-clock-adjusted cycle intervals",
)
return make_decision(
ComparisonStatus.UNDECIDED,
"cycle_gap_not_confirmed",
"clear timing gap was not confirmed by SM-clock-adjusted cycle intervals",
)
def compare_timings_for_clear_gap(ref_timing, cmp_timing):
ref_interval = compute_timing_interval(ref_timing)
cmp_interval = compute_timing_interval(cmp_timing)
if ref_interval is None or cmp_interval is None:
return ComparisonStatus.UNDECIDED
return make_decision(
ComparisonStatus.UNDECIDED,
"missing_interval",
"could not construct comparable timing intervals",
)
status = compare_intervals_for_clear_gap(ref_interval, cmp_interval)
if status is None:
return ComparisonStatus.UNDECIDED
return make_decision(
ComparisonStatus.UNDECIDED,
"no_clear_gap",
"timing intervals do not have a sufficient clear gap",
)
return confirm_clear_gap_with_clock_rate(
status, ref_timing, cmp_timing, ref_interval, cmp_interval
@@ -654,38 +702,81 @@ def compare_timings_for_clear_gap(ref_timing, cmp_timing):
def compare_intervals_for_same(ref_interval, cmp_interval):
if not centers_are_close(ref_interval.center, cmp_interval.center):
return ComparisonStatus.UNDECIDED
return make_decision(
ComparisonStatus.UNDECIDED,
"centers_not_close",
"timing centers are not close enough to declare same",
)
if not intervals_overlap_strongly(ref_interval, cmp_interval):
return ComparisonStatus.UNDECIDED
return ComparisonStatus.SAME
return make_decision(
ComparisonStatus.UNDECIDED,
"weak_interval_overlap",
"timing intervals do not overlap strongly enough to declare same",
)
return make_decision(
ComparisonStatus.SAME,
"same_summary",
"timing centers are close and intervals overlap strongly",
)
def confirm_same_with_clock_rate(ref_timing, cmp_timing, ref_interval, cmp_interval):
if ref_timing.sm_clock_rate_mean is None or cmp_timing.sm_clock_rate_mean is None:
return ComparisonStatus.SAME
return make_decision(
ComparisonStatus.SAME,
"same_without_clock_rate",
"timing centers are close and intervals overlap strongly; SM clock summaries are unavailable",
)
ref_cycles = scale_interval(ref_interval, ref_timing.sm_clock_rate_mean)
cmp_cycles = scale_interval(cmp_interval, cmp_timing.sm_clock_rate_mean)
if ref_cycles is None or cmp_cycles is None:
return ComparisonStatus.UNDECIDED
return make_decision(
ComparisonStatus.UNDECIDED,
"invalid_clock_rate",
"same decision was not confirmed because SM clock summaries are invalid",
)
return compare_intervals_for_same(ref_cycles, cmp_cycles)
decision = compare_intervals_for_same(ref_cycles, cmp_cycles)
if decision.status == ComparisonStatus.SAME:
return make_decision(
ComparisonStatus.SAME,
"same_confirmed_by_cycles",
"timing and SM-clock-adjusted cycle intervals both support same",
)
return make_decision(
ComparisonStatus.UNDECIDED,
"cycle_same_not_confirmed",
"same decision was not confirmed by SM-clock-adjusted cycle intervals",
)
def compare_timings_for_same(ref_timing, cmp_timing, ref_noise, cmp_noise):
if not has_finite_noise(ref_noise) or not has_finite_noise(cmp_noise):
return ComparisonStatus.UNDECIDED
return make_decision(
ComparisonStatus.UNDECIDED,
"noise_unavailable",
"relative dispersion is unavailable or non-finite",
)
if max(ref_noise, cmp_noise) > SAME_RELATIVE_DISPERSION_CEILING:
return ComparisonStatus.UNDECIDED
return make_decision(
ComparisonStatus.UNDECIDED,
"noise_too_high",
"relative dispersion is too high to declare same",
)
ref_interval = compute_timing_interval(ref_timing)
cmp_interval = compute_timing_interval(cmp_timing)
if ref_interval is None or cmp_interval is None:
return ComparisonStatus.UNDECIDED
return make_decision(
ComparisonStatus.UNDECIDED,
"missing_interval",
"could not construct comparable timing intervals",
)
status = compare_intervals_for_same(ref_interval, cmp_interval)
if status != ComparisonStatus.SAME:
return status
decision = compare_intervals_for_same(ref_interval, cmp_interval)
if decision.status != ComparisonStatus.SAME:
return decision
return confirm_same_with_clock_rate(
ref_timing, cmp_timing, ref_interval, cmp_interval
@@ -790,9 +881,14 @@ def compare_gpu_timings(ref_timing, cmp_timing):
else:
max_noise = max(ref_noise, cmp_noise)
status = compare_timings_for_clear_gap(ref_timing, cmp_timing)
if status == ComparisonStatus.UNDECIDED:
status = compare_timings_for_same(ref_timing, cmp_timing, ref_noise, cmp_noise)
decision = compare_timings_for_clear_gap(ref_timing, cmp_timing)
if decision.status == ComparisonStatus.UNDECIDED and decision.reason.code in {
"no_clear_gap",
"missing_interval",
}:
decision = compare_timings_for_same(
ref_timing, cmp_timing, ref_noise, cmp_noise
)
return SummaryComparison(
ref_estimate=ref_estimate,
@@ -804,7 +900,8 @@ def compare_gpu_timings(ref_timing, cmp_timing):
diff=diff,
frac_diff=frac_diff,
max_noise=max_noise,
status=status,
status=decision.status,
reason=decision.reason,
)
@@ -1295,7 +1392,7 @@ def compare_benches(
comparison.ref_noise
)
run_data.stats.record(comparison.status)
run_data.stats.record(comparison.status, comparison.reason)
status = colorize_comparison_status(comparison.status, no_color)
if abs(comparison.frac_diff) >= threshold:
@@ -1629,6 +1726,10 @@ def main() -> int:
print(
f" - Undecided (comparison requires more evidence): {stats.undecided_count}"
)
if stats.undecided_reasons:
print(" - Reasons:")
for reason, count in stats.undecided_reasons.most_common():
print(f" - {reason.code}: {count} ({reason.message})")
print(f" - Unknown (infinite or unavailable noise): {stats.unknown_count}")
return 0