mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 09:44:02 +00:00
Added ability to quantize with torchao
This commit is contained in:
@@ -1,17 +1,48 @@
|
||||
from fnmatch import fnmatch
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
|
||||
from optimum.quanto.quantize import _quantize_submodule
|
||||
from optimum.quanto.tensor import Optimizer, qtype
|
||||
from optimum.quanto.tensor import Optimizer, qtype, qtypes
|
||||
from torchao.quantization.quant_api import (
|
||||
quantize_ as torchao_quantize_,
|
||||
Float8WeightOnlyConfig,
|
||||
UIntXWeightOnlyConfig
|
||||
)
|
||||
|
||||
# the quantize function in quanto had a bug where it was using exclude instead of include
|
||||
|
||||
Q_MODULES = ['QLinear', 'QConv2d', 'QEmbedding', 'QBatchNorm2d', 'QLayerNorm', 'QConvTranspose2d', 'QEmbeddingBag']
|
||||
|
||||
torchao_qtypes = {
|
||||
# "int4": Int4WeightOnlyConfig(),
|
||||
"uint2": UIntXWeightOnlyConfig(torch.uint2),
|
||||
"uint3": UIntXWeightOnlyConfig(torch.uint3),
|
||||
"uint4": UIntXWeightOnlyConfig(torch.uint4),
|
||||
"uint5": UIntXWeightOnlyConfig(torch.uint5),
|
||||
"uint6": UIntXWeightOnlyConfig(torch.uint6),
|
||||
"uint7": UIntXWeightOnlyConfig(torch.uint7),
|
||||
"uint8": UIntXWeightOnlyConfig(torch.uint8),
|
||||
"float8": Float8WeightOnlyConfig(),
|
||||
}
|
||||
|
||||
class aotype:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.config = torchao_qtypes[name]
|
||||
|
||||
def get_qtype(qtype: Union[str, qtype]) -> qtype:
|
||||
if qtype in torchao_qtypes:
|
||||
return aotype(qtype)
|
||||
if isinstance(qtype, str):
|
||||
return qtypes[qtype]
|
||||
else:
|
||||
return qtype
|
||||
|
||||
def quantize(
|
||||
model: torch.nn.Module,
|
||||
weights: Optional[Union[str, qtype]] = None,
|
||||
weights: Optional[Union[str, qtype, aotype]] = None,
|
||||
activations: Optional[Union[str, qtype]] = None,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
include: Optional[Union[str, List[str]]] = None,
|
||||
@@ -57,8 +88,11 @@ def quantize(
|
||||
if m.__class__.__name__ in Q_MODULES:
|
||||
continue
|
||||
else:
|
||||
_quantize_submodule(model, name, m, weights=weights,
|
||||
activations=activations, optimizer=optimizer)
|
||||
if isinstance(weights, aotype):
|
||||
torchao_quantize_(m, weights.config)
|
||||
else:
|
||||
_quantize_submodule(model, name, m, weights=weights,
|
||||
activations=activations, optimizer=optimizer)
|
||||
except Exception as e:
|
||||
print(f"Failed to quantize {name}: {e}")
|
||||
raise e
|
||||
raise e
|
||||
Reference in New Issue
Block a user