mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
315 lines
12 KiB
Python
315 lines
12 KiB
Python
from fnmatch import fnmatch
|
|
from typing import List, Optional, Union, TYPE_CHECKING
|
|
import torch
|
|
|
|
from optimum.quanto.quantize import _quantize_submodule
|
|
from optimum.quanto.tensor import Optimizer, qtype, qtypes
|
|
from torchao.quantization.quant_api import (
|
|
quantize_ as torchao_quantize_,
|
|
Float8WeightOnlyConfig,
|
|
UIntXWeightOnlyConfig,
|
|
)
|
|
from optimum.quanto import freeze
|
|
from tqdm import tqdm
|
|
from safetensors.torch import load_file
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
from toolkit.print import print_acc
|
|
import os
|
|
|
|
if TYPE_CHECKING:
|
|
from toolkit.models.base_model import BaseModel
|
|
|
|
# the quantize function in quanto had a bug where it was using exclude instead of include
|
|
|
|
Q_MODULES = [
|
|
"QLinear",
|
|
"QConv2d",
|
|
"QEmbedding",
|
|
"QBatchNorm2d",
|
|
"QLayerNorm",
|
|
"QConvTranspose2d",
|
|
"QEmbeddingBag",
|
|
]
|
|
|
|
torchao_qtypes = {
|
|
# "int4": Int4WeightOnlyConfig(),
|
|
"uint2": UIntXWeightOnlyConfig(torch.uint2),
|
|
"uint3": UIntXWeightOnlyConfig(torch.uint3),
|
|
"uint4": UIntXWeightOnlyConfig(torch.uint4),
|
|
"uint5": UIntXWeightOnlyConfig(torch.uint5),
|
|
"uint6": UIntXWeightOnlyConfig(torch.uint6),
|
|
"uint7": UIntXWeightOnlyConfig(torch.uint7),
|
|
"uint8": UIntXWeightOnlyConfig(torch.uint8),
|
|
"float8": Float8WeightOnlyConfig(),
|
|
}
|
|
|
|
|
|
class aotype:
|
|
def __init__(self, name: str):
|
|
self.name = name
|
|
self.config = torchao_qtypes[name]
|
|
|
|
|
|
def get_qtype(qtype: Union[str, qtype]) -> qtype:
|
|
if qtype in torchao_qtypes:
|
|
return aotype(qtype)
|
|
if isinstance(qtype, str):
|
|
return qtypes[qtype]
|
|
else:
|
|
return qtype
|
|
|
|
|
|
def quantize(
|
|
model: torch.nn.Module,
|
|
weights: Optional[Union[str, qtype, aotype]] = None,
|
|
activations: Optional[Union[str, qtype]] = None,
|
|
optimizer: Optional[Optimizer] = None,
|
|
include: Optional[Union[str, List[str]]] = None,
|
|
exclude: Optional[Union[str, List[str]]] = None,
|
|
):
|
|
"""Quantize the specified model submodules
|
|
|
|
Recursively quantize the submodules of the specified parent model.
|
|
|
|
Only modules that have quantized counterparts will be quantized.
|
|
|
|
If include patterns are specified, the submodule name must match one of them.
|
|
|
|
If exclude patterns are specified, the submodule must not match one of them.
|
|
|
|
Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See
|
|
https://docs.python.org/3/library/fnmatch.html for more details.
|
|
|
|
Note: quantization happens in-place and modifies the original model and its descendants.
|
|
|
|
Args:
|
|
model (`torch.nn.Module`): the model whose submodules will be quantized.
|
|
weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization.
|
|
activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization.
|
|
include (`Optional[Union[str, List[str]]]`):
|
|
Patterns constituting the allowlist. If provided, module names must match at
|
|
least one pattern from the allowlist.
|
|
exclude (`Optional[Union[str, List[str]]]`):
|
|
Patterns constituting the denylist. If provided, module names must not match
|
|
any patterns from the denylist.
|
|
"""
|
|
if include is not None:
|
|
include = [include] if isinstance(include, str) else include
|
|
if exclude is not None:
|
|
exclude = [exclude] if isinstance(exclude, str) else exclude
|
|
for name, m in model.named_modules():
|
|
if include is not None and not any(
|
|
fnmatch(name, pattern) for pattern in include
|
|
):
|
|
continue
|
|
if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude):
|
|
continue
|
|
try:
|
|
# check if m is QLinear or QConv2d
|
|
if m.__class__.__name__ in Q_MODULES:
|
|
continue
|
|
else:
|
|
if isinstance(weights, aotype):
|
|
torchao_quantize_(m, weights.config)
|
|
else:
|
|
_quantize_submodule(
|
|
model,
|
|
name,
|
|
m,
|
|
weights=weights,
|
|
activations=activations,
|
|
optimizer=optimizer,
|
|
)
|
|
except Exception as e:
|
|
print(f"Failed to quantize {name}: {e}")
|
|
# raise e
|
|
|
|
|
|
def quantize_model(
|
|
base_model: "BaseModel",
|
|
model_to_quantize: torch.nn.Module,
|
|
):
|
|
from toolkit.dequantize import patch_dequantization_on_save
|
|
|
|
if not hasattr(base_model, "get_transformer_block_names"):
|
|
raise ValueError(
|
|
"The model to quantize must have a method `get_transformer_block_names`."
|
|
)
|
|
|
|
# patch the state dict method
|
|
patch_dequantization_on_save(model_to_quantize)
|
|
|
|
if base_model.model_config.accuracy_recovery_adapter is not None:
|
|
from toolkit.config_modules import NetworkConfig
|
|
from toolkit.lora_special import LoRASpecialNetwork
|
|
|
|
# we need to load and quantize with an accuracy recovery adapter
|
|
# todo handle hf repos
|
|
load_lora_path = base_model.model_config.accuracy_recovery_adapter
|
|
|
|
if not os.path.exists(load_lora_path):
|
|
# not local file, grab from the hub
|
|
|
|
path_split = load_lora_path.split("/")
|
|
if len(path_split) > 3:
|
|
raise ValueError(
|
|
"The accuracy recovery adapter path must be a local path or for a hf repo, 'username/repo_name/filename.safetensors'."
|
|
)
|
|
repo_id = f"{path_split[0]}/{path_split[1]}"
|
|
print_acc(f"Grabbing lora from the hub: {load_lora_path}")
|
|
new_lora_path = hf_hub_download(
|
|
repo_id,
|
|
filename=path_split[-1],
|
|
)
|
|
# replace the path
|
|
load_lora_path = new_lora_path
|
|
|
|
# build the lora config based on the lora weights
|
|
lora_state_dict = load_file(load_lora_path)
|
|
|
|
if hasattr(base_model, "convert_lora_weights_before_load"):
|
|
lora_state_dict = base_model.convert_lora_weights_before_load(lora_state_dict)
|
|
|
|
network_config = {
|
|
"type": "lora",
|
|
"network_kwargs": {"only_if_contains": []},
|
|
"transformer_only": False,
|
|
}
|
|
first_key = list(lora_state_dict.keys())[0]
|
|
first_weight = lora_state_dict[first_key]
|
|
# if it starts with lycoris and includes lokr
|
|
if first_key.startswith("lycoris") and any(
|
|
"lokr" in key for key in lora_state_dict.keys()
|
|
):
|
|
network_config["type"] = "lokr"
|
|
|
|
network_kwargs = {}
|
|
|
|
# find firse loraA weight
|
|
if network_config["type"] == "lora":
|
|
linear_dim = None
|
|
for key, value in lora_state_dict.items():
|
|
if "lora_A" in key:
|
|
linear_dim = int(value.shape[0])
|
|
break
|
|
linear_alpha = linear_dim
|
|
network_config["linear"] = linear_dim
|
|
network_config["linear_alpha"] = linear_alpha
|
|
|
|
# we build the keys to match every key
|
|
only_if_contains = []
|
|
for key in lora_state_dict.keys():
|
|
contains_key = key.split(".lora_")[0]
|
|
if contains_key not in only_if_contains:
|
|
only_if_contains.append(contains_key)
|
|
|
|
network_kwargs["only_if_contains"] = only_if_contains
|
|
elif network_config["type"] == "lokr":
|
|
# find the factor
|
|
largest_factor = 0
|
|
for key, value in lora_state_dict.items():
|
|
if "lokr_w1" in key:
|
|
factor = int(value.shape[0])
|
|
if factor > largest_factor:
|
|
largest_factor = factor
|
|
network_config["lokr_full_rank"] = True
|
|
network_config["lokr_factor"] = largest_factor
|
|
|
|
only_if_contains = []
|
|
for key in lora_state_dict.keys():
|
|
if "lokr_w1" in key:
|
|
contains_key = key.split(".lokr_w1")[0]
|
|
contains_key = contains_key.replace("lycoris_", "")
|
|
if contains_key not in only_if_contains:
|
|
only_if_contains.append(contains_key)
|
|
network_kwargs["only_if_contains"] = only_if_contains
|
|
|
|
if hasattr(base_model, 'target_lora_modules'):
|
|
network_kwargs['target_lin_modules'] = base_model.target_lora_modules
|
|
|
|
# todo auto grab these
|
|
# get dim and scale
|
|
network_config = NetworkConfig(**network_config)
|
|
|
|
network = LoRASpecialNetwork(
|
|
text_encoder=None,
|
|
unet=model_to_quantize,
|
|
lora_dim=network_config.linear,
|
|
multiplier=1.0,
|
|
alpha=network_config.linear_alpha,
|
|
# conv_lora_dim=self.network_config.conv,
|
|
# conv_alpha=self.network_config.conv_alpha,
|
|
train_unet=True,
|
|
train_text_encoder=False,
|
|
network_config=network_config,
|
|
network_type=network_config.type,
|
|
transformer_only=network_config.transformer_only,
|
|
is_transformer=base_model.is_transformer,
|
|
base_model=base_model,
|
|
**network_kwargs
|
|
)
|
|
network.apply_to(
|
|
None, model_to_quantize, apply_text_encoder=False, apply_unet=True
|
|
)
|
|
network.force_to(base_model.device_torch, dtype=base_model.torch_dtype)
|
|
network._update_torch_multiplier()
|
|
network.load_weights(lora_state_dict)
|
|
network.eval()
|
|
network.is_active = True
|
|
network.can_merge_in = False
|
|
base_model.accuracy_recovery_adapter = network
|
|
|
|
# quantize it
|
|
lora_exclude_modules = []
|
|
quantization_type = get_qtype(base_model.model_config.qtype)
|
|
for lora_module in tqdm(network.unet_loras, desc="Attaching quantization"):
|
|
# the lora has already hijacked the original module
|
|
orig_module = lora_module.org_module[0]
|
|
orig_module.to(base_model.torch_dtype)
|
|
# make the params not require gradients
|
|
for param in orig_module.parameters():
|
|
param.requires_grad = False
|
|
quantize(orig_module, weights=quantization_type)
|
|
freeze(orig_module)
|
|
module_name = lora_module.lora_name.replace('$$', '.').replace('transformer.', '')
|
|
lora_exclude_modules.append(module_name)
|
|
if base_model.model_config.low_vram:
|
|
# move it back to cpu
|
|
orig_module.to("cpu")
|
|
pass
|
|
# quantize additional layers
|
|
print_acc(" - quantizing additional layers")
|
|
quantization_type = get_qtype('uint8')
|
|
quantize(
|
|
model_to_quantize,
|
|
weights=quantization_type,
|
|
exclude=lora_exclude_modules
|
|
)
|
|
else:
|
|
# quantize model the original way without an accuracy recovery adapter
|
|
# move and quantize only certain pieces at a time.
|
|
quantization_type = get_qtype(base_model.model_config.qtype)
|
|
# all_blocks = list(model_to_quantize.transformer_blocks)
|
|
all_blocks: List[torch.nn.Module] = []
|
|
transformer_block_names = base_model.get_transformer_block_names()
|
|
for name in transformer_block_names:
|
|
block_list = getattr(model_to_quantize, name, None)
|
|
if block_list is not None:
|
|
all_blocks += list(block_list)
|
|
base_model.print_and_status_update(
|
|
f" - quantizing {len(all_blocks)} transformer blocks"
|
|
)
|
|
for block in tqdm(all_blocks):
|
|
block.to(base_model.device_torch, dtype=base_model.torch_dtype)
|
|
quantize(block, weights=quantization_type)
|
|
freeze(block)
|
|
block.to("cpu")
|
|
|
|
# todo, on extras find a universal way to quantize them on device and move them back to their original
|
|
# device without having to move the transformer blocks to the device first
|
|
base_model.print_and_status_update(" - quantizing extras")
|
|
model_to_quantize.to(base_model.device_torch, dtype=base_model.torch_dtype)
|
|
quantize(model_to_quantize, weights=quantization_type)
|
|
freeze(model_to_quantize)
|