mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-10 07:29:56 +00:00
Improvements for full tuning flux. Added debugging launch config for vscode
This commit is contained in:
88
toolkit/dequantize.py
Normal file
88
toolkit/dequantize.py
Normal file
@@ -0,0 +1,88 @@
|
||||
|
||||
|
||||
from functools import partial
|
||||
from optimum.quanto.tensor import QTensor
|
||||
import torch
|
||||
|
||||
|
||||
def hacked_state_dict(self, *args, **kwargs):
|
||||
orig_state_dict = self.orig_state_dict(*args, **kwargs)
|
||||
new_state_dict = {}
|
||||
for key, value in orig_state_dict.items():
|
||||
if key.endswith("._scale"):
|
||||
continue
|
||||
if key.endswith(".input_scale"):
|
||||
continue
|
||||
if key.endswith(".output_scale"):
|
||||
continue
|
||||
if key.endswith("._data"):
|
||||
key = key[:-6]
|
||||
scale = orig_state_dict[key + "._scale"]
|
||||
# scale is the original dtype
|
||||
dtype = scale.dtype
|
||||
scale = scale.float()
|
||||
value = value.float()
|
||||
dequantized = value * scale
|
||||
|
||||
# handle input and output scaling if they exist
|
||||
input_scale = orig_state_dict.get(key + ".input_scale")
|
||||
|
||||
if input_scale is not None:
|
||||
# make sure the tensor is 1.0
|
||||
if input_scale.item() != 1.0:
|
||||
raise ValueError("Input scale is not 1.0, cannot dequantize")
|
||||
|
||||
output_scale = orig_state_dict.get(key + ".output_scale")
|
||||
|
||||
if output_scale is not None:
|
||||
# make sure the tensor is 1.0
|
||||
if output_scale.item() != 1.0:
|
||||
raise ValueError("Output scale is not 1.0, cannot dequantize")
|
||||
|
||||
new_state_dict[key] = dequantized.to('cpu', dtype=dtype)
|
||||
else:
|
||||
new_state_dict[key] = value
|
||||
return new_state_dict
|
||||
|
||||
# hacks the state dict so we can dequantize before saving
|
||||
def patch_dequantization_on_save(model):
|
||||
model.orig_state_dict = model.state_dict
|
||||
model.state_dict = partial(hacked_state_dict, model)
|
||||
|
||||
|
||||
def dequantize_parameter(module: torch.nn.Module, param_name: str) -> bool:
|
||||
"""
|
||||
Convert a quantized parameter back to a regular Parameter with floating point values.
|
||||
|
||||
Args:
|
||||
module: The module containing the parameter to unquantize
|
||||
param_name: Name of the parameter to unquantize (e.g., 'weight', 'bias')
|
||||
|
||||
Returns:
|
||||
bool: True if parameter was unquantized, False if it was already unquantized
|
||||
"""
|
||||
|
||||
# Check if the parameter exists
|
||||
if not hasattr(module, param_name):
|
||||
raise AttributeError(f"Module has no parameter named '{param_name}'")
|
||||
|
||||
param = getattr(module, param_name)
|
||||
|
||||
# If it's not a parameter or not quantized, nothing to do
|
||||
if not isinstance(param, torch.nn.Parameter):
|
||||
raise TypeError(f"'{param_name}' is not a Parameter")
|
||||
if not isinstance(param, QTensor):
|
||||
return False
|
||||
|
||||
# Convert to float tensor while preserving device and requires_grad
|
||||
with torch.no_grad():
|
||||
float_tensor = param.float()
|
||||
new_param = torch.nn.Parameter(
|
||||
float_tensor,
|
||||
requires_grad=param.requires_grad
|
||||
)
|
||||
|
||||
# Replace the parameter
|
||||
setattr(module, param_name, new_param)
|
||||
|
||||
return True
|
||||
Reference in New Issue
Block a user