mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-11 22:40:03 +00:00
317 lines
9.8 KiB
Python
317 lines
9.8 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
"""
|
|
A standalone module for aggregating metrics.
|
|
|
|
Metrics can be logged from anywhere using the `log_*` functions defined
|
|
in this module. The logged values will be aggregated dynamically based
|
|
on the aggregation context in which the logging occurs. See the
|
|
:func:`aggregate` context manager for more details.
|
|
"""
|
|
|
|
import contextlib
|
|
import uuid
|
|
from collections import defaultdict
|
|
from typing import Callable, List, Optional
|
|
|
|
from .meters import *
|
|
|
|
|
|
# Aggregation contexts are considered "active" when inside the scope
|
|
# created by the :func:`aggregate` context manager.
|
|
_aggregators = OrderedDict()
|
|
_active_aggregators = OrderedDict()
|
|
_active_aggregators_cnt = defaultdict(lambda: 0)
|
|
|
|
|
|
def reset() -> None:
|
|
"""Reset all metrics aggregators."""
|
|
_aggregators.clear()
|
|
_active_aggregators.clear()
|
|
_active_aggregators_cnt.clear()
|
|
|
|
# The "default" aggregator observes all logged values.
|
|
_aggregators["default"] = MetersDict()
|
|
_active_aggregators["default"] = _aggregators["default"]
|
|
_active_aggregators_cnt["default"] = 1
|
|
|
|
|
|
reset()
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def aggregate(name: Optional[str] = None, new_root: bool = False):
|
|
"""Context manager to aggregate metrics under a given name.
|
|
|
|
Aggregations can be nested. If *new_root* is ``False``, then logged
|
|
metrics will be recorded along the entire stack of nested
|
|
aggregators, including a global "default" aggregator. If *new_root*
|
|
is ``True``, then this aggregator will be the root of a new
|
|
aggregation stack, thus bypassing any parent aggregators.
|
|
|
|
Note that aggregation contexts are uniquely identified by their
|
|
*name* (e.g., train, valid). Creating a context with an existing
|
|
name will reuse the corresponding :class:`MetersDict` instance.
|
|
If no name is given, then a temporary aggregator will be created.
|
|
|
|
Usage::
|
|
|
|
with metrics.aggregate("train"):
|
|
for step, batch in enumerate(epoch):
|
|
with metrics.aggregate("train_inner") as agg:
|
|
metrics.log_scalar("loss", get_loss(batch))
|
|
if step % log_interval == 0:
|
|
print(agg.get_smoothed_value("loss"))
|
|
agg.reset()
|
|
print(metrics.get_smoothed_values("train")["loss"])
|
|
|
|
Args:
|
|
name (str): name of the aggregation. Defaults to a
|
|
random/temporary name if not given explicitly.
|
|
new_root (bool): make this aggregation the root of a new
|
|
aggregation stack.
|
|
"""
|
|
if name is None:
|
|
# generate a temporary name
|
|
name = str(uuid.uuid4())
|
|
assert name not in _aggregators
|
|
agg = MetersDict()
|
|
else:
|
|
assert name != "default"
|
|
agg = _aggregators.setdefault(name, MetersDict())
|
|
|
|
if new_root:
|
|
backup_aggregators = _active_aggregators.copy()
|
|
_active_aggregators.clear()
|
|
backup_aggregators_cnt = _active_aggregators_cnt.copy()
|
|
_active_aggregators_cnt.clear()
|
|
|
|
_active_aggregators[name] = agg
|
|
_active_aggregators_cnt[name] += 1
|
|
|
|
yield agg
|
|
|
|
_active_aggregators_cnt[name] -= 1
|
|
if _active_aggregators_cnt[name] == 0 and name in _active_aggregators:
|
|
del _active_aggregators[name]
|
|
|
|
if new_root:
|
|
_active_aggregators.clear()
|
|
_active_aggregators.update(backup_aggregators)
|
|
_active_aggregators_cnt.clear()
|
|
_active_aggregators_cnt.update(backup_aggregators_cnt)
|
|
|
|
|
|
def get_active_aggregators() -> List[MetersDict]:
|
|
return list(_active_aggregators.values())
|
|
|
|
|
|
def log_scalar(
|
|
key: str,
|
|
value: float,
|
|
weight: float = 1,
|
|
priority: int = 10,
|
|
round: Optional[int] = None,
|
|
):
|
|
"""Log a scalar value.
|
|
|
|
Args:
|
|
key (str): name of the field to log
|
|
value (float): value to log
|
|
weight (float): weight that this value contributes to the average.
|
|
A weight of 0 will always log the latest value.
|
|
priority (int): smaller values are logged earlier in the output
|
|
round (Optional[int]): number of digits to round to when displaying
|
|
"""
|
|
for agg in get_active_aggregators():
|
|
if key not in agg:
|
|
agg.add_meter(key, AverageMeter(round=round), priority)
|
|
agg[key].update(value, weight)
|
|
|
|
|
|
def log_scalar_sum(
|
|
key: str,
|
|
value: float,
|
|
priority: int = 10,
|
|
round: Optional[int] = None,
|
|
):
|
|
"""Log a scalar value that is summed for reporting.
|
|
|
|
Args:
|
|
key (str): name of the field to log
|
|
value (float): value to log
|
|
priority (int): smaller values are logged earlier in the output
|
|
round (Optional[int]): number of digits to round to when displaying
|
|
"""
|
|
for agg in get_active_aggregators():
|
|
if key not in agg:
|
|
agg.add_meter(key, SumMeter(round=round), priority)
|
|
agg[key].update(value)
|
|
|
|
|
|
def log_derived(key: str, fn: Callable[[MetersDict], float], priority: int = 20):
|
|
"""Log a scalar value derived from other meters.
|
|
|
|
Args:
|
|
key (str): name of the field to log
|
|
fn (Callable[[MetersDict], float]): function that takes a single
|
|
argument *meters* and returns the derived value
|
|
priority (int): smaller values are logged earlier in the output
|
|
"""
|
|
for agg in get_active_aggregators():
|
|
if key not in agg:
|
|
agg.add_meter(key, MetersDict._DerivedMeter(fn), priority)
|
|
|
|
|
|
def log_speed(
|
|
key: str,
|
|
value: float,
|
|
priority: int = 30,
|
|
round: Optional[int] = None,
|
|
):
|
|
"""Log the rate of some quantity per second.
|
|
|
|
Args:
|
|
key (str): name of the field to log
|
|
value (float): value to log
|
|
priority (int): smaller values are logged earlier in the output
|
|
round (Optional[int]): number of digits to round to when displaying
|
|
"""
|
|
for agg in get_active_aggregators():
|
|
if key not in agg:
|
|
agg.add_meter(key, TimeMeter(round=round), priority)
|
|
agg[key].reset() # reset meter on the first call
|
|
else:
|
|
agg[key].update(value)
|
|
|
|
|
|
def log_start_time(key: str, priority: int = 40, round: Optional[int] = None):
|
|
"""Log the duration of some event in seconds.
|
|
|
|
The duration will be computed once :func:`log_stop_time` is called.
|
|
|
|
Args:
|
|
key (str): name of the field to log
|
|
priority (int): smaller values are logged earlier in the output
|
|
round (Optional[int]): number of digits to round to when displaying
|
|
"""
|
|
for agg in get_active_aggregators():
|
|
if key not in agg:
|
|
agg.add_meter(key, StopwatchMeter(round=round), priority)
|
|
agg[key].start()
|
|
|
|
|
|
def log_stop_time(key: str, weight: float = 0.0, prehook=None):
|
|
"""Log the duration of some event in seconds.
|
|
|
|
The duration will be computed since :func:`log_start_time` was called.
|
|
Set weight > 0 to report the average time instead of the sum.
|
|
|
|
Args:
|
|
key (str): name of the field to log
|
|
weight (float): weight that this time contributes to the average
|
|
prehook (function, no arguments): will be called before the timer
|
|
is stopped. For example, use prehook=torch.cuda.synchronize to
|
|
make sure all gpu operations are done before timer is stopped.
|
|
"""
|
|
for agg in get_active_aggregators():
|
|
if key in agg:
|
|
agg[key].stop(weight, prehook)
|
|
|
|
|
|
def log_custom(
|
|
new_meter_fn: Callable[[], Meter],
|
|
key: str,
|
|
*args,
|
|
priority: int = 50,
|
|
**kwargs,
|
|
):
|
|
"""Log using a custom Meter.
|
|
|
|
Any extra *args* or *kwargs* will be passed through to the Meter's
|
|
*update* method.
|
|
|
|
Args:
|
|
new_meter_fn (Callable[[], Meter]): function that returns a new
|
|
Meter instance
|
|
key (str): name of the field to log
|
|
priority (int): smaller values are logged earlier in the output
|
|
"""
|
|
for agg in get_active_aggregators():
|
|
if key not in agg:
|
|
agg.add_meter(key, new_meter_fn(), priority)
|
|
agg[key].update(*args, **kwargs)
|
|
|
|
|
|
def reset_meter(name: str, key: str) -> None:
|
|
"""Reset Meter instance aggregated under a given *name* and *key*."""
|
|
meter = get_meter(name, key)
|
|
if meter is not None:
|
|
meter.reset()
|
|
|
|
|
|
def reset_meters(name: str) -> None:
|
|
"""Reset Meter instances aggregated under a given *name*."""
|
|
meters = get_meters(name)
|
|
if meters is not None:
|
|
meters.reset()
|
|
|
|
|
|
def get_meter(name: str, key: str) -> Meter:
|
|
"""Get a single Meter instance aggregated under *name* and *key*.
|
|
|
|
Returns:
|
|
Meter or None if no metrics have been logged under *name* and *key*.
|
|
"""
|
|
if name not in _aggregators:
|
|
return None
|
|
return _aggregators[name].get(key, None)
|
|
|
|
|
|
def get_meters(name: str) -> MetersDict:
|
|
"""Get Meter instances aggregated under a given *name*.
|
|
|
|
Returns:
|
|
MetersDict or None if no metrics have been logged under *name*.
|
|
"""
|
|
return _aggregators.get(name, None)
|
|
|
|
|
|
def get_smoothed_value(name: str, key: str) -> float:
|
|
"""Get a single smoothed value.
|
|
|
|
Raises:
|
|
KeyError: if no metrics have been logged under *name* and *key*.
|
|
"""
|
|
return _aggregators[name].get_smoothed_value(key)
|
|
|
|
|
|
def get_smoothed_values(name: str) -> Dict[str, float]:
|
|
"""Get smoothed values aggregated under a given *name*.
|
|
|
|
Raises:
|
|
KeyError: if no metrics have been logged under *name*.
|
|
"""
|
|
return _aggregators[name].get_smoothed_values()
|
|
|
|
|
|
def state_dict():
|
|
return OrderedDict([(name, agg.state_dict()) for name, agg in _aggregators.items()])
|
|
|
|
|
|
def load_state_dict(state_dict):
|
|
for name, agg_state in state_dict.items():
|
|
_aggregators[name] = MetersDict()
|
|
_aggregators[name].load_state_dict(agg_state)
|
|
|
|
|
|
def xla_metrics_report():
|
|
try:
|
|
import torch_xla.debug.metrics as met
|
|
|
|
print(met.metrics_report())
|
|
except ImportError:
|
|
return
|