Added ability for adafactor to fully fine tune quantized model.

This commit is contained in:
Jaret Burkett
2024-10-30 16:38:07 -06:00
parent 58f9d01c2b
commit 025ee3dd3d
2 changed files with 97 additions and 1 deletions

View File

@@ -2,6 +2,7 @@ import math
from typing import List
import torch
from toolkit.optimizers.optimizer_utils import copy_stochastic, stochastic_grad_accummulation
from optimum.quanto import QBytesTensor
class Adafactor(torch.optim.Optimizer):
@@ -251,6 +252,9 @@ class Adafactor(torch.optim.Optimizer):
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
p_data_fp32 = p
if isinstance(p_data_fp32, QBytesTensor):
p_data_fp32 = p_data_fp32.dequantize()
if p.dtype != torch.float32:
p_data_fp32 = p_data_fp32.clone().float()

View File

@@ -1,6 +1,96 @@
import torch
from torch import Tensor
from typing import Optional
from optimum.quanto import QBytesTensor
def compute_scale_for_dtype(tensor, dtype):
"""
Compute appropriate scale for the given tensor and target dtype.
Args:
tensor: Input tensor to be quantized
dtype: Target dtype for quantization
Returns:
Appropriate scale factor for the quantization
"""
if dtype == torch.int8:
abs_max = torch.max(torch.abs(tensor))
return abs_max / 127.0 if abs_max > 0 else 1.0
elif dtype == torch.uint8:
max_val = torch.max(tensor)
min_val = torch.min(tensor)
range_val = max_val - min_val
return range_val / 255.0 if range_val > 0 else 1.0
elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
# For float8, we typically want to preserve the magnitude of the values
# while fitting within the representable range of the format
abs_max = torch.max(torch.abs(tensor))
if dtype == torch.float8_e4m3fn:
# e4m3fn has range [-448, 448] with no infinities
max_representable = 448.0
else: # torch.float8_e5m2
# e5m2 has range [-57344, 57344] with infinities
max_representable = 57344.0
return abs_max / max_representable if abs_max > 0 else 1.0
else:
raise ValueError(f"Unsupported dtype for quantization: {dtype}")
def quantize_tensor(tensor, dtype):
"""
Quantize a floating-point tensor to the target dtype with appropriate scaling.
Args:
tensor: Input tensor (float)
dtype: Target dtype for quantization
Returns:
quantized_data: Quantized tensor
scale: Scale factor used
"""
scale = compute_scale_for_dtype(tensor, dtype)
if dtype == torch.int8:
quantized_data = torch.clamp(torch.round(tensor / scale), -128, 127).to(dtype)
elif dtype == torch.uint8:
quantized_data = torch.clamp(torch.round(tensor / scale), 0, 255).to(dtype)
elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
# For float8, we scale and then cast directly to the target type
# The casting operation will handle the appropriate rounding
scaled_tensor = tensor / scale
quantized_data = scaled_tensor.to(dtype)
else:
raise ValueError(f"Unsupported dtype for quantization: {dtype}")
return quantized_data, scale
def update_parameter(target, result_float):
"""
Updates a parameter tensor, handling both regular torch.Tensor and QBytesTensor cases
with proper rescaling for quantized tensors.
Args:
target: The parameter to update (either torch.Tensor or QBytesTensor)
result_float: The new values to assign (torch.Tensor)
"""
if isinstance(target, QBytesTensor):
# Get the target dtype from the existing quantized tensor
target_dtype = target._data.dtype
# Handle device placement
device = target._data.device
result_float = result_float.to(device)
# Compute new quantized values and scale
quantized_data, new_scale = quantize_tensor(result_float, target_dtype)
# Update the internal tensors with newly computed values
target._data.copy_(quantized_data)
target._scale.copy_(new_scale)
else:
# Regular tensor update
target.copy_(result_float)
def get_format_params(dtype: torch.dtype) -> tuple[int, int]:
@@ -98,7 +188,9 @@ def copy_stochastic(
elif target.dtype == torch.float8_e5m2:
result_float.clamp_(-57344.0, 57344.0)
target.copy_(result_float)
# Copy the result to the target tensor
update_parameter(target, result_float)
# target.copy_(result_float)
del result, rand, source_int