mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-05 02:29:49 +00:00
Added ability for adafactor to fully fine tune quantized model.
This commit is contained in:
@@ -2,6 +2,7 @@ import math
|
||||
from typing import List
|
||||
import torch
|
||||
from toolkit.optimizers.optimizer_utils import copy_stochastic, stochastic_grad_accummulation
|
||||
from optimum.quanto import QBytesTensor
|
||||
|
||||
|
||||
class Adafactor(torch.optim.Optimizer):
|
||||
@@ -251,6 +252,9 @@ class Adafactor(torch.optim.Optimizer):
|
||||
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
||||
|
||||
p_data_fp32 = p
|
||||
|
||||
if isinstance(p_data_fp32, QBytesTensor):
|
||||
p_data_fp32 = p_data_fp32.dequantize()
|
||||
if p.dtype != torch.float32:
|
||||
p_data_fp32 = p_data_fp32.clone().float()
|
||||
|
||||
|
||||
@@ -1,6 +1,96 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from typing import Optional
|
||||
from optimum.quanto import QBytesTensor
|
||||
|
||||
|
||||
def compute_scale_for_dtype(tensor, dtype):
|
||||
"""
|
||||
Compute appropriate scale for the given tensor and target dtype.
|
||||
|
||||
Args:
|
||||
tensor: Input tensor to be quantized
|
||||
dtype: Target dtype for quantization
|
||||
Returns:
|
||||
Appropriate scale factor for the quantization
|
||||
"""
|
||||
if dtype == torch.int8:
|
||||
abs_max = torch.max(torch.abs(tensor))
|
||||
return abs_max / 127.0 if abs_max > 0 else 1.0
|
||||
elif dtype == torch.uint8:
|
||||
max_val = torch.max(tensor)
|
||||
min_val = torch.min(tensor)
|
||||
range_val = max_val - min_val
|
||||
return range_val / 255.0 if range_val > 0 else 1.0
|
||||
elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
||||
# For float8, we typically want to preserve the magnitude of the values
|
||||
# while fitting within the representable range of the format
|
||||
abs_max = torch.max(torch.abs(tensor))
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
# e4m3fn has range [-448, 448] with no infinities
|
||||
max_representable = 448.0
|
||||
else: # torch.float8_e5m2
|
||||
# e5m2 has range [-57344, 57344] with infinities
|
||||
max_representable = 57344.0
|
||||
|
||||
return abs_max / max_representable if abs_max > 0 else 1.0
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype for quantization: {dtype}")
|
||||
|
||||
def quantize_tensor(tensor, dtype):
|
||||
"""
|
||||
Quantize a floating-point tensor to the target dtype with appropriate scaling.
|
||||
|
||||
Args:
|
||||
tensor: Input tensor (float)
|
||||
dtype: Target dtype for quantization
|
||||
Returns:
|
||||
quantized_data: Quantized tensor
|
||||
scale: Scale factor used
|
||||
"""
|
||||
scale = compute_scale_for_dtype(tensor, dtype)
|
||||
|
||||
if dtype == torch.int8:
|
||||
quantized_data = torch.clamp(torch.round(tensor / scale), -128, 127).to(dtype)
|
||||
elif dtype == torch.uint8:
|
||||
quantized_data = torch.clamp(torch.round(tensor / scale), 0, 255).to(dtype)
|
||||
elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
||||
# For float8, we scale and then cast directly to the target type
|
||||
# The casting operation will handle the appropriate rounding
|
||||
scaled_tensor = tensor / scale
|
||||
quantized_data = scaled_tensor.to(dtype)
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype for quantization: {dtype}")
|
||||
|
||||
return quantized_data, scale
|
||||
|
||||
|
||||
def update_parameter(target, result_float):
|
||||
"""
|
||||
Updates a parameter tensor, handling both regular torch.Tensor and QBytesTensor cases
|
||||
with proper rescaling for quantized tensors.
|
||||
|
||||
Args:
|
||||
target: The parameter to update (either torch.Tensor or QBytesTensor)
|
||||
result_float: The new values to assign (torch.Tensor)
|
||||
"""
|
||||
if isinstance(target, QBytesTensor):
|
||||
# Get the target dtype from the existing quantized tensor
|
||||
target_dtype = target._data.dtype
|
||||
|
||||
# Handle device placement
|
||||
device = target._data.device
|
||||
result_float = result_float.to(device)
|
||||
|
||||
# Compute new quantized values and scale
|
||||
quantized_data, new_scale = quantize_tensor(result_float, target_dtype)
|
||||
|
||||
# Update the internal tensors with newly computed values
|
||||
target._data.copy_(quantized_data)
|
||||
target._scale.copy_(new_scale)
|
||||
else:
|
||||
# Regular tensor update
|
||||
target.copy_(result_float)
|
||||
|
||||
|
||||
def get_format_params(dtype: torch.dtype) -> tuple[int, int]:
|
||||
@@ -98,7 +188,9 @@ def copy_stochastic(
|
||||
elif target.dtype == torch.float8_e5m2:
|
||||
result_float.clamp_(-57344.0, 57344.0)
|
||||
|
||||
target.copy_(result_float)
|
||||
# Copy the result to the target tensor
|
||||
update_parameter(target, result_float)
|
||||
# target.copy_(result_float)
|
||||
del result, rand, source_int
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user