Added ability to quantize with torchao

This commit is contained in:
Jaret Burkett
2025-03-20 16:28:54 -06:00
parent 3a6b24f4c8
commit f5aa4232fa
7 changed files with 57 additions and 26 deletions

View File

@@ -1,17 +1,48 @@
from fnmatch import fnmatch
from typing import Any, Dict, List, Optional, Union
import torch
from dataclasses import dataclass
from optimum.quanto.quantize import _quantize_submodule
from optimum.quanto.tensor import Optimizer, qtype
from optimum.quanto.tensor import Optimizer, qtype, qtypes
from torchao.quantization.quant_api import (
quantize_ as torchao_quantize_,
Float8WeightOnlyConfig,
UIntXWeightOnlyConfig
)
# the quantize function in quanto had a bug where it was using exclude instead of include
Q_MODULES = ['QLinear', 'QConv2d', 'QEmbedding', 'QBatchNorm2d', 'QLayerNorm', 'QConvTranspose2d', 'QEmbeddingBag']
torchao_qtypes = {
# "int4": Int4WeightOnlyConfig(),
"uint2": UIntXWeightOnlyConfig(torch.uint2),
"uint3": UIntXWeightOnlyConfig(torch.uint3),
"uint4": UIntXWeightOnlyConfig(torch.uint4),
"uint5": UIntXWeightOnlyConfig(torch.uint5),
"uint6": UIntXWeightOnlyConfig(torch.uint6),
"uint7": UIntXWeightOnlyConfig(torch.uint7),
"uint8": UIntXWeightOnlyConfig(torch.uint8),
"float8": Float8WeightOnlyConfig(),
}
class aotype:
def __init__(self, name: str):
self.name = name
self.config = torchao_qtypes[name]
def get_qtype(qtype: Union[str, qtype]) -> qtype:
if qtype in torchao_qtypes:
return aotype(qtype)
if isinstance(qtype, str):
return qtypes[qtype]
else:
return qtype
def quantize(
model: torch.nn.Module,
weights: Optional[Union[str, qtype]] = None,
weights: Optional[Union[str, qtype, aotype]] = None,
activations: Optional[Union[str, qtype]] = None,
optimizer: Optional[Optimizer] = None,
include: Optional[Union[str, List[str]]] = None,
@@ -57,8 +88,11 @@ def quantize(
if m.__class__.__name__ in Q_MODULES:
continue
else:
_quantize_submodule(model, name, m, weights=weights,
activations=activations, optimizer=optimizer)
if isinstance(weights, aotype):
torchao_quantize_(m, weights.config)
else:
_quantize_submodule(model, name, m, weights=weights,
activations=activations, optimizer=optimizer)
except Exception as e:
print(f"Failed to quantize {name}: {e}")
raise e
raise e