mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-12 06:50:09 +00:00
71 lines
2.6 KiB
Python
71 lines
2.6 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
class DynamicLossScaler(object):
|
|
def __init__(
|
|
self,
|
|
init_scale=2.0**15,
|
|
scale_factor=2.0,
|
|
scale_window=2000,
|
|
tolerance=0.0,
|
|
threshold=None,
|
|
min_loss_scale=1e-4,
|
|
):
|
|
self.loss_scale = init_scale
|
|
self.scale_factor = scale_factor
|
|
self.scale_window = scale_window
|
|
self.tolerance = tolerance
|
|
self.threshold = threshold
|
|
self._iter = 0
|
|
self._last_overflow_iter = -1
|
|
self._last_rescale_iter = -1
|
|
self._overflows_since_rescale = 0
|
|
self.min_loss_scale = min_loss_scale
|
|
|
|
def scale(self, outputs):
|
|
return self.loss_scale * outputs
|
|
|
|
def update(self):
|
|
if (self._iter - self._last_overflow_iter) % self.scale_window == 0:
|
|
self.loss_scale *= self.scale_factor
|
|
self._last_rescale_iter = self._iter
|
|
self._iter += 1
|
|
|
|
def _decrease_loss_scale(self):
|
|
self.loss_scale /= self.scale_factor
|
|
if self.threshold is not None:
|
|
self.loss_scale = max(self.loss_scale, self.threshold)
|
|
|
|
def check_overflow(self, grad_norm):
|
|
# detect inf and nan
|
|
if grad_norm == float("inf") or grad_norm != grad_norm:
|
|
# overflow has occured
|
|
prev_scale = self.loss_scale
|
|
iter_since_rescale = self._iter - self._last_rescale_iter
|
|
|
|
self._last_overflow_iter = self._iter
|
|
self._overflows_since_rescale += 1
|
|
pct_overflow = self._overflows_since_rescale / float(iter_since_rescale)
|
|
if pct_overflow >= self.tolerance:
|
|
self._decrease_loss_scale()
|
|
self._last_rescale_iter = self._iter
|
|
self._overflows_since_rescale = 0
|
|
|
|
if self.loss_scale <= self.min_loss_scale:
|
|
# Use FloatingPointError as an uncommon error that parent
|
|
# functions can safely catch to stop training.
|
|
self.loss_scale = prev_scale
|
|
raise FloatingPointError(
|
|
(
|
|
"Minimum loss scale reached ({}). Your loss is probably exploding. "
|
|
"Try lowering the learning rate, using gradient clipping or "
|
|
"increasing the batch size."
|
|
).format(self.min_loss_scale)
|
|
)
|
|
|
|
self._iter += 1
|
|
raise OverflowError("setting loss scale to: " + str(self.loss_scale))
|