mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
88 lines
3.0 KiB
Python
88 lines
3.0 KiB
Python
|
|
|
|
from functools import partial
|
|
from optimum.quanto.tensor import QTensor
|
|
import torch
|
|
|
|
|
|
def hacked_state_dict(self, *args, **kwargs):
|
|
orig_state_dict = self.orig_state_dict(*args, **kwargs)
|
|
new_state_dict = {}
|
|
for key, value in orig_state_dict.items():
|
|
if key.endswith("._scale"):
|
|
continue
|
|
if key.endswith(".input_scale"):
|
|
continue
|
|
if key.endswith(".output_scale"):
|
|
continue
|
|
if key.endswith("._data"):
|
|
key = key[:-6]
|
|
scale = orig_state_dict[key + "._scale"]
|
|
# scale is the original dtype
|
|
dtype = scale.dtype
|
|
scale = scale.float()
|
|
value = value.float()
|
|
dequantized = value * scale
|
|
|
|
# handle input and output scaling if they exist
|
|
input_scale = orig_state_dict.get(key + ".input_scale")
|
|
|
|
if input_scale is not None:
|
|
# make sure the tensor is 1.0
|
|
if input_scale.item() != 1.0:
|
|
raise ValueError("Input scale is not 1.0, cannot dequantize")
|
|
|
|
output_scale = orig_state_dict.get(key + ".output_scale")
|
|
|
|
if output_scale is not None:
|
|
# make sure the tensor is 1.0
|
|
if output_scale.item() != 1.0:
|
|
raise ValueError("Output scale is not 1.0, cannot dequantize")
|
|
|
|
new_state_dict[key] = dequantized.to('cpu', dtype=dtype)
|
|
else:
|
|
new_state_dict[key] = value
|
|
return new_state_dict
|
|
|
|
# hacks the state dict so we can dequantize before saving
|
|
def patch_dequantization_on_save(model):
|
|
model.orig_state_dict = model.state_dict
|
|
model.state_dict = partial(hacked_state_dict, model)
|
|
|
|
|
|
def dequantize_parameter(module: torch.nn.Module, param_name: str) -> bool:
|
|
"""
|
|
Convert a quantized parameter back to a regular Parameter with floating point values.
|
|
|
|
Args:
|
|
module: The module containing the parameter to unquantize
|
|
param_name: Name of the parameter to unquantize (e.g., 'weight', 'bias')
|
|
|
|
Returns:
|
|
bool: True if parameter was unquantized, False if it was already unquantized
|
|
"""
|
|
|
|
# Check if the parameter exists
|
|
if not hasattr(module, param_name):
|
|
raise AttributeError(f"Module has no parameter named '{param_name}'")
|
|
|
|
param = getattr(module, param_name)
|
|
|
|
# If it's not a parameter or not quantized, nothing to do
|
|
if not isinstance(param, torch.nn.Parameter):
|
|
raise TypeError(f"'{param_name}' is not a Parameter")
|
|
if not isinstance(param, QTensor):
|
|
return False
|
|
|
|
# Convert to float tensor while preserving device and requires_grad
|
|
with torch.no_grad():
|
|
float_tensor = param.float()
|
|
new_param = torch.nn.Parameter(
|
|
float_tensor,
|
|
requires_grad=param.requires_grad
|
|
)
|
|
|
|
# Replace the parameter
|
|
setattr(module, param_name, new_param)
|
|
|
|
return True |