diff --git a/toolkit/optimizers/adafactor.py b/toolkit/optimizers/adafactor.py index 3f4a738c..b98f4590 100644 --- a/toolkit/optimizers/adafactor.py +++ b/toolkit/optimizers/adafactor.py @@ -2,6 +2,7 @@ import math from typing import List import torch from toolkit.optimizers.optimizer_utils import copy_stochastic, stochastic_grad_accummulation +from optimum.quanto import QBytesTensor class Adafactor(torch.optim.Optimizer): @@ -251,6 +252,9 @@ class Adafactor(torch.optim.Optimizer): state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) p_data_fp32 = p + + if isinstance(p_data_fp32, QBytesTensor): + p_data_fp32 = p_data_fp32.dequantize() if p.dtype != torch.float32: p_data_fp32 = p_data_fp32.clone().float() diff --git a/toolkit/optimizers/optimizer_utils.py b/toolkit/optimizers/optimizer_utils.py index f895244c..28ae280d 100644 --- a/toolkit/optimizers/optimizer_utils.py +++ b/toolkit/optimizers/optimizer_utils.py @@ -1,6 +1,96 @@ import torch from torch import Tensor from typing import Optional +from optimum.quanto import QBytesTensor + + +def compute_scale_for_dtype(tensor, dtype): + """ + Compute appropriate scale for the given tensor and target dtype. + + Args: + tensor: Input tensor to be quantized + dtype: Target dtype for quantization + Returns: + Appropriate scale factor for the quantization + """ + if dtype == torch.int8: + abs_max = torch.max(torch.abs(tensor)) + return abs_max / 127.0 if abs_max > 0 else 1.0 + elif dtype == torch.uint8: + max_val = torch.max(tensor) + min_val = torch.min(tensor) + range_val = max_val - min_val + return range_val / 255.0 if range_val > 0 else 1.0 + elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + # For float8, we typically want to preserve the magnitude of the values + # while fitting within the representable range of the format + abs_max = torch.max(torch.abs(tensor)) + if dtype == torch.float8_e4m3fn: + # e4m3fn has range [-448, 448] with no infinities + max_representable = 448.0 + else: # torch.float8_e5m2 + # e5m2 has range [-57344, 57344] with infinities + max_representable = 57344.0 + + return abs_max / max_representable if abs_max > 0 else 1.0 + else: + raise ValueError(f"Unsupported dtype for quantization: {dtype}") + +def quantize_tensor(tensor, dtype): + """ + Quantize a floating-point tensor to the target dtype with appropriate scaling. + + Args: + tensor: Input tensor (float) + dtype: Target dtype for quantization + Returns: + quantized_data: Quantized tensor + scale: Scale factor used + """ + scale = compute_scale_for_dtype(tensor, dtype) + + if dtype == torch.int8: + quantized_data = torch.clamp(torch.round(tensor / scale), -128, 127).to(dtype) + elif dtype == torch.uint8: + quantized_data = torch.clamp(torch.round(tensor / scale), 0, 255).to(dtype) + elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + # For float8, we scale and then cast directly to the target type + # The casting operation will handle the appropriate rounding + scaled_tensor = tensor / scale + quantized_data = scaled_tensor.to(dtype) + else: + raise ValueError(f"Unsupported dtype for quantization: {dtype}") + + return quantized_data, scale + + +def update_parameter(target, result_float): + """ + Updates a parameter tensor, handling both regular torch.Tensor and QBytesTensor cases + with proper rescaling for quantized tensors. + + Args: + target: The parameter to update (either torch.Tensor or QBytesTensor) + result_float: The new values to assign (torch.Tensor) + """ + if isinstance(target, QBytesTensor): + # Get the target dtype from the existing quantized tensor + target_dtype = target._data.dtype + + # Handle device placement + device = target._data.device + result_float = result_float.to(device) + + # Compute new quantized values and scale + quantized_data, new_scale = quantize_tensor(result_float, target_dtype) + + # Update the internal tensors with newly computed values + target._data.copy_(quantized_data) + target._scale.copy_(new_scale) + else: + # Regular tensor update + target.copy_(result_float) def get_format_params(dtype: torch.dtype) -> tuple[int, int]: @@ -98,7 +188,9 @@ def copy_stochastic( elif target.dtype == torch.float8_e5m2: result_float.clamp_(-57344.0, 57344.0) - target.copy_(result_float) + # Copy the result to the target tensor + update_parameter(target, result_float) + # target.copy_(result_float) del result, rand, source_int