Add support for training with an accuracy recovery adapter with qwen image

This commit is contained in:
Jaret Burkett
2025-08-12 08:21:36 -06:00
parent 4ad18f3d00
commit 77b10d884d
8 changed files with 292 additions and 36 deletions

View File

@@ -601,6 +601,16 @@ class ModelConfig:
# 20 different model variants
self.extras_name_or_path = kwargs.get("extras_name_or_path", self.name_or_path)
# path to an accuracy recovery adapter, either local or remote
self.accuracy_recovery_adapter = kwargs.get("accuracy_recovery_adapter", None)
# parse ARA from qtype
if self.qtype is not None and "|" in self.qtype:
self.qtype, self.accuracy_recovery_adapter = self.qtype.split('|')
# compile the model with torch compile
self.compile = kwargs.get("compile", False)
# kwargs to pass to the model
self.model_kwargs = kwargs.get("model_kwargs", {})

View File

@@ -168,8 +168,10 @@ class BaseModel:
self._after_sample_img_hooks = []
self._status_update_hooks = []
self.is_transformer = False
self.sample_prompts_cache = None
self.accuracy_recovery_adapter: Union[None, 'LoRASpecialNetwork'] = None
# properties for old arch for backwards compatibility
@property

View File

@@ -1,19 +1,36 @@
from fnmatch import fnmatch
from typing import Any, Dict, List, Optional, Union
from typing import List, Optional, Union, TYPE_CHECKING
import torch
from dataclasses import dataclass
from optimum.quanto.quantize import _quantize_submodule
from optimum.quanto.tensor import Optimizer, qtype, qtypes
from torchao.quantization.quant_api import (
quantize_ as torchao_quantize_,
Float8WeightOnlyConfig,
UIntXWeightOnlyConfig
UIntXWeightOnlyConfig,
)
from optimum.quanto import freeze
from tqdm import tqdm
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from toolkit.print import print_acc
import os
if TYPE_CHECKING:
from toolkit.models.base_model import BaseModel
# the quantize function in quanto had a bug where it was using exclude instead of include
Q_MODULES = ['QLinear', 'QConv2d', 'QEmbedding', 'QBatchNorm2d', 'QLayerNorm', 'QConvTranspose2d', 'QEmbeddingBag']
Q_MODULES = [
"QLinear",
"QConv2d",
"QEmbedding",
"QBatchNorm2d",
"QLayerNorm",
"QConvTranspose2d",
"QEmbeddingBag",
]
torchao_qtypes = {
# "int4": Int4WeightOnlyConfig(),
@@ -27,11 +44,13 @@ torchao_qtypes = {
"float8": Float8WeightOnlyConfig(),
}
class aotype:
def __init__(self, name: str):
self.name = name
self.config = torchao_qtypes[name]
def get_qtype(qtype: Union[str, qtype]) -> qtype:
if qtype in torchao_qtypes:
return aotype(qtype)
@@ -40,6 +59,7 @@ def get_qtype(qtype: Union[str, qtype]) -> qtype:
else:
return qtype
def quantize(
model: torch.nn.Module,
weights: Optional[Union[str, qtype, aotype]] = None,
@@ -79,7 +99,9 @@ def quantize(
if exclude is not None:
exclude = [exclude] if isinstance(exclude, str) else exclude
for name, m in model.named_modules():
if include is not None and not any(fnmatch(name, pattern) for pattern in include):
if include is not None and not any(
fnmatch(name, pattern) for pattern in include
):
continue
if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude):
continue
@@ -91,8 +113,191 @@ def quantize(
if isinstance(weights, aotype):
torchao_quantize_(m, weights.config)
else:
_quantize_submodule(model, name, m, weights=weights,
activations=activations, optimizer=optimizer)
_quantize_submodule(
model,
name,
m,
weights=weights,
activations=activations,
optimizer=optimizer,
)
except Exception as e:
print(f"Failed to quantize {name}: {e}")
# raise e
# raise e
def quantize_model(
base_model: "BaseModel",
model_to_quantize: torch.nn.Module,
):
from toolkit.dequantize import patch_dequantization_on_save
if not hasattr(base_model, "get_transformer_block_names"):
raise ValueError(
"The model to quantize must have a method `get_transformer_block_names`."
)
# patch the state dict method
patch_dequantization_on_save(model_to_quantize)
if base_model.model_config.accuracy_recovery_adapter is not None:
from toolkit.config_modules import NetworkConfig
from toolkit.lora_special import LoRASpecialNetwork
# we need to load and quantize with an accuracy recovery adapter
# todo handle hf repos
load_lora_path = base_model.model_config.accuracy_recovery_adapter
if not os.path.exists(load_lora_path):
# not local file, grab from the hub
path_split = load_lora_path.split("/")
if len(path_split) > 3:
raise ValueError(
"The accuracy recovery adapter path must be a local path or for a hf repo, 'username/repo_name/filename.safetensors'."
)
repo_id = f"{path_split[0]}/{path_split[1]}"
print_acc(f"Grabbing lora from the hub: {load_lora_path}")
new_lora_path = hf_hub_download(
repo_id,
filename=path_split[-1],
)
# replace the path
load_lora_path = new_lora_path
# build the lora config based on the lora weights
lora_state_dict = load_file(load_lora_path)
if hasattr(base_model, "convert_lora_weights_before_load"):
lora_state_dict = base_model.convert_lora_weights_before_load(lora_state_dict)
network_config = {
"type": "lora",
"network_kwargs": {"only_if_contains": []},
"transformer_only": False,
}
first_key = list(lora_state_dict.keys())[0]
first_weight = lora_state_dict[first_key]
# if it starts with lycoris and includes lokr
if first_key.startswith("lycoris") and any(
"lokr" in key for key in lora_state_dict.keys()
):
network_config["type"] = "lokr"
network_kwargs = {}
# find firse loraA weight
if network_config["type"] == "lora":
linear_dim = None
for key, value in lora_state_dict.items():
if "lora_A" in key:
linear_dim = int(value.shape[0])
break
linear_alpha = linear_dim
network_config["linear"] = linear_dim
network_config["linear_alpha"] = linear_alpha
# we build the keys to match every key
only_if_contains = []
for key in lora_state_dict.keys():
contains_key = key.split(".lora_")[0]
if contains_key not in only_if_contains:
only_if_contains.append(contains_key)
network_kwargs["only_if_contains"] = only_if_contains
elif network_config["type"] == "lokr":
# find the factor
largest_factor = 0
for key, value in lora_state_dict.items():
if "lokr_w1" in key:
factor = int(value.shape[0])
if factor > largest_factor:
largest_factor = factor
network_config["lokr_full_rank"] = True
network_config["lokr_factor"] = largest_factor
only_if_contains = []
for key in lora_state_dict.keys():
if "lokr_w1" in key:
contains_key = key.split(".lokr_w1")[0]
contains_key = contains_key.replace("lycoris_", "")
if contains_key not in only_if_contains:
only_if_contains.append(contains_key)
network_kwargs["only_if_contains"] = only_if_contains
if hasattr(base_model, 'target_lora_modules'):
network_kwargs['target_lin_modules'] = base_model.target_lora_modules
# todo auto grab these
# get dim and scale
network_config = NetworkConfig(**network_config)
network = LoRASpecialNetwork(
text_encoder=None,
unet=model_to_quantize,
lora_dim=network_config.linear,
multiplier=1.0,
alpha=network_config.linear_alpha,
# conv_lora_dim=self.network_config.conv,
# conv_alpha=self.network_config.conv_alpha,
train_unet=True,
train_text_encoder=False,
network_config=network_config,
network_type=network_config.type,
transformer_only=network_config.transformer_only,
is_transformer=base_model.is_transformer,
base_model=base_model,
**network_kwargs
)
network.apply_to(
None, model_to_quantize, apply_text_encoder=False, apply_unet=True
)
network.force_to(base_model.device_torch, dtype=base_model.torch_dtype)
network._update_torch_multiplier()
network.load_weights(lora_state_dict)
network.eval()
network.is_active = True
network.can_merge_in = False
base_model.accuracy_recovery_adapter = network
# quantize it
quantization_type = get_qtype(base_model.model_config.qtype)
for lora_module in tqdm(network.unet_loras, desc="Attaching quantization"):
# the lora has already hijacked the original module
orig_module = lora_module.org_module[0]
orig_module.to(base_model.torch_dtype)
# make the params not require gradients
for param in orig_module.parameters():
param.requires_grad = False
quantize(orig_module, weights=quantization_type)
freeze(orig_module)
if base_model.model_config.low_vram:
# move it back to cpu
orig_module.to("cpu")
else:
# quantize model the original way without an accuracy recovery adapter
# move and quantize only certain pieces at a time.
quantization_type = get_qtype(base_model.model_config.qtype)
# all_blocks = list(model_to_quantize.transformer_blocks)
all_blocks: List[torch.nn.Module] = []
transformer_block_names = base_model.get_transformer_block_names()
for name in transformer_block_names:
block = getattr(model_to_quantize, name, None)
if block is not None:
all_blocks.append(block)
base_model.print_and_status_update(
f" - quantizing {len(all_blocks)} transformer blocks"
)
for block in tqdm(all_blocks):
block.to(base_model.device_torch, dtype=base_model.torch_dtype)
quantize(block, weights=quantization_type)
freeze(block)
block.to("cpu")
# todo, on extras find a universal way to quantize them on device and move them back to their original
# device without having to move the transformer blocks to the device first
base_model.print_and_status_update(" - quantizing extras")
model_to_quantize.to(base_model.device_torch, dtype=base_model.torch_dtype)
quantize(model_to_quantize, weights=quantization_type)
freeze(model_to_quantize)