Various bug fixes and optimizations for quantized training. Added untested custom adam8bit optimizer. Did some work on LoRM (dont use)

This commit is contained in:
Jaret Burkett
2024-11-20 09:16:55 -07:00
parent 6509ba4484
commit 894374b2e9
7 changed files with 241 additions and 18 deletions

View File

@@ -5,6 +5,7 @@ from typing import Iterable, Optional
import weakref
import copy
import contextlib
from toolkit.optimizers.optimizer_utils import copy_stochastic
import torch
@@ -43,7 +44,7 @@ class ExponentialMovingAverage:
self,
parameters: Iterable[torch.nn.Parameter] = None,
decay: float = 0.995,
use_num_updates: bool = True,
use_num_updates: bool = False,
# feeds back the decat to the parameter
use_feedback: bool = False,
param_multiplier: float = 1.0
@@ -123,16 +124,32 @@ class ExponentialMovingAverage:
one_minus_decay = 1.0 - decay
with torch.no_grad():
for s_param, param in zip(self.shadow_params, parameters):
tmp = (s_param - param)
s_param_float = s_param.float()
if s_param.dtype != torch.float32:
s_param_float = s_param_float.to(torch.float32)
param_float = param
if param.dtype != torch.float32:
param_float = param_float.to(torch.float32)
tmp = (s_param_float - param_float)
# tmp will be a new tensor so we can do in-place
tmp.mul_(one_minus_decay)
s_param.sub_(tmp)
s_param_float.sub_(tmp)
update_param = False
if self.use_feedback:
param.add_(tmp)
param_float.add_(tmp)
update_param = True
if self.param_multiplier != 1.0:
param.mul_(self.param_multiplier)
param_float.mul_(self.param_multiplier)
update_param = True
if s_param.dtype != torch.float32:
copy_stochastic(s_param, s_param_float)
if update_param and param.dtype != torch.float32:
copy_stochastic(param, param_float)
def copy_to(
self,