mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
98 lines
3.8 KiB
Python
98 lines
3.8 KiB
Python
from fnmatch import fnmatch
|
|
from typing import Any, Dict, List, Optional, Union
|
|
import torch
|
|
from dataclasses import dataclass
|
|
|
|
from optimum.quanto.quantize import _quantize_submodule
|
|
from optimum.quanto.tensor import Optimizer, qtype, qtypes
|
|
from torchao.quantization.quant_api import (
|
|
quantize_ as torchao_quantize_,
|
|
Float8WeightOnlyConfig,
|
|
UIntXWeightOnlyConfig
|
|
)
|
|
|
|
# the quantize function in quanto had a bug where it was using exclude instead of include
|
|
|
|
Q_MODULES = ['QLinear', 'QConv2d', 'QEmbedding', 'QBatchNorm2d', 'QLayerNorm', 'QConvTranspose2d', 'QEmbeddingBag']
|
|
|
|
torchao_qtypes = {
|
|
# "int4": Int4WeightOnlyConfig(),
|
|
"uint2": UIntXWeightOnlyConfig(torch.uint2),
|
|
"uint3": UIntXWeightOnlyConfig(torch.uint3),
|
|
"uint4": UIntXWeightOnlyConfig(torch.uint4),
|
|
"uint5": UIntXWeightOnlyConfig(torch.uint5),
|
|
"uint6": UIntXWeightOnlyConfig(torch.uint6),
|
|
"uint7": UIntXWeightOnlyConfig(torch.uint7),
|
|
"uint8": UIntXWeightOnlyConfig(torch.uint8),
|
|
"float8": Float8WeightOnlyConfig(),
|
|
}
|
|
|
|
class aotype:
|
|
def __init__(self, name: str):
|
|
self.name = name
|
|
self.config = torchao_qtypes[name]
|
|
|
|
def get_qtype(qtype: Union[str, qtype]) -> qtype:
|
|
if qtype in torchao_qtypes:
|
|
return aotype(qtype)
|
|
if isinstance(qtype, str):
|
|
return qtypes[qtype]
|
|
else:
|
|
return qtype
|
|
|
|
def quantize(
|
|
model: torch.nn.Module,
|
|
weights: Optional[Union[str, qtype, aotype]] = None,
|
|
activations: Optional[Union[str, qtype]] = None,
|
|
optimizer: Optional[Optimizer] = None,
|
|
include: Optional[Union[str, List[str]]] = None,
|
|
exclude: Optional[Union[str, List[str]]] = None,
|
|
):
|
|
"""Quantize the specified model submodules
|
|
|
|
Recursively quantize the submodules of the specified parent model.
|
|
|
|
Only modules that have quantized counterparts will be quantized.
|
|
|
|
If include patterns are specified, the submodule name must match one of them.
|
|
|
|
If exclude patterns are specified, the submodule must not match one of them.
|
|
|
|
Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See
|
|
https://docs.python.org/3/library/fnmatch.html for more details.
|
|
|
|
Note: quantization happens in-place and modifies the original model and its descendants.
|
|
|
|
Args:
|
|
model (`torch.nn.Module`): the model whose submodules will be quantized.
|
|
weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization.
|
|
activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization.
|
|
include (`Optional[Union[str, List[str]]]`):
|
|
Patterns constituting the allowlist. If provided, module names must match at
|
|
least one pattern from the allowlist.
|
|
exclude (`Optional[Union[str, List[str]]]`):
|
|
Patterns constituting the denylist. If provided, module names must not match
|
|
any patterns from the denylist.
|
|
"""
|
|
if include is not None:
|
|
include = [include] if isinstance(include, str) else include
|
|
if exclude is not None:
|
|
exclude = [exclude] if isinstance(exclude, str) else exclude
|
|
for name, m in model.named_modules():
|
|
if include is not None and not any(fnmatch(name, pattern) for pattern in include):
|
|
continue
|
|
if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude):
|
|
continue
|
|
try:
|
|
# check if m is QLinear or QConv2d
|
|
if m.__class__.__name__ in Q_MODULES:
|
|
continue
|
|
else:
|
|
if isinstance(weights, aotype):
|
|
torchao_quantize_(m, weights.config)
|
|
else:
|
|
_quantize_submodule(model, name, m, weights=weights,
|
|
activations=activations, optimizer=optimizer)
|
|
except Exception as e:
|
|
print(f"Failed to quantize {name}: {e}")
|
|
# raise e |