mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-21 15:09:01 +00:00
322 lines
8.5 KiB
Python
322 lines
8.5 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import bisect
|
|
import time
|
|
from collections import OrderedDict
|
|
from typing import Dict, Optional
|
|
|
|
try:
|
|
import torch
|
|
|
|
def type_as(a, b):
|
|
if torch.is_tensor(a) and torch.is_tensor(b):
|
|
return a.to(b)
|
|
else:
|
|
return a
|
|
|
|
except ImportError:
|
|
torch = None
|
|
|
|
def type_as(a, b):
|
|
return a
|
|
|
|
|
|
try:
|
|
import numpy as np
|
|
except ImportError:
|
|
np = None
|
|
|
|
|
|
class Meter(object):
|
|
"""Base class for Meters."""
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def state_dict(self):
|
|
return {}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
pass
|
|
|
|
def reset(self):
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def smoothed_value(self) -> float:
|
|
"""Smoothed value used for logging."""
|
|
raise NotImplementedError
|
|
|
|
|
|
def safe_round(number, ndigits):
|
|
if hasattr(number, "__round__"):
|
|
return round(number, ndigits)
|
|
elif torch is not None and torch.is_tensor(number) and number.numel() == 1:
|
|
return safe_round(number.item(), ndigits)
|
|
elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"):
|
|
return safe_round(number.item(), ndigits)
|
|
else:
|
|
return number
|
|
|
|
|
|
class AverageMeter(Meter):
|
|
"""Computes and stores the average and current value"""
|
|
|
|
def __init__(self, round: Optional[int] = None):
|
|
self.round = round
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.val = None # most recent update
|
|
self.sum = 0 # sum from all updates
|
|
self.count = 0 # total n from all updates
|
|
|
|
def update(self, val, n=1):
|
|
if val is not None:
|
|
self.val = val
|
|
if n > 0:
|
|
self.sum = type_as(self.sum, val) + (val * n)
|
|
self.count = type_as(self.count, n) + n
|
|
|
|
def state_dict(self):
|
|
return {
|
|
"val": self.val,
|
|
"sum": self.sum,
|
|
"count": self.count,
|
|
"round": self.round,
|
|
}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
self.val = state_dict["val"]
|
|
self.sum = state_dict["sum"]
|
|
self.count = state_dict["count"]
|
|
self.round = state_dict.get("round", None)
|
|
|
|
@property
|
|
def avg(self):
|
|
return self.sum / self.count if self.count > 0 else self.val
|
|
|
|
@property
|
|
def smoothed_value(self) -> float:
|
|
val = self.avg
|
|
if self.round is not None and val is not None:
|
|
val = safe_round(val, self.round)
|
|
return val
|
|
|
|
|
|
class SumMeter(Meter):
|
|
"""Computes and stores the sum"""
|
|
|
|
def __init__(self, round: Optional[int] = None):
|
|
self.round = round
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.sum = 0 # sum from all updates
|
|
|
|
def update(self, val):
|
|
if val is not None:
|
|
self.sum = type_as(self.sum, val) + val
|
|
|
|
def state_dict(self):
|
|
return {
|
|
"sum": self.sum,
|
|
"round": self.round,
|
|
}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
self.sum = state_dict["sum"]
|
|
self.round = state_dict.get("round", None)
|
|
|
|
@property
|
|
def smoothed_value(self) -> float:
|
|
val = self.sum
|
|
if self.round is not None and val is not None:
|
|
val = safe_round(val, self.round)
|
|
return val
|
|
|
|
|
|
class TimeMeter(Meter):
|
|
"""Computes the average occurrence of some event per second"""
|
|
|
|
def __init__(
|
|
self,
|
|
init: int = 0,
|
|
n: int = 0,
|
|
round: Optional[int] = None,
|
|
):
|
|
self.round = round
|
|
self.reset(init, n)
|
|
|
|
def reset(self, init=0, n=0):
|
|
self.init = init
|
|
self.start = time.perf_counter()
|
|
self.n = n
|
|
self.i = 0
|
|
|
|
def update(self, val=1):
|
|
self.n = type_as(self.n, val) + val
|
|
self.i += 1
|
|
|
|
def state_dict(self):
|
|
return {
|
|
"init": self.elapsed_time,
|
|
"n": self.n,
|
|
"round": self.round,
|
|
}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
if "start" in state_dict:
|
|
# backwards compatibility for old state_dicts
|
|
self.reset(init=state_dict["init"])
|
|
else:
|
|
self.reset(init=state_dict["init"], n=state_dict["n"])
|
|
self.round = state_dict.get("round", None)
|
|
|
|
@property
|
|
def avg(self):
|
|
return self.n / self.elapsed_time
|
|
|
|
@property
|
|
def elapsed_time(self):
|
|
return self.init + (time.perf_counter() - self.start)
|
|
|
|
@property
|
|
def smoothed_value(self) -> float:
|
|
val = self.avg
|
|
if self.round is not None and val is not None:
|
|
val = safe_round(val, self.round)
|
|
return val
|
|
|
|
|
|
class StopwatchMeter(Meter):
|
|
"""Computes the sum/avg duration of some event in seconds"""
|
|
|
|
def __init__(self, round: Optional[int] = None):
|
|
self.round = round
|
|
self.sum = 0
|
|
self.n = 0
|
|
self.start_time = None
|
|
|
|
def start(self):
|
|
self.start_time = time.perf_counter()
|
|
|
|
def stop(self, n=1, prehook=None):
|
|
if self.start_time is not None:
|
|
if prehook is not None:
|
|
prehook()
|
|
delta = time.perf_counter() - self.start_time
|
|
self.sum = self.sum + delta
|
|
self.n = type_as(self.n, n) + n
|
|
|
|
def reset(self):
|
|
self.sum = 0 # cumulative time during which stopwatch was active
|
|
self.n = 0 # total n across all start/stop
|
|
self.start()
|
|
|
|
def state_dict(self):
|
|
return {
|
|
"sum": self.sum,
|
|
"n": self.n,
|
|
"round": self.round,
|
|
}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
self.sum = state_dict["sum"]
|
|
self.n = state_dict["n"]
|
|
self.start_time = None
|
|
self.round = state_dict.get("round", None)
|
|
|
|
@property
|
|
def avg(self):
|
|
return self.sum / self.n if self.n > 0 else self.sum
|
|
|
|
@property
|
|
def elapsed_time(self):
|
|
if self.start_time is None:
|
|
return 0.0
|
|
return time.perf_counter() - self.start_time
|
|
|
|
@property
|
|
def smoothed_value(self) -> float:
|
|
val = self.avg if self.sum > 0 else self.elapsed_time
|
|
if self.round is not None and val is not None:
|
|
val = safe_round(val, self.round)
|
|
return val
|
|
|
|
|
|
class MetersDict(OrderedDict):
|
|
"""A sorted dictionary of :class:`Meters`.
|
|
|
|
Meters are sorted according to a priority that is given when the
|
|
meter is first added to the dictionary.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.priorities = []
|
|
|
|
def __setitem__(self, key, value):
|
|
assert key not in self, "MetersDict doesn't support reassignment"
|
|
priority, value = value
|
|
bisect.insort(self.priorities, (priority, len(self.priorities), key))
|
|
super().__setitem__(key, value)
|
|
for _, _, key in self.priorities: # reorder dict to match priorities
|
|
self.move_to_end(key)
|
|
|
|
def add_meter(self, key, meter, priority):
|
|
self.__setitem__(key, (priority, meter))
|
|
|
|
def state_dict(self):
|
|
return [
|
|
(pri, key, self[key].__class__.__name__, self[key].state_dict())
|
|
for pri, _, key in self.priorities
|
|
# can't serialize DerivedMeter instances
|
|
if not isinstance(self[key], MetersDict._DerivedMeter)
|
|
]
|
|
|
|
def load_state_dict(self, state_dict):
|
|
self.clear()
|
|
self.priorities.clear()
|
|
for pri, key, meter_cls, meter_state in state_dict:
|
|
meter = globals()[meter_cls]()
|
|
meter.load_state_dict(meter_state)
|
|
self.add_meter(key, meter, pri)
|
|
|
|
def get_smoothed_value(self, key: str) -> float:
|
|
"""Get a single smoothed value."""
|
|
meter = self[key]
|
|
if isinstance(meter, MetersDict._DerivedMeter):
|
|
return meter.fn(self)
|
|
else:
|
|
return meter.smoothed_value
|
|
|
|
def get_smoothed_values(self) -> Dict[str, float]:
|
|
"""Get all smoothed values."""
|
|
return OrderedDict(
|
|
[
|
|
(key, self.get_smoothed_value(key))
|
|
for key in self.keys()
|
|
if not key.startswith("_")
|
|
]
|
|
)
|
|
|
|
def reset(self):
|
|
"""Reset Meter instances."""
|
|
for meter in self.values():
|
|
if isinstance(meter, MetersDict._DerivedMeter):
|
|
continue
|
|
meter.reset()
|
|
|
|
class _DerivedMeter(Meter):
|
|
"""A Meter whose values are derived from other Meters."""
|
|
|
|
def __init__(self, fn):
|
|
self.fn = fn
|
|
|
|
def reset(self):
|
|
pass
|