mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Various bug fixes and optimizations for quantized training. Added untested custom adam8bit optimizer. Did some work on LoRM (dont use)
This commit is contained in:
@@ -196,13 +196,15 @@ def copy_stochastic(
|
||||
|
||||
class Auto8bitTensor:
|
||||
def __init__(self, data: Tensor, *args, **kwargs):
|
||||
if isinstance(data, dict): # Add constructor from state dict
|
||||
self._load_from_state_dict(data)
|
||||
else:
|
||||
abs_max = data.abs().max().item()
|
||||
scale = abs_max / 127.0 if abs_max > 0 else 1.0
|
||||
|
||||
abs_max = data.abs().max().item()
|
||||
scale = abs_max / 127.0 if abs_max > 0 else 1.0
|
||||
|
||||
self.quantized = (data / scale).round().clamp(-127, 127).to(torch.int8)
|
||||
self.scale = scale
|
||||
self.orig_dtype = data.dtype
|
||||
self.quantized = (data / scale).round().clamp(-127, 127).to(torch.int8)
|
||||
self.scale = scale
|
||||
self.orig_dtype = data.dtype
|
||||
|
||||
def dequantize(self) -> Tensor:
|
||||
return self.quantized.to(dtype=torch.float32) * self.scale
|
||||
@@ -224,6 +226,23 @@ class Auto8bitTensor:
|
||||
# If no dtype specified, just pass through to parent
|
||||
return self.dequantize().to(*args, **kwargs)
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns a dictionary containing the current state of the tensor."""
|
||||
return {
|
||||
'quantized': self.quantized,
|
||||
'scale': self.scale,
|
||||
'orig_dtype': self.orig_dtype
|
||||
}
|
||||
|
||||
def _load_from_state_dict(self, state_dict):
|
||||
"""Loads the tensor state from a state dictionary."""
|
||||
self.quantized = state_dict['quantized']
|
||||
self.scale = state_dict['scale']
|
||||
self.orig_dtype = state_dict['orig_dtype']
|
||||
|
||||
def __str__(self):
|
||||
return f"Auto8bitTensor(scale={self.scale}, orig_dtype={self.orig_dtype})"
|
||||
|
||||
|
||||
def stochastic_grad_accummulation(param):
|
||||
if hasattr(param, "_accum_grad"):
|
||||
|
||||
Reference in New Issue
Block a user