mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
228 lines
7.9 KiB
Python
228 lines
7.9 KiB
Python
import torch
|
|
from torch import Tensor
|
|
from typing import Optional
|
|
from optimum.quanto import QBytesTensor
|
|
|
|
|
|
def compute_scale_for_dtype(tensor, dtype):
|
|
"""
|
|
Compute appropriate scale for the given tensor and target dtype.
|
|
|
|
Args:
|
|
tensor: Input tensor to be quantized
|
|
dtype: Target dtype for quantization
|
|
Returns:
|
|
Appropriate scale factor for the quantization
|
|
"""
|
|
if dtype == torch.int8:
|
|
abs_max = torch.max(torch.abs(tensor))
|
|
return abs_max / 127.0 if abs_max > 0 else 1.0
|
|
elif dtype == torch.uint8:
|
|
max_val = torch.max(tensor)
|
|
min_val = torch.min(tensor)
|
|
range_val = max_val - min_val
|
|
return range_val / 255.0 if range_val > 0 else 1.0
|
|
elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
|
# For float8, we typically want to preserve the magnitude of the values
|
|
# while fitting within the representable range of the format
|
|
abs_max = torch.max(torch.abs(tensor))
|
|
if dtype == torch.float8_e4m3fn:
|
|
# e4m3fn has range [-448, 448] with no infinities
|
|
max_representable = 448.0
|
|
else: # torch.float8_e5m2
|
|
# e5m2 has range [-57344, 57344] with infinities
|
|
max_representable = 57344.0
|
|
|
|
return abs_max / max_representable if abs_max > 0 else 1.0
|
|
else:
|
|
raise ValueError(f"Unsupported dtype for quantization: {dtype}")
|
|
|
|
def quantize_tensor(tensor, dtype):
|
|
"""
|
|
Quantize a floating-point tensor to the target dtype with appropriate scaling.
|
|
|
|
Args:
|
|
tensor: Input tensor (float)
|
|
dtype: Target dtype for quantization
|
|
Returns:
|
|
quantized_data: Quantized tensor
|
|
scale: Scale factor used
|
|
"""
|
|
scale = compute_scale_for_dtype(tensor, dtype)
|
|
|
|
if dtype == torch.int8:
|
|
quantized_data = torch.clamp(torch.round(tensor / scale), -128, 127).to(dtype)
|
|
elif dtype == torch.uint8:
|
|
quantized_data = torch.clamp(torch.round(tensor / scale), 0, 255).to(dtype)
|
|
elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
|
# For float8, we scale and then cast directly to the target type
|
|
# The casting operation will handle the appropriate rounding
|
|
scaled_tensor = tensor / scale
|
|
quantized_data = scaled_tensor.to(dtype)
|
|
else:
|
|
raise ValueError(f"Unsupported dtype for quantization: {dtype}")
|
|
|
|
return quantized_data, scale
|
|
|
|
|
|
def update_parameter(target, result_float):
|
|
"""
|
|
Updates a parameter tensor, handling both regular torch.Tensor and QBytesTensor cases
|
|
with proper rescaling for quantized tensors.
|
|
|
|
Args:
|
|
target: The parameter to update (either torch.Tensor or QBytesTensor)
|
|
result_float: The new values to assign (torch.Tensor)
|
|
"""
|
|
if isinstance(target, QBytesTensor):
|
|
# Get the target dtype from the existing quantized tensor
|
|
target_dtype = target._data.dtype
|
|
|
|
# Handle device placement
|
|
device = target._data.device
|
|
result_float = result_float.to(device)
|
|
|
|
# Compute new quantized values and scale
|
|
quantized_data, new_scale = quantize_tensor(result_float, target_dtype)
|
|
|
|
# Update the internal tensors with newly computed values
|
|
target._data.copy_(quantized_data)
|
|
target._scale.copy_(new_scale)
|
|
else:
|
|
# Regular tensor update
|
|
target.copy_(result_float)
|
|
|
|
|
|
def get_format_params(dtype: torch.dtype) -> tuple[int, int]:
|
|
"""
|
|
Returns (mantissa_bits, total_bits) for each format.
|
|
mantissa_bits excludes the implicit leading 1.
|
|
"""
|
|
if dtype == torch.float32:
|
|
return 23, 32
|
|
elif dtype == torch.bfloat16:
|
|
return 7, 16
|
|
elif dtype == torch.float16:
|
|
return 10, 16
|
|
elif dtype == torch.float8_e4m3fn:
|
|
return 3, 8
|
|
elif dtype == torch.float8_e5m2:
|
|
return 2, 8
|
|
elif dtype == torch.int8:
|
|
return 0, 8 # Int8 doesn't have mantissa bits
|
|
else:
|
|
raise ValueError(f"Unsupported dtype: {dtype}")
|
|
|
|
def copy_stochastic_bf16(target: torch.Tensor, source: torch.Tensor):
|
|
# adapted from https://github.com/Nerogar/OneTrainer/blob/411532e85f3cf2b52baa37597f9c145073d54511/modules/util/bf16_stochastic_rounding.py#L5
|
|
# create a random 16 bit integer
|
|
result = torch.randint_like(
|
|
source,
|
|
dtype=torch.int32,
|
|
low=0,
|
|
high=(1 << 16),
|
|
)
|
|
|
|
# add the random number to the lower 16 bit of the mantissa
|
|
result.add_(source.view(dtype=torch.int32))
|
|
|
|
# mask off the lower 16 bit of the mantissa
|
|
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
|
|
|
|
# copy the higher 16 bit into the target tensor
|
|
target.copy_(result.view(dtype=torch.float32))
|
|
|
|
del result
|
|
|
|
|
|
def copy_stochastic(target: torch.Tensor, source: torch.Tensor, eps: Optional[float] = None) -> None:
|
|
with torch.no_grad():
|
|
# assert if target is on cpu, throw error
|
|
assert target.device.type != 'cpu', "Target is on cpu!"
|
|
assert source.device.type != 'cpu', "Source is on cpu!"
|
|
|
|
if target.dtype == torch.float32:
|
|
target.copy_(source)
|
|
return
|
|
if target.dtype == torch.bfloat16:
|
|
copy_stochastic_bf16(target, source)
|
|
return
|
|
|
|
mantissa_bits, _ = get_format_params(target.dtype)
|
|
round_factor = 2 ** (23 - mantissa_bits)
|
|
|
|
# Add uniform noise for stochastic rounding
|
|
noise = torch.rand_like(source, device=source.device) - 0.5
|
|
rounded = torch.round(source * round_factor + noise)
|
|
result_float = rounded / round_factor
|
|
|
|
# Clamp for float8
|
|
if target.dtype == torch.float8_e4m3fn:
|
|
result_float.clamp_(-448.0, 448.0)
|
|
elif target.dtype == torch.float8_e5m2:
|
|
result_float.clamp_(-57344.0, 57344.0)
|
|
|
|
update_parameter(target, result_float)
|
|
|
|
|
|
class Auto8bitTensor:
|
|
def __init__(self, data: Tensor, *args, **kwargs):
|
|
if isinstance(data, dict): # Add constructor from state dict
|
|
self._load_from_state_dict(data)
|
|
else:
|
|
abs_max = data.abs().max().item()
|
|
scale = abs_max / 127.0 if abs_max > 0 else 1.0
|
|
|
|
self.quantized = (data / scale).round().clamp(-127, 127).to(torch.int8)
|
|
self.scale = scale
|
|
self.orig_dtype = data.dtype
|
|
|
|
def dequantize(self) -> Tensor:
|
|
return self.quantized.to(dtype=torch.float32) * self.scale
|
|
|
|
def to(self, *args, **kwargs):
|
|
# Handle the dtype argument whether it's positional or keyword
|
|
dtype = None
|
|
if args and isinstance(args[0], torch.dtype):
|
|
dtype = args[0]
|
|
args = args[1:]
|
|
elif 'dtype' in kwargs:
|
|
dtype = kwargs['dtype']
|
|
del kwargs['dtype']
|
|
|
|
if dtype is not None:
|
|
# First dequantize then convert to requested dtype
|
|
return self.dequantize().to(dtype=dtype, *args, **kwargs)
|
|
|
|
# If no dtype specified, just pass through to parent
|
|
return self.dequantize().to(*args, **kwargs)
|
|
|
|
def state_dict(self):
|
|
"""Returns a dictionary containing the current state of the tensor."""
|
|
return {
|
|
'quantized': self.quantized,
|
|
'scale': self.scale,
|
|
'orig_dtype': self.orig_dtype
|
|
}
|
|
|
|
def _load_from_state_dict(self, state_dict):
|
|
"""Loads the tensor state from a state dictionary."""
|
|
self.quantized = state_dict['quantized']
|
|
self.scale = state_dict['scale']
|
|
self.orig_dtype = state_dict['orig_dtype']
|
|
|
|
def __str__(self):
|
|
return f"Auto8bitTensor({self.dequantize()})"
|
|
|
|
|
|
def stochastic_grad_accummulation(param):
|
|
if hasattr(param, "_accum_grad"):
|
|
grad_fp32 = param._accum_grad.clone().to(torch.float32)
|
|
grad_fp32.add_(param.grad.to(torch.float32))
|
|
copy_stochastic(param._accum_grad, grad_fp32)
|
|
del grad_fp32
|
|
del param.grad
|
|
else:
|
|
param._accum_grad = param.grad.clone()
|
|
del param.grad
|