Various bug fixes and optimizations for quantized training. Added untested custom adam8bit optimizer. Did some work on LoRM (dont use)

This commit is contained in:
Jaret Burkett
2024-11-20 09:16:55 -07:00
parent 6509ba4484
commit 894374b2e9
7 changed files with 241 additions and 18 deletions

View File

@@ -196,13 +196,15 @@ def copy_stochastic(
class Auto8bitTensor:
def __init__(self, data: Tensor, *args, **kwargs):
if isinstance(data, dict): # Add constructor from state dict
self._load_from_state_dict(data)
else:
abs_max = data.abs().max().item()
scale = abs_max / 127.0 if abs_max > 0 else 1.0
abs_max = data.abs().max().item()
scale = abs_max / 127.0 if abs_max > 0 else 1.0
self.quantized = (data / scale).round().clamp(-127, 127).to(torch.int8)
self.scale = scale
self.orig_dtype = data.dtype
self.quantized = (data / scale).round().clamp(-127, 127).to(torch.int8)
self.scale = scale
self.orig_dtype = data.dtype
def dequantize(self) -> Tensor:
return self.quantized.to(dtype=torch.float32) * self.scale
@@ -224,6 +226,23 @@ class Auto8bitTensor:
# If no dtype specified, just pass through to parent
return self.dequantize().to(*args, **kwargs)
def state_dict(self):
"""Returns a dictionary containing the current state of the tensor."""
return {
'quantized': self.quantized,
'scale': self.scale,
'orig_dtype': self.orig_dtype
}
def _load_from_state_dict(self, state_dict):
"""Loads the tensor state from a state dictionary."""
self.quantized = state_dict['quantized']
self.scale = state_dict['scale']
self.orig_dtype = state_dict['orig_dtype']
def __str__(self):
return f"Auto8bitTensor(scale={self.scale}, orig_dtype={self.orig_dtype})"
def stochastic_grad_accummulation(param):
if hasattr(param, "_accum_grad"):