mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added ability for adafactor to fully fine tune quantized model.
This commit is contained in:
@@ -2,6 +2,7 @@ import math
|
|||||||
from typing import List
|
from typing import List
|
||||||
import torch
|
import torch
|
||||||
from toolkit.optimizers.optimizer_utils import copy_stochastic, stochastic_grad_accummulation
|
from toolkit.optimizers.optimizer_utils import copy_stochastic, stochastic_grad_accummulation
|
||||||
|
from optimum.quanto import QBytesTensor
|
||||||
|
|
||||||
|
|
||||||
class Adafactor(torch.optim.Optimizer):
|
class Adafactor(torch.optim.Optimizer):
|
||||||
@@ -251,6 +252,9 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
||||||
|
|
||||||
p_data_fp32 = p
|
p_data_fp32 = p
|
||||||
|
|
||||||
|
if isinstance(p_data_fp32, QBytesTensor):
|
||||||
|
p_data_fp32 = p_data_fp32.dequantize()
|
||||||
if p.dtype != torch.float32:
|
if p.dtype != torch.float32:
|
||||||
p_data_fp32 = p_data_fp32.clone().float()
|
p_data_fp32 = p_data_fp32.clone().float()
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,96 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from optimum.quanto import QBytesTensor
|
||||||
|
|
||||||
|
|
||||||
|
def compute_scale_for_dtype(tensor, dtype):
|
||||||
|
"""
|
||||||
|
Compute appropriate scale for the given tensor and target dtype.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: Input tensor to be quantized
|
||||||
|
dtype: Target dtype for quantization
|
||||||
|
Returns:
|
||||||
|
Appropriate scale factor for the quantization
|
||||||
|
"""
|
||||||
|
if dtype == torch.int8:
|
||||||
|
abs_max = torch.max(torch.abs(tensor))
|
||||||
|
return abs_max / 127.0 if abs_max > 0 else 1.0
|
||||||
|
elif dtype == torch.uint8:
|
||||||
|
max_val = torch.max(tensor)
|
||||||
|
min_val = torch.min(tensor)
|
||||||
|
range_val = max_val - min_val
|
||||||
|
return range_val / 255.0 if range_val > 0 else 1.0
|
||||||
|
elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
||||||
|
# For float8, we typically want to preserve the magnitude of the values
|
||||||
|
# while fitting within the representable range of the format
|
||||||
|
abs_max = torch.max(torch.abs(tensor))
|
||||||
|
if dtype == torch.float8_e4m3fn:
|
||||||
|
# e4m3fn has range [-448, 448] with no infinities
|
||||||
|
max_representable = 448.0
|
||||||
|
else: # torch.float8_e5m2
|
||||||
|
# e5m2 has range [-57344, 57344] with infinities
|
||||||
|
max_representable = 57344.0
|
||||||
|
|
||||||
|
return abs_max / max_representable if abs_max > 0 else 1.0
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported dtype for quantization: {dtype}")
|
||||||
|
|
||||||
|
def quantize_tensor(tensor, dtype):
|
||||||
|
"""
|
||||||
|
Quantize a floating-point tensor to the target dtype with appropriate scaling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: Input tensor (float)
|
||||||
|
dtype: Target dtype for quantization
|
||||||
|
Returns:
|
||||||
|
quantized_data: Quantized tensor
|
||||||
|
scale: Scale factor used
|
||||||
|
"""
|
||||||
|
scale = compute_scale_for_dtype(tensor, dtype)
|
||||||
|
|
||||||
|
if dtype == torch.int8:
|
||||||
|
quantized_data = torch.clamp(torch.round(tensor / scale), -128, 127).to(dtype)
|
||||||
|
elif dtype == torch.uint8:
|
||||||
|
quantized_data = torch.clamp(torch.round(tensor / scale), 0, 255).to(dtype)
|
||||||
|
elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
||||||
|
# For float8, we scale and then cast directly to the target type
|
||||||
|
# The casting operation will handle the appropriate rounding
|
||||||
|
scaled_tensor = tensor / scale
|
||||||
|
quantized_data = scaled_tensor.to(dtype)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported dtype for quantization: {dtype}")
|
||||||
|
|
||||||
|
return quantized_data, scale
|
||||||
|
|
||||||
|
|
||||||
|
def update_parameter(target, result_float):
|
||||||
|
"""
|
||||||
|
Updates a parameter tensor, handling both regular torch.Tensor and QBytesTensor cases
|
||||||
|
with proper rescaling for quantized tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: The parameter to update (either torch.Tensor or QBytesTensor)
|
||||||
|
result_float: The new values to assign (torch.Tensor)
|
||||||
|
"""
|
||||||
|
if isinstance(target, QBytesTensor):
|
||||||
|
# Get the target dtype from the existing quantized tensor
|
||||||
|
target_dtype = target._data.dtype
|
||||||
|
|
||||||
|
# Handle device placement
|
||||||
|
device = target._data.device
|
||||||
|
result_float = result_float.to(device)
|
||||||
|
|
||||||
|
# Compute new quantized values and scale
|
||||||
|
quantized_data, new_scale = quantize_tensor(result_float, target_dtype)
|
||||||
|
|
||||||
|
# Update the internal tensors with newly computed values
|
||||||
|
target._data.copy_(quantized_data)
|
||||||
|
target._scale.copy_(new_scale)
|
||||||
|
else:
|
||||||
|
# Regular tensor update
|
||||||
|
target.copy_(result_float)
|
||||||
|
|
||||||
|
|
||||||
def get_format_params(dtype: torch.dtype) -> tuple[int, int]:
|
def get_format_params(dtype: torch.dtype) -> tuple[int, int]:
|
||||||
@@ -98,7 +188,9 @@ def copy_stochastic(
|
|||||||
elif target.dtype == torch.float8_e5m2:
|
elif target.dtype == torch.float8_e5m2:
|
||||||
result_float.clamp_(-57344.0, 57344.0)
|
result_float.clamp_(-57344.0, 57344.0)
|
||||||
|
|
||||||
target.copy_(result_float)
|
# Copy the result to the target tensor
|
||||||
|
update_parameter(target, result_float)
|
||||||
|
# target.copy_(result_float)
|
||||||
del result, rand, source_int
|
del result, rand, source_int
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user