mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-02-22 06:04:26 +00:00
144 lines
5.4 KiB
Python
144 lines
5.4 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import logging
|
|
|
|
from fairseq.modules.quantization import pq, quantization_options, scalar
|
|
from omegaconf import DictConfig
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def quantize_model_scalar(model, model_cfg: DictConfig):
|
|
quant_noise_scalar = getattr(model_cfg, "quant_noise_scalar", 0) or 0
|
|
if quant_noise_scalar > 0:
|
|
# quantize_model edits the model in place
|
|
scalar.quantize_model_(model, p=quant_noise_scalar, bits=8, update_step=1000)
|
|
return model
|
|
|
|
|
|
class Quantizer(object):
|
|
def __init__(self, config_path, max_epoch, max_update):
|
|
try:
|
|
import yaml
|
|
except ImportError:
|
|
raise ImportError("Please install yaml with: pip install yaml")
|
|
|
|
# parse config
|
|
if config_path:
|
|
with open(config_path) as config_file:
|
|
config = quantization_options.parse_config_yaml(
|
|
yaml.safe_load(config_file)
|
|
)
|
|
else:
|
|
config = quantization_options.parse_config_yaml({})
|
|
|
|
self.n_centroids_config = config["n_centroids"]
|
|
self.block_sizes_config = config["block_sizes"]
|
|
self.layers_to_quantize = config["layers_to_quantize"]
|
|
|
|
# We assume that training will run for a fixed number of epochs
|
|
# (or updates) and that we should train for equal durations
|
|
# between iterations of PQ.
|
|
num_iterations = len(self.layers_to_quantize)
|
|
if max_epoch > 0:
|
|
assert max_epoch % num_iterations == 0, (
|
|
"for iterative PQ, --max-epoch (={}) must be evenly divisible by "
|
|
"len(layers_to_quantize) (={})".format(max_epoch, num_iterations)
|
|
)
|
|
self.epoch_schedule = max_epoch // num_iterations
|
|
else:
|
|
self.epoch_schedule = None
|
|
if max_update > 0:
|
|
assert max_update % num_iterations == 0, (
|
|
"for iterative PQ, --max-update (={}) must be evenly divisible by "
|
|
"len(layers_to_quantize) (={})".format(max_update, num_iterations)
|
|
)
|
|
self.update_schedule = max_update // num_iterations
|
|
else:
|
|
self.update_schedule = None
|
|
assert (self.epoch_schedule is not None) ^ (
|
|
self.update_schedule is not None
|
|
), "for iterative PQ, cannot specify both --max-update and --max-epoch"
|
|
|
|
# 0 is a special value for quantization step, which will force
|
|
# the first call to begin_epoch() to call step()
|
|
self.quantization_step = 0
|
|
|
|
def set_trainer(self, trainer):
|
|
self.trainer = trainer
|
|
self.size_tracker = pq.SizeTracker(self.trainer.get_model())
|
|
|
|
def step(self):
|
|
"""Move to the next stage of quantization."""
|
|
if self.quantization_step >= len(self.layers_to_quantize):
|
|
# Maybe we just finished the last training step or we loaded
|
|
# a checkpoint for an iterative PQ model which previously
|
|
# finished training. Either way, don't quantize again.
|
|
return
|
|
|
|
logger.info(
|
|
"quantizing model (step={}; layers_to_quantize[step]={})".format(
|
|
self.quantization_step, self.layers_to_quantize[self.quantization_step]
|
|
)
|
|
)
|
|
quantized_layers = pq.quantize_model_(
|
|
self.trainer.get_model(),
|
|
self.size_tracker,
|
|
self.layers_to_quantize,
|
|
self.block_sizes_config,
|
|
self.n_centroids_config,
|
|
step=self.quantization_step,
|
|
)
|
|
logger.info("quantized layers: {}".format(quantized_layers))
|
|
logger.info(self.size_tracker)
|
|
|
|
self.quantization_step += 1
|
|
|
|
# reintialize the Trainer since model parameters have changed
|
|
self.trainer.reinitialize()
|
|
|
|
def begin_epoch(self, epoch):
|
|
"""Called at the beginning of each epoch (epochs start at 1)."""
|
|
if (
|
|
(
|
|
self.epoch_schedule is not None
|
|
and epoch > 0
|
|
and (epoch - 1) % self.epoch_schedule == 0
|
|
)
|
|
# we always step once in the beginning, even if using
|
|
# update-based quantization
|
|
or self.quantization_step == 0
|
|
):
|
|
self.step()
|
|
|
|
def step_update(self, num_updates):
|
|
"""Called at the end of each step."""
|
|
if (
|
|
self.update_schedule is not None
|
|
and num_updates > 0
|
|
and num_updates % self.update_schedule == 0
|
|
):
|
|
self.step()
|
|
|
|
def state_dict(self):
|
|
return {
|
|
"n_centroids_config": self.n_centroids_config,
|
|
"block_sizes_config": self.block_sizes_config,
|
|
"layers_to_quantize": self.layers_to_quantize,
|
|
"epoch_schedule": self.epoch_schedule,
|
|
"update_schedule": self.update_schedule,
|
|
"quantization_step": self.quantization_step,
|
|
}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
self.n_centroids_config = state_dict["n_centroids_config"]
|
|
self.block_sizes_config = state_dict["block_sizes_config"]
|
|
self.layers_to_quantize = state_dict["layers_to_quantize"]
|
|
self.epoch_schedule = state_dict["epoch_schedule"]
|
|
self.update_schedule = state_dict["update_schedule"]
|
|
self.quantization_step = state_dict["quantization_step"]
|