Files
ai-toolkit/toolkit/optimizers/optimizer_utils.py

146 lines
4.6 KiB
Python

import torch
from torch import Tensor
from typing import Optional
def get_format_params(dtype: torch.dtype) -> tuple[int, int]:
"""
Returns (mantissa_bits, total_bits) for each format.
mantissa_bits excludes the implicit leading 1.
"""
if dtype == torch.float32:
return 23, 32
elif dtype == torch.bfloat16:
return 7, 16
elif dtype == torch.float16:
return 10, 16
elif dtype == torch.float8_e4m3fn:
return 3, 8
elif dtype == torch.float8_e5m2:
return 2, 8
elif dtype == torch.int8:
return 0, 8 # Int8 doesn't have mantissa bits
else:
raise ValueError(f"Unsupported dtype: {dtype}")
def copy_stochastic(
target: torch.Tensor,
source: torch.Tensor,
eps: Optional[float] = None
) -> None:
"""
Performs stochastic rounding from source tensor to target tensor.
Args:
target: Destination tensor (determines the target format)
source: Source tensor (typically float32)
eps: Optional minimum value for stochastic rounding (for numerical stability)
"""
with torch.no_grad():
# If target is float32, just copy directly
if target.dtype == torch.float32:
target.copy_(source)
return
# Special handling for int8
if target.dtype == torch.int8:
# Scale the source values to utilize the full int8 range
scaled = source * 127.0 # Scale to [-127, 127]
# Add random noise for stochastic rounding
noise = torch.rand_like(scaled) - 0.5
rounded = torch.round(scaled + noise)
# Clamp to int8 range
clamped = torch.clamp(rounded, -127, 127)
target.copy_(clamped.to(torch.int8))
return
mantissa_bits, _ = get_format_params(target.dtype)
# Convert source to int32 view
source_int = source.view(dtype=torch.int32)
# Calculate number of bits to round
bits_to_round = 23 - mantissa_bits # 23 is float32 mantissa bits
# Create random integers for stochastic rounding
rand = torch.randint_like(
source,
dtype=torch.int32,
low=0,
high=(1 << bits_to_round),
)
# Add random values to the bits that will be rounded off
result = source_int.clone()
result.add_(rand)
# Mask to keep only the bits we want
# Create mask with 1s in positions we want to keep
mask = (-1) << bits_to_round
result.bitwise_and_(mask)
# Handle minimum value threshold if specified
if eps is not None:
eps_int = torch.tensor(
eps, dtype=torch.float32).view(dtype=torch.int32)
zero_mask = (result.abs() < eps_int)
result[zero_mask] = torch.sign(source_int[zero_mask]) * eps_int
# Convert back to float32 view
result_float = result.view(dtype=torch.float32)
# Special handling for float8 formats
if target.dtype == torch.float8_e4m3fn:
result_float.clamp_(-448.0, 448.0)
elif target.dtype == torch.float8_e5m2:
result_float.clamp_(-57344.0, 57344.0)
target.copy_(result_float)
del result, rand, source_int
class Auto8bitTensor:
def __init__(self, data: Tensor, *args, **kwargs):
abs_max = data.abs().max().item()
scale = abs_max / 127.0 if abs_max > 0 else 1.0
self.quantized = (data / scale).round().clamp(-127, 127).to(torch.int8)
self.scale = scale
self.orig_dtype = data.dtype
def dequantize(self) -> Tensor:
return self.quantized.to(dtype=torch.float32) * self.scale
def to(self, *args, **kwargs):
# Handle the dtype argument whether it's positional or keyword
dtype = None
if args and isinstance(args[0], torch.dtype):
dtype = args[0]
args = args[1:]
elif 'dtype' in kwargs:
dtype = kwargs['dtype']
del kwargs['dtype']
if dtype is not None:
# First dequantize then convert to requested dtype
return self.dequantize().to(dtype=dtype, *args, **kwargs)
# If no dtype specified, just pass through to parent
return self.dequantize().to(*args, **kwargs)
def stochastic_grad_accummulation(param):
if hasattr(param, "_accum_grad"):
grad_fp32 = param._accum_grad.clone().to(torch.float32)
grad_fp32.add_(param.grad.to(torch.float32))
copy_stochastic(param._accum_grad, grad_fp32)
del grad_fp32
del param.grad
else:
param._accum_grad = param.grad.clone()
del param.grad