mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Various bug fixes and optimizations for quantized training. Added untested custom adam8bit optimizer. Did some work on LoRM (dont use)
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import Iterable, Optional
|
||||
import weakref
|
||||
import copy
|
||||
import contextlib
|
||||
from toolkit.optimizers.optimizer_utils import copy_stochastic
|
||||
|
||||
import torch
|
||||
|
||||
@@ -43,7 +44,7 @@ class ExponentialMovingAverage:
|
||||
self,
|
||||
parameters: Iterable[torch.nn.Parameter] = None,
|
||||
decay: float = 0.995,
|
||||
use_num_updates: bool = True,
|
||||
use_num_updates: bool = False,
|
||||
# feeds back the decat to the parameter
|
||||
use_feedback: bool = False,
|
||||
param_multiplier: float = 1.0
|
||||
@@ -123,16 +124,32 @@ class ExponentialMovingAverage:
|
||||
one_minus_decay = 1.0 - decay
|
||||
with torch.no_grad():
|
||||
for s_param, param in zip(self.shadow_params, parameters):
|
||||
tmp = (s_param - param)
|
||||
s_param_float = s_param.float()
|
||||
if s_param.dtype != torch.float32:
|
||||
s_param_float = s_param_float.to(torch.float32)
|
||||
param_float = param
|
||||
if param.dtype != torch.float32:
|
||||
param_float = param_float.to(torch.float32)
|
||||
tmp = (s_param_float - param_float)
|
||||
# tmp will be a new tensor so we can do in-place
|
||||
tmp.mul_(one_minus_decay)
|
||||
s_param.sub_(tmp)
|
||||
|
||||
s_param_float.sub_(tmp)
|
||||
|
||||
update_param = False
|
||||
if self.use_feedback:
|
||||
param.add_(tmp)
|
||||
param_float.add_(tmp)
|
||||
update_param = True
|
||||
|
||||
if self.param_multiplier != 1.0:
|
||||
param.mul_(self.param_multiplier)
|
||||
param_float.mul_(self.param_multiplier)
|
||||
update_param = True
|
||||
|
||||
if s_param.dtype != torch.float32:
|
||||
copy_stochastic(s_param, s_param_float)
|
||||
|
||||
if update_param and param.dtype != torch.float32:
|
||||
copy_stochastic(param, param_float)
|
||||
|
||||
|
||||
def copy_to(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user