mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-30 19:31:20 +00:00
Add monkey patched fairseq package to run on python 3.11 (what is needed for our use of RVC at least)
This commit is contained in:
321
modules/voice_conversion/fairseq/logging/meters.py
Normal file
321
modules/voice_conversion/fairseq/logging/meters.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import bisect
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Optional
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
def type_as(a, b):
|
||||
if torch.is_tensor(a) and torch.is_tensor(b):
|
||||
return a.to(b)
|
||||
else:
|
||||
return a
|
||||
|
||||
except ImportError:
|
||||
torch = None
|
||||
|
||||
def type_as(a, b):
|
||||
return a
|
||||
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
np = None
|
||||
|
||||
|
||||
class Meter(object):
|
||||
"""Base class for Meters."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def smoothed_value(self) -> float:
|
||||
"""Smoothed value used for logging."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def safe_round(number, ndigits):
|
||||
if hasattr(number, "__round__"):
|
||||
return round(number, ndigits)
|
||||
elif torch is not None and torch.is_tensor(number) and number.numel() == 1:
|
||||
return safe_round(number.item(), ndigits)
|
||||
elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"):
|
||||
return safe_round(number.item(), ndigits)
|
||||
else:
|
||||
return number
|
||||
|
||||
|
||||
class AverageMeter(Meter):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self, round: Optional[int] = None):
|
||||
self.round = round
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = None # most recent update
|
||||
self.sum = 0 # sum from all updates
|
||||
self.count = 0 # total n from all updates
|
||||
|
||||
def update(self, val, n=1):
|
||||
if val is not None:
|
||||
self.val = val
|
||||
if n > 0:
|
||||
self.sum = type_as(self.sum, val) + (val * n)
|
||||
self.count = type_as(self.count, n) + n
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
"val": self.val,
|
||||
"sum": self.sum,
|
||||
"count": self.count,
|
||||
"round": self.round,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.val = state_dict["val"]
|
||||
self.sum = state_dict["sum"]
|
||||
self.count = state_dict["count"]
|
||||
self.round = state_dict.get("round", None)
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
return self.sum / self.count if self.count > 0 else self.val
|
||||
|
||||
@property
|
||||
def smoothed_value(self) -> float:
|
||||
val = self.avg
|
||||
if self.round is not None and val is not None:
|
||||
val = safe_round(val, self.round)
|
||||
return val
|
||||
|
||||
|
||||
class SumMeter(Meter):
|
||||
"""Computes and stores the sum"""
|
||||
|
||||
def __init__(self, round: Optional[int] = None):
|
||||
self.round = round
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.sum = 0 # sum from all updates
|
||||
|
||||
def update(self, val):
|
||||
if val is not None:
|
||||
self.sum = type_as(self.sum, val) + val
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
"sum": self.sum,
|
||||
"round": self.round,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.sum = state_dict["sum"]
|
||||
self.round = state_dict.get("round", None)
|
||||
|
||||
@property
|
||||
def smoothed_value(self) -> float:
|
||||
val = self.sum
|
||||
if self.round is not None and val is not None:
|
||||
val = safe_round(val, self.round)
|
||||
return val
|
||||
|
||||
|
||||
class TimeMeter(Meter):
|
||||
"""Computes the average occurrence of some event per second"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
init: int = 0,
|
||||
n: int = 0,
|
||||
round: Optional[int] = None,
|
||||
):
|
||||
self.round = round
|
||||
self.reset(init, n)
|
||||
|
||||
def reset(self, init=0, n=0):
|
||||
self.init = init
|
||||
self.start = time.perf_counter()
|
||||
self.n = n
|
||||
self.i = 0
|
||||
|
||||
def update(self, val=1):
|
||||
self.n = type_as(self.n, val) + val
|
||||
self.i += 1
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
"init": self.elapsed_time,
|
||||
"n": self.n,
|
||||
"round": self.round,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
if "start" in state_dict:
|
||||
# backwards compatibility for old state_dicts
|
||||
self.reset(init=state_dict["init"])
|
||||
else:
|
||||
self.reset(init=state_dict["init"], n=state_dict["n"])
|
||||
self.round = state_dict.get("round", None)
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
return self.n / self.elapsed_time
|
||||
|
||||
@property
|
||||
def elapsed_time(self):
|
||||
return self.init + (time.perf_counter() - self.start)
|
||||
|
||||
@property
|
||||
def smoothed_value(self) -> float:
|
||||
val = self.avg
|
||||
if self.round is not None and val is not None:
|
||||
val = safe_round(val, self.round)
|
||||
return val
|
||||
|
||||
|
||||
class StopwatchMeter(Meter):
|
||||
"""Computes the sum/avg duration of some event in seconds"""
|
||||
|
||||
def __init__(self, round: Optional[int] = None):
|
||||
self.round = round
|
||||
self.sum = 0
|
||||
self.n = 0
|
||||
self.start_time = None
|
||||
|
||||
def start(self):
|
||||
self.start_time = time.perf_counter()
|
||||
|
||||
def stop(self, n=1, prehook=None):
|
||||
if self.start_time is not None:
|
||||
if prehook is not None:
|
||||
prehook()
|
||||
delta = time.perf_counter() - self.start_time
|
||||
self.sum = self.sum + delta
|
||||
self.n = type_as(self.n, n) + n
|
||||
|
||||
def reset(self):
|
||||
self.sum = 0 # cumulative time during which stopwatch was active
|
||||
self.n = 0 # total n across all start/stop
|
||||
self.start()
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
"sum": self.sum,
|
||||
"n": self.n,
|
||||
"round": self.round,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.sum = state_dict["sum"]
|
||||
self.n = state_dict["n"]
|
||||
self.start_time = None
|
||||
self.round = state_dict.get("round", None)
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
return self.sum / self.n if self.n > 0 else self.sum
|
||||
|
||||
@property
|
||||
def elapsed_time(self):
|
||||
if self.start_time is None:
|
||||
return 0.0
|
||||
return time.perf_counter() - self.start_time
|
||||
|
||||
@property
|
||||
def smoothed_value(self) -> float:
|
||||
val = self.avg if self.sum > 0 else self.elapsed_time
|
||||
if self.round is not None and val is not None:
|
||||
val = safe_round(val, self.round)
|
||||
return val
|
||||
|
||||
|
||||
class MetersDict(OrderedDict):
|
||||
"""A sorted dictionary of :class:`Meters`.
|
||||
|
||||
Meters are sorted according to a priority that is given when the
|
||||
meter is first added to the dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.priorities = []
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
assert key not in self, "MetersDict doesn't support reassignment"
|
||||
priority, value = value
|
||||
bisect.insort(self.priorities, (priority, len(self.priorities), key))
|
||||
super().__setitem__(key, value)
|
||||
for _, _, key in self.priorities: # reorder dict to match priorities
|
||||
self.move_to_end(key)
|
||||
|
||||
def add_meter(self, key, meter, priority):
|
||||
self.__setitem__(key, (priority, meter))
|
||||
|
||||
def state_dict(self):
|
||||
return [
|
||||
(pri, key, self[key].__class__.__name__, self[key].state_dict())
|
||||
for pri, _, key in self.priorities
|
||||
# can't serialize DerivedMeter instances
|
||||
if not isinstance(self[key], MetersDict._DerivedMeter)
|
||||
]
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.clear()
|
||||
self.priorities.clear()
|
||||
for pri, key, meter_cls, meter_state in state_dict:
|
||||
meter = globals()[meter_cls]()
|
||||
meter.load_state_dict(meter_state)
|
||||
self.add_meter(key, meter, pri)
|
||||
|
||||
def get_smoothed_value(self, key: str) -> float:
|
||||
"""Get a single smoothed value."""
|
||||
meter = self[key]
|
||||
if isinstance(meter, MetersDict._DerivedMeter):
|
||||
return meter.fn(self)
|
||||
else:
|
||||
return meter.smoothed_value
|
||||
|
||||
def get_smoothed_values(self) -> Dict[str, float]:
|
||||
"""Get all smoothed values."""
|
||||
return OrderedDict(
|
||||
[
|
||||
(key, self.get_smoothed_value(key))
|
||||
for key in self.keys()
|
||||
if not key.startswith("_")
|
||||
]
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
"""Reset Meter instances."""
|
||||
for meter in self.values():
|
||||
if isinstance(meter, MetersDict._DerivedMeter):
|
||||
continue
|
||||
meter.reset()
|
||||
|
||||
class _DerivedMeter(Meter):
|
||||
"""A Meter whose values are derived from other Meters."""
|
||||
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
316
modules/voice_conversion/fairseq/logging/metrics.py
Normal file
316
modules/voice_conversion/fairseq/logging/metrics.py
Normal file
@@ -0,0 +1,316 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""
|
||||
A standalone module for aggregating metrics.
|
||||
|
||||
Metrics can be logged from anywhere using the `log_*` functions defined
|
||||
in this module. The logged values will be aggregated dynamically based
|
||||
on the aggregation context in which the logging occurs. See the
|
||||
:func:`aggregate` context manager for more details.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from .meters import *
|
||||
|
||||
|
||||
# Aggregation contexts are considered "active" when inside the scope
|
||||
# created by the :func:`aggregate` context manager.
|
||||
_aggregators = OrderedDict()
|
||||
_active_aggregators = OrderedDict()
|
||||
_active_aggregators_cnt = defaultdict(lambda: 0)
|
||||
|
||||
|
||||
def reset() -> None:
|
||||
"""Reset all metrics aggregators."""
|
||||
_aggregators.clear()
|
||||
_active_aggregators.clear()
|
||||
_active_aggregators_cnt.clear()
|
||||
|
||||
# The "default" aggregator observes all logged values.
|
||||
_aggregators["default"] = MetersDict()
|
||||
_active_aggregators["default"] = _aggregators["default"]
|
||||
_active_aggregators_cnt["default"] = 1
|
||||
|
||||
|
||||
reset()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def aggregate(name: Optional[str] = None, new_root: bool = False):
|
||||
"""Context manager to aggregate metrics under a given name.
|
||||
|
||||
Aggregations can be nested. If *new_root* is ``False``, then logged
|
||||
metrics will be recorded along the entire stack of nested
|
||||
aggregators, including a global "default" aggregator. If *new_root*
|
||||
is ``True``, then this aggregator will be the root of a new
|
||||
aggregation stack, thus bypassing any parent aggregators.
|
||||
|
||||
Note that aggregation contexts are uniquely identified by their
|
||||
*name* (e.g., train, valid). Creating a context with an existing
|
||||
name will reuse the corresponding :class:`MetersDict` instance.
|
||||
If no name is given, then a temporary aggregator will be created.
|
||||
|
||||
Usage::
|
||||
|
||||
with metrics.aggregate("train"):
|
||||
for step, batch in enumerate(epoch):
|
||||
with metrics.aggregate("train_inner") as agg:
|
||||
metrics.log_scalar("loss", get_loss(batch))
|
||||
if step % log_interval == 0:
|
||||
print(agg.get_smoothed_value("loss"))
|
||||
agg.reset()
|
||||
print(metrics.get_smoothed_values("train")["loss"])
|
||||
|
||||
Args:
|
||||
name (str): name of the aggregation. Defaults to a
|
||||
random/temporary name if not given explicitly.
|
||||
new_root (bool): make this aggregation the root of a new
|
||||
aggregation stack.
|
||||
"""
|
||||
if name is None:
|
||||
# generate a temporary name
|
||||
name = str(uuid.uuid4())
|
||||
assert name not in _aggregators
|
||||
agg = MetersDict()
|
||||
else:
|
||||
assert name != "default"
|
||||
agg = _aggregators.setdefault(name, MetersDict())
|
||||
|
||||
if new_root:
|
||||
backup_aggregators = _active_aggregators.copy()
|
||||
_active_aggregators.clear()
|
||||
backup_aggregators_cnt = _active_aggregators_cnt.copy()
|
||||
_active_aggregators_cnt.clear()
|
||||
|
||||
_active_aggregators[name] = agg
|
||||
_active_aggregators_cnt[name] += 1
|
||||
|
||||
yield agg
|
||||
|
||||
_active_aggregators_cnt[name] -= 1
|
||||
if _active_aggregators_cnt[name] == 0 and name in _active_aggregators:
|
||||
del _active_aggregators[name]
|
||||
|
||||
if new_root:
|
||||
_active_aggregators.clear()
|
||||
_active_aggregators.update(backup_aggregators)
|
||||
_active_aggregators_cnt.clear()
|
||||
_active_aggregators_cnt.update(backup_aggregators_cnt)
|
||||
|
||||
|
||||
def get_active_aggregators() -> List[MetersDict]:
|
||||
return list(_active_aggregators.values())
|
||||
|
||||
|
||||
def log_scalar(
|
||||
key: str,
|
||||
value: float,
|
||||
weight: float = 1,
|
||||
priority: int = 10,
|
||||
round: Optional[int] = None,
|
||||
):
|
||||
"""Log a scalar value.
|
||||
|
||||
Args:
|
||||
key (str): name of the field to log
|
||||
value (float): value to log
|
||||
weight (float): weight that this value contributes to the average.
|
||||
A weight of 0 will always log the latest value.
|
||||
priority (int): smaller values are logged earlier in the output
|
||||
round (Optional[int]): number of digits to round to when displaying
|
||||
"""
|
||||
for agg in get_active_aggregators():
|
||||
if key not in agg:
|
||||
agg.add_meter(key, AverageMeter(round=round), priority)
|
||||
agg[key].update(value, weight)
|
||||
|
||||
|
||||
def log_scalar_sum(
|
||||
key: str,
|
||||
value: float,
|
||||
priority: int = 10,
|
||||
round: Optional[int] = None,
|
||||
):
|
||||
"""Log a scalar value that is summed for reporting.
|
||||
|
||||
Args:
|
||||
key (str): name of the field to log
|
||||
value (float): value to log
|
||||
priority (int): smaller values are logged earlier in the output
|
||||
round (Optional[int]): number of digits to round to when displaying
|
||||
"""
|
||||
for agg in get_active_aggregators():
|
||||
if key not in agg:
|
||||
agg.add_meter(key, SumMeter(round=round), priority)
|
||||
agg[key].update(value)
|
||||
|
||||
|
||||
def log_derived(key: str, fn: Callable[[MetersDict], float], priority: int = 20):
|
||||
"""Log a scalar value derived from other meters.
|
||||
|
||||
Args:
|
||||
key (str): name of the field to log
|
||||
fn (Callable[[MetersDict], float]): function that takes a single
|
||||
argument *meters* and returns the derived value
|
||||
priority (int): smaller values are logged earlier in the output
|
||||
"""
|
||||
for agg in get_active_aggregators():
|
||||
if key not in agg:
|
||||
agg.add_meter(key, MetersDict._DerivedMeter(fn), priority)
|
||||
|
||||
|
||||
def log_speed(
|
||||
key: str,
|
||||
value: float,
|
||||
priority: int = 30,
|
||||
round: Optional[int] = None,
|
||||
):
|
||||
"""Log the rate of some quantity per second.
|
||||
|
||||
Args:
|
||||
key (str): name of the field to log
|
||||
value (float): value to log
|
||||
priority (int): smaller values are logged earlier in the output
|
||||
round (Optional[int]): number of digits to round to when displaying
|
||||
"""
|
||||
for agg in get_active_aggregators():
|
||||
if key not in agg:
|
||||
agg.add_meter(key, TimeMeter(round=round), priority)
|
||||
agg[key].reset() # reset meter on the first call
|
||||
else:
|
||||
agg[key].update(value)
|
||||
|
||||
|
||||
def log_start_time(key: str, priority: int = 40, round: Optional[int] = None):
|
||||
"""Log the duration of some event in seconds.
|
||||
|
||||
The duration will be computed once :func:`log_stop_time` is called.
|
||||
|
||||
Args:
|
||||
key (str): name of the field to log
|
||||
priority (int): smaller values are logged earlier in the output
|
||||
round (Optional[int]): number of digits to round to when displaying
|
||||
"""
|
||||
for agg in get_active_aggregators():
|
||||
if key not in agg:
|
||||
agg.add_meter(key, StopwatchMeter(round=round), priority)
|
||||
agg[key].start()
|
||||
|
||||
|
||||
def log_stop_time(key: str, weight: float = 0.0, prehook=None):
|
||||
"""Log the duration of some event in seconds.
|
||||
|
||||
The duration will be computed since :func:`log_start_time` was called.
|
||||
Set weight > 0 to report the average time instead of the sum.
|
||||
|
||||
Args:
|
||||
key (str): name of the field to log
|
||||
weight (float): weight that this time contributes to the average
|
||||
prehook (function, no arguments): will be called before the timer
|
||||
is stopped. For example, use prehook=torch.cuda.synchronize to
|
||||
make sure all gpu operations are done before timer is stopped.
|
||||
"""
|
||||
for agg in get_active_aggregators():
|
||||
if key in agg:
|
||||
agg[key].stop(weight, prehook)
|
||||
|
||||
|
||||
def log_custom(
|
||||
new_meter_fn: Callable[[], Meter],
|
||||
key: str,
|
||||
*args,
|
||||
priority: int = 50,
|
||||
**kwargs,
|
||||
):
|
||||
"""Log using a custom Meter.
|
||||
|
||||
Any extra *args* or *kwargs* will be passed through to the Meter's
|
||||
*update* method.
|
||||
|
||||
Args:
|
||||
new_meter_fn (Callable[[], Meter]): function that returns a new
|
||||
Meter instance
|
||||
key (str): name of the field to log
|
||||
priority (int): smaller values are logged earlier in the output
|
||||
"""
|
||||
for agg in get_active_aggregators():
|
||||
if key not in agg:
|
||||
agg.add_meter(key, new_meter_fn(), priority)
|
||||
agg[key].update(*args, **kwargs)
|
||||
|
||||
|
||||
def reset_meter(name: str, key: str) -> None:
|
||||
"""Reset Meter instance aggregated under a given *name* and *key*."""
|
||||
meter = get_meter(name, key)
|
||||
if meter is not None:
|
||||
meter.reset()
|
||||
|
||||
|
||||
def reset_meters(name: str) -> None:
|
||||
"""Reset Meter instances aggregated under a given *name*."""
|
||||
meters = get_meters(name)
|
||||
if meters is not None:
|
||||
meters.reset()
|
||||
|
||||
|
||||
def get_meter(name: str, key: str) -> Meter:
|
||||
"""Get a single Meter instance aggregated under *name* and *key*.
|
||||
|
||||
Returns:
|
||||
Meter or None if no metrics have been logged under *name* and *key*.
|
||||
"""
|
||||
if name not in _aggregators:
|
||||
return None
|
||||
return _aggregators[name].get(key, None)
|
||||
|
||||
|
||||
def get_meters(name: str) -> MetersDict:
|
||||
"""Get Meter instances aggregated under a given *name*.
|
||||
|
||||
Returns:
|
||||
MetersDict or None if no metrics have been logged under *name*.
|
||||
"""
|
||||
return _aggregators.get(name, None)
|
||||
|
||||
|
||||
def get_smoothed_value(name: str, key: str) -> float:
|
||||
"""Get a single smoothed value.
|
||||
|
||||
Raises:
|
||||
KeyError: if no metrics have been logged under *name* and *key*.
|
||||
"""
|
||||
return _aggregators[name].get_smoothed_value(key)
|
||||
|
||||
|
||||
def get_smoothed_values(name: str) -> Dict[str, float]:
|
||||
"""Get smoothed values aggregated under a given *name*.
|
||||
|
||||
Raises:
|
||||
KeyError: if no metrics have been logged under *name*.
|
||||
"""
|
||||
return _aggregators[name].get_smoothed_values()
|
||||
|
||||
|
||||
def state_dict():
|
||||
return OrderedDict([(name, agg.state_dict()) for name, agg in _aggregators.items()])
|
||||
|
||||
|
||||
def load_state_dict(state_dict):
|
||||
for name, agg_state in state_dict.items():
|
||||
_aggregators[name] = MetersDict()
|
||||
_aggregators[name].load_state_dict(agg_state)
|
||||
|
||||
|
||||
def xla_metrics_report():
|
||||
try:
|
||||
import torch_xla.debug.metrics as met
|
||||
|
||||
print(met.metrics_report())
|
||||
except ImportError:
|
||||
return
|
||||
582
modules/voice_conversion/fairseq/logging/progress_bar.py
Normal file
582
modules/voice_conversion/fairseq/logging/progress_bar.py
Normal file
@@ -0,0 +1,582 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Wrapper around various loggers and progress bars (e.g., tqdm).
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from numbers import Number
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .meters import AverageMeter, StopwatchMeter, TimeMeter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def progress_bar(
|
||||
iterator,
|
||||
log_format: Optional[str] = None,
|
||||
log_interval: int = 100,
|
||||
log_file: Optional[str] = None,
|
||||
epoch: Optional[int] = None,
|
||||
prefix: Optional[str] = None,
|
||||
aim_repo: Optional[str] = None,
|
||||
aim_run_hash: Optional[str] = None,
|
||||
aim_param_checkpoint_dir: Optional[str] = None,
|
||||
tensorboard_logdir: Optional[str] = None,
|
||||
default_log_format: str = "tqdm",
|
||||
wandb_project: Optional[str] = None,
|
||||
wandb_run_name: Optional[str] = None,
|
||||
azureml_logging: Optional[bool] = False,
|
||||
):
|
||||
if log_format is None:
|
||||
log_format = default_log_format
|
||||
if log_file is not None:
|
||||
handler = logging.FileHandler(filename=log_file)
|
||||
logger.addHandler(handler)
|
||||
|
||||
if log_format == "tqdm" and not sys.stderr.isatty():
|
||||
log_format = "simple"
|
||||
|
||||
if log_format == "json":
|
||||
bar = JsonProgressBar(iterator, epoch, prefix, log_interval)
|
||||
elif log_format == "none":
|
||||
bar = NoopProgressBar(iterator, epoch, prefix)
|
||||
elif log_format == "simple":
|
||||
bar = SimpleProgressBar(iterator, epoch, prefix, log_interval)
|
||||
elif log_format == "tqdm":
|
||||
bar = TqdmProgressBar(iterator, epoch, prefix)
|
||||
else:
|
||||
raise ValueError("Unknown log format: {}".format(log_format))
|
||||
|
||||
if aim_repo:
|
||||
bar = AimProgressBarWrapper(
|
||||
bar,
|
||||
aim_repo=aim_repo,
|
||||
aim_run_hash=aim_run_hash,
|
||||
aim_param_checkpoint_dir=aim_param_checkpoint_dir,
|
||||
)
|
||||
|
||||
if tensorboard_logdir:
|
||||
try:
|
||||
# [FB only] custom wrapper for TensorBoard
|
||||
import palaas # noqa
|
||||
|
||||
from .fb_tbmf_wrapper import FbTbmfWrapper
|
||||
|
||||
bar = FbTbmfWrapper(bar, log_interval)
|
||||
except ImportError:
|
||||
bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir)
|
||||
|
||||
if wandb_project:
|
||||
bar = WandBProgressBarWrapper(bar, wandb_project, run_name=wandb_run_name)
|
||||
|
||||
if azureml_logging:
|
||||
bar = AzureMLProgressBarWrapper(bar)
|
||||
|
||||
return bar
|
||||
|
||||
|
||||
def build_progress_bar(
|
||||
args,
|
||||
iterator,
|
||||
epoch: Optional[int] = None,
|
||||
prefix: Optional[str] = None,
|
||||
default: str = "tqdm",
|
||||
no_progress_bar: str = "none",
|
||||
):
|
||||
"""Legacy wrapper that takes an argparse.Namespace."""
|
||||
if getattr(args, "no_progress_bar", False):
|
||||
default = no_progress_bar
|
||||
if getattr(args, "distributed_rank", 0) == 0:
|
||||
tensorboard_logdir = getattr(args, "tensorboard_logdir", None)
|
||||
else:
|
||||
tensorboard_logdir = None
|
||||
return progress_bar(
|
||||
iterator,
|
||||
log_format=args.log_format,
|
||||
log_interval=args.log_interval,
|
||||
epoch=epoch,
|
||||
prefix=prefix,
|
||||
tensorboard_logdir=tensorboard_logdir,
|
||||
default_log_format=default,
|
||||
)
|
||||
|
||||
|
||||
def format_stat(stat):
|
||||
if isinstance(stat, Number):
|
||||
stat = "{:g}".format(stat)
|
||||
elif isinstance(stat, AverageMeter):
|
||||
stat = "{:.3f}".format(stat.avg)
|
||||
elif isinstance(stat, TimeMeter):
|
||||
stat = "{:g}".format(round(stat.avg))
|
||||
elif isinstance(stat, StopwatchMeter):
|
||||
stat = "{:g}".format(round(stat.sum))
|
||||
elif torch.is_tensor(stat):
|
||||
stat = stat.tolist()
|
||||
return stat
|
||||
|
||||
|
||||
class BaseProgressBar(object):
|
||||
"""Abstract class for progress bars."""
|
||||
|
||||
def __init__(self, iterable, epoch=None, prefix=None):
|
||||
self.iterable = iterable
|
||||
self.n = getattr(iterable, "n", 0)
|
||||
self.epoch = epoch
|
||||
self.prefix = ""
|
||||
if epoch is not None:
|
||||
self.prefix += "epoch {:03d}".format(epoch)
|
||||
if prefix is not None:
|
||||
self.prefix += (" | " if self.prefix != "" else "") + prefix
|
||||
|
||||
def __len__(self):
|
||||
return len(self.iterable)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
return False
|
||||
|
||||
def __iter__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def log(self, stats, tag=None, step=None):
|
||||
"""Log intermediate stats according to log_interval."""
|
||||
raise NotImplementedError
|
||||
|
||||
def print(self, stats, tag=None, step=None):
|
||||
"""Print end-of-epoch stats."""
|
||||
raise NotImplementedError
|
||||
|
||||
def update_config(self, config):
|
||||
"""Log latest configuration."""
|
||||
pass
|
||||
|
||||
def _str_commas(self, stats):
|
||||
return ", ".join(key + "=" + stats[key].strip() for key in stats.keys())
|
||||
|
||||
def _str_pipes(self, stats):
|
||||
return " | ".join(key + " " + stats[key].strip() for key in stats.keys())
|
||||
|
||||
def _format_stats(self, stats):
|
||||
postfix = OrderedDict(stats)
|
||||
# Preprocess stats according to datatype
|
||||
for key in postfix.keys():
|
||||
postfix[key] = str(format_stat(postfix[key]))
|
||||
return postfix
|
||||
|
||||
|
||||
@contextmanager
|
||||
def rename_logger(logger, new_name):
|
||||
old_name = logger.name
|
||||
if new_name is not None:
|
||||
logger.name = new_name
|
||||
yield logger
|
||||
logger.name = old_name
|
||||
|
||||
|
||||
class JsonProgressBar(BaseProgressBar):
|
||||
"""Log output in JSON format."""
|
||||
|
||||
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
|
||||
super().__init__(iterable, epoch, prefix)
|
||||
self.log_interval = log_interval
|
||||
self.i = None
|
||||
self.size = None
|
||||
|
||||
def __iter__(self):
|
||||
self.size = len(self.iterable)
|
||||
for i, obj in enumerate(self.iterable, start=self.n):
|
||||
self.i = i
|
||||
yield obj
|
||||
|
||||
def log(self, stats, tag=None, step=None):
|
||||
"""Log intermediate stats according to log_interval."""
|
||||
step = step or self.i or 0
|
||||
if step > 0 and self.log_interval is not None and step % self.log_interval == 0:
|
||||
update = (
|
||||
self.epoch - 1 + (self.i + 1) / float(self.size)
|
||||
if self.epoch is not None
|
||||
else None
|
||||
)
|
||||
stats = self._format_stats(stats, epoch=self.epoch, update=update)
|
||||
with rename_logger(logger, tag):
|
||||
logger.info(json.dumps(stats))
|
||||
|
||||
def print(self, stats, tag=None, step=None):
|
||||
"""Print end-of-epoch stats."""
|
||||
self.stats = stats
|
||||
if tag is not None:
|
||||
self.stats = OrderedDict(
|
||||
[(tag + "_" + k, v) for k, v in self.stats.items()]
|
||||
)
|
||||
stats = self._format_stats(self.stats, epoch=self.epoch)
|
||||
with rename_logger(logger, tag):
|
||||
logger.info(json.dumps(stats))
|
||||
|
||||
def _format_stats(self, stats, epoch=None, update=None):
|
||||
postfix = OrderedDict()
|
||||
if epoch is not None:
|
||||
postfix["epoch"] = epoch
|
||||
if update is not None:
|
||||
postfix["update"] = round(update, 3)
|
||||
# Preprocess stats according to datatype
|
||||
for key in stats.keys():
|
||||
postfix[key] = format_stat(stats[key])
|
||||
return postfix
|
||||
|
||||
|
||||
class NoopProgressBar(BaseProgressBar):
|
||||
"""No logging."""
|
||||
|
||||
def __init__(self, iterable, epoch=None, prefix=None):
|
||||
super().__init__(iterable, epoch, prefix)
|
||||
|
||||
def __iter__(self):
|
||||
for obj in self.iterable:
|
||||
yield obj
|
||||
|
||||
def log(self, stats, tag=None, step=None):
|
||||
"""Log intermediate stats according to log_interval."""
|
||||
pass
|
||||
|
||||
def print(self, stats, tag=None, step=None):
|
||||
"""Print end-of-epoch stats."""
|
||||
pass
|
||||
|
||||
|
||||
class SimpleProgressBar(BaseProgressBar):
|
||||
"""A minimal logger for non-TTY environments."""
|
||||
|
||||
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
|
||||
super().__init__(iterable, epoch, prefix)
|
||||
self.log_interval = log_interval
|
||||
self.i = None
|
||||
self.size = None
|
||||
|
||||
def __iter__(self):
|
||||
self.size = len(self.iterable)
|
||||
for i, obj in enumerate(self.iterable, start=self.n):
|
||||
self.i = i
|
||||
yield obj
|
||||
|
||||
def log(self, stats, tag=None, step=None):
|
||||
"""Log intermediate stats according to log_interval."""
|
||||
step = step or self.i or 0
|
||||
if step > 0 and self.log_interval is not None and step % self.log_interval == 0:
|
||||
stats = self._format_stats(stats)
|
||||
postfix = self._str_commas(stats)
|
||||
with rename_logger(logger, tag):
|
||||
logger.info(
|
||||
"{}: {:5d} / {:d} {}".format(
|
||||
self.prefix, self.i + 1, self.size, postfix
|
||||
)
|
||||
)
|
||||
|
||||
def print(self, stats, tag=None, step=None):
|
||||
"""Print end-of-epoch stats."""
|
||||
postfix = self._str_pipes(self._format_stats(stats))
|
||||
with rename_logger(logger, tag):
|
||||
logger.info("{} | {}".format(self.prefix, postfix))
|
||||
|
||||
|
||||
class TqdmProgressBar(BaseProgressBar):
|
||||
"""Log to tqdm."""
|
||||
|
||||
def __init__(self, iterable, epoch=None, prefix=None):
|
||||
super().__init__(iterable, epoch, prefix)
|
||||
from tqdm import tqdm
|
||||
|
||||
self.tqdm = tqdm(
|
||||
iterable,
|
||||
self.prefix,
|
||||
leave=False,
|
||||
disable=(logger.getEffectiveLevel() > logging.INFO),
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.tqdm)
|
||||
|
||||
def log(self, stats, tag=None, step=None):
|
||||
"""Log intermediate stats according to log_interval."""
|
||||
self.tqdm.set_postfix(self._format_stats(stats), refresh=False)
|
||||
|
||||
def print(self, stats, tag=None, step=None):
|
||||
"""Print end-of-epoch stats."""
|
||||
postfix = self._str_pipes(self._format_stats(stats))
|
||||
with rename_logger(logger, tag):
|
||||
logger.info("{} | {}".format(self.prefix, postfix))
|
||||
|
||||
|
||||
try:
|
||||
import functools
|
||||
|
||||
from aim import Repo as AimRepo
|
||||
|
||||
@functools.lru_cache()
|
||||
def get_aim_run(repo, run_hash):
|
||||
from aim import Run
|
||||
|
||||
return Run(run_hash=run_hash, repo=repo)
|
||||
|
||||
except ImportError:
|
||||
get_aim_run = None
|
||||
AimRepo = None
|
||||
|
||||
|
||||
class AimProgressBarWrapper(BaseProgressBar):
|
||||
"""Log to Aim."""
|
||||
|
||||
def __init__(self, wrapped_bar, aim_repo, aim_run_hash, aim_param_checkpoint_dir):
|
||||
self.wrapped_bar = wrapped_bar
|
||||
|
||||
if get_aim_run is None:
|
||||
self.run = None
|
||||
logger.warning("Aim not found, please install with: pip install aim")
|
||||
else:
|
||||
logger.info(f"Storing logs at Aim repo: {aim_repo}")
|
||||
|
||||
if not aim_run_hash:
|
||||
# Find run based on save_dir parameter
|
||||
query = f"run.checkpoint.save_dir == '{aim_param_checkpoint_dir}'"
|
||||
try:
|
||||
runs_generator = AimRepo(aim_repo).query_runs(query)
|
||||
run = next(runs_generator.iter_runs())
|
||||
aim_run_hash = run.run.hash
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if aim_run_hash:
|
||||
logger.info(f"Appending to run: {aim_run_hash}")
|
||||
|
||||
self.run = get_aim_run(aim_repo, aim_run_hash)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.wrapped_bar)
|
||||
|
||||
def log(self, stats, tag=None, step=None):
|
||||
"""Log intermediate stats to Aim."""
|
||||
self._log_to_aim(stats, tag, step)
|
||||
self.wrapped_bar.log(stats, tag=tag, step=step)
|
||||
|
||||
def print(self, stats, tag=None, step=None):
|
||||
"""Print end-of-epoch stats."""
|
||||
self._log_to_aim(stats, tag, step)
|
||||
self.wrapped_bar.print(stats, tag=tag, step=step)
|
||||
|
||||
def update_config(self, config):
|
||||
"""Log latest configuration."""
|
||||
if self.run is not None:
|
||||
for key in config:
|
||||
self.run.set(key, config[key], strict=False)
|
||||
self.wrapped_bar.update_config(config)
|
||||
|
||||
def _log_to_aim(self, stats, tag=None, step=None):
|
||||
if self.run is None:
|
||||
return
|
||||
|
||||
if step is None:
|
||||
step = stats["num_updates"]
|
||||
|
||||
if "train" in tag:
|
||||
context = {"tag": tag, "subset": "train"}
|
||||
elif "val" in tag:
|
||||
context = {"tag": tag, "subset": "val"}
|
||||
else:
|
||||
context = {"tag": tag}
|
||||
|
||||
for key in stats.keys() - {"num_updates"}:
|
||||
self.run.track(stats[key], name=key, step=step, context=context)
|
||||
|
||||
|
||||
try:
|
||||
_tensorboard_writers = {}
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except ImportError:
|
||||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
except ImportError:
|
||||
SummaryWriter = None
|
||||
|
||||
|
||||
def _close_writers():
|
||||
for w in _tensorboard_writers.values():
|
||||
w.close()
|
||||
|
||||
|
||||
atexit.register(_close_writers)
|
||||
|
||||
|
||||
class TensorboardProgressBarWrapper(BaseProgressBar):
|
||||
"""Log to tensorboard."""
|
||||
|
||||
def __init__(self, wrapped_bar, tensorboard_logdir):
|
||||
self.wrapped_bar = wrapped_bar
|
||||
self.tensorboard_logdir = tensorboard_logdir
|
||||
|
||||
if SummaryWriter is None:
|
||||
logger.warning(
|
||||
"tensorboard not found, please install with: pip install tensorboard"
|
||||
)
|
||||
|
||||
def _writer(self, key):
|
||||
if SummaryWriter is None:
|
||||
return None
|
||||
_writers = _tensorboard_writers
|
||||
if key not in _writers:
|
||||
_writers[key] = SummaryWriter(os.path.join(self.tensorboard_logdir, key))
|
||||
_writers[key].add_text("sys.argv", " ".join(sys.argv))
|
||||
return _writers[key]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.wrapped_bar)
|
||||
|
||||
def log(self, stats, tag=None, step=None):
|
||||
"""Log intermediate stats to tensorboard."""
|
||||
self._log_to_tensorboard(stats, tag, step)
|
||||
self.wrapped_bar.log(stats, tag=tag, step=step)
|
||||
|
||||
def print(self, stats, tag=None, step=None):
|
||||
"""Print end-of-epoch stats."""
|
||||
self._log_to_tensorboard(stats, tag, step)
|
||||
self.wrapped_bar.print(stats, tag=tag, step=step)
|
||||
|
||||
def update_config(self, config):
|
||||
"""Log latest configuration."""
|
||||
# TODO add hparams to Tensorboard
|
||||
self.wrapped_bar.update_config(config)
|
||||
|
||||
def _log_to_tensorboard(self, stats, tag=None, step=None):
|
||||
writer = self._writer(tag or "")
|
||||
if writer is None:
|
||||
return
|
||||
if step is None:
|
||||
step = stats["num_updates"]
|
||||
for key in stats.keys() - {"num_updates"}:
|
||||
if isinstance(stats[key], AverageMeter):
|
||||
writer.add_scalar(key, stats[key].val, step)
|
||||
elif isinstance(stats[key], Number):
|
||||
writer.add_scalar(key, stats[key], step)
|
||||
elif torch.is_tensor(stats[key]) and stats[key].numel() == 1:
|
||||
writer.add_scalar(key, stats[key].item(), step)
|
||||
writer.flush()
|
||||
|
||||
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
wandb = None
|
||||
|
||||
|
||||
class WandBProgressBarWrapper(BaseProgressBar):
|
||||
"""Log to Weights & Biases."""
|
||||
|
||||
def __init__(self, wrapped_bar, wandb_project, run_name=None):
|
||||
self.wrapped_bar = wrapped_bar
|
||||
if wandb is None:
|
||||
logger.warning("wandb not found, pip install wandb")
|
||||
return
|
||||
|
||||
# reinit=False to ensure if wandb.init() is called multiple times
|
||||
# within one process it still references the same run
|
||||
wandb.init(project=wandb_project, reinit=False, name=run_name)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.wrapped_bar)
|
||||
|
||||
def log(self, stats, tag=None, step=None):
|
||||
"""Log intermediate stats to tensorboard."""
|
||||
self._log_to_wandb(stats, tag, step)
|
||||
self.wrapped_bar.log(stats, tag=tag, step=step)
|
||||
|
||||
def print(self, stats, tag=None, step=None):
|
||||
"""Print end-of-epoch stats."""
|
||||
self._log_to_wandb(stats, tag, step)
|
||||
self.wrapped_bar.print(stats, tag=tag, step=step)
|
||||
|
||||
def update_config(self, config):
|
||||
"""Log latest configuration."""
|
||||
if wandb is not None:
|
||||
wandb.config.update(config)
|
||||
self.wrapped_bar.update_config(config)
|
||||
|
||||
def _log_to_wandb(self, stats, tag=None, step=None):
|
||||
if wandb is None:
|
||||
return
|
||||
if step is None:
|
||||
step = stats["num_updates"]
|
||||
|
||||
prefix = "" if tag is None else tag + "/"
|
||||
|
||||
for key in stats.keys() - {"num_updates"}:
|
||||
if isinstance(stats[key], AverageMeter):
|
||||
wandb.log({prefix + key: stats[key].val}, step=step)
|
||||
elif isinstance(stats[key], Number):
|
||||
wandb.log({prefix + key: stats[key]}, step=step)
|
||||
|
||||
|
||||
try:
|
||||
from azureml.core import Run
|
||||
except ImportError:
|
||||
Run = None
|
||||
|
||||
|
||||
class AzureMLProgressBarWrapper(BaseProgressBar):
|
||||
"""Log to Azure ML"""
|
||||
|
||||
def __init__(self, wrapped_bar):
|
||||
self.wrapped_bar = wrapped_bar
|
||||
if Run is None:
|
||||
logger.warning("azureml.core not found, pip install azureml-core")
|
||||
return
|
||||
self.run = Run.get_context()
|
||||
|
||||
def __exit__(self, *exc):
|
||||
if Run is not None:
|
||||
self.run.complete()
|
||||
return False
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.wrapped_bar)
|
||||
|
||||
def log(self, stats, tag=None, step=None):
|
||||
"""Log intermediate stats to AzureML"""
|
||||
self._log_to_azureml(stats, tag, step)
|
||||
self.wrapped_bar.log(stats, tag=tag, step=step)
|
||||
|
||||
def print(self, stats, tag=None, step=None):
|
||||
"""Print end-of-epoch stats"""
|
||||
self._log_to_azureml(stats, tag, step)
|
||||
self.wrapped_bar.print(stats, tag=tag, step=step)
|
||||
|
||||
def update_config(self, config):
|
||||
"""Log latest configuration."""
|
||||
self.wrapped_bar.update_config(config)
|
||||
|
||||
def _log_to_azureml(self, stats, tag=None, step=None):
|
||||
if Run is None:
|
||||
return
|
||||
if step is None:
|
||||
step = stats["num_updates"]
|
||||
|
||||
prefix = "" if tag is None else tag + "/"
|
||||
|
||||
for key in stats.keys() - {"num_updates"}:
|
||||
name = prefix + key
|
||||
if isinstance(stats[key], AverageMeter):
|
||||
self.run.log_row(name=name, **{"step": step, key: stats[key].val})
|
||||
elif isinstance(stats[key], Number):
|
||||
self.run.log_row(name=name, **{"step": step, key: stats[key]})
|
||||
Reference in New Issue
Block a user