mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-30 11:21:28 +00:00
Add monkey patched fairseq package to run on python 3.11 (what is needed for our use of RVC at least)
This commit is contained in:
905
modules/voice_conversion/fairseq/checkpoint_utils.py
Normal file
905
modules/voice_conversion/fairseq/checkpoint_utils.py
Normal file
@@ -0,0 +1,905 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import ast
|
||||
import collections
|
||||
import contextlib
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import data_utils
|
||||
from fairseq.dataclass.configs import CheckpointConfig
|
||||
from fairseq.dataclass.utils import (
|
||||
convert_namespace_to_omegaconf,
|
||||
overwrite_args_by_name,
|
||||
)
|
||||
from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP
|
||||
from fairseq.file_io import PathManager
|
||||
from fairseq.models import FairseqDecoder, FairseqEncoder
|
||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
|
||||
from fairseq import meters
|
||||
|
||||
# only one worker should attempt to create the required dir
|
||||
if trainer.data_parallel_rank == 0:
|
||||
os.makedirs(cfg.save_dir, exist_ok=True)
|
||||
|
||||
prev_best = getattr(save_checkpoint, "best", val_loss)
|
||||
if val_loss is not None:
|
||||
best_function = max if cfg.maximize_best_checkpoint_metric else min
|
||||
save_checkpoint.best = best_function(val_loss, prev_best)
|
||||
|
||||
if cfg.no_save:
|
||||
return
|
||||
|
||||
trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state
|
||||
|
||||
if not trainer.should_save_checkpoint_on_current_rank:
|
||||
if trainer.always_call_state_dict_during_save_checkpoint:
|
||||
trainer.state_dict()
|
||||
return
|
||||
|
||||
write_timer = meters.StopwatchMeter()
|
||||
write_timer.start()
|
||||
|
||||
epoch = epoch_itr.epoch
|
||||
end_of_epoch = epoch_itr.end_of_epoch()
|
||||
updates = trainer.get_num_updates()
|
||||
|
||||
logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")
|
||||
|
||||
def is_better(a, b):
|
||||
return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
|
||||
|
||||
suffix = trainer.checkpoint_suffix
|
||||
checkpoint_conds = collections.OrderedDict()
|
||||
checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
|
||||
end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
|
||||
)
|
||||
checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
|
||||
not end_of_epoch
|
||||
and cfg.save_interval_updates > 0
|
||||
and updates % cfg.save_interval_updates == 0
|
||||
)
|
||||
checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
|
||||
not hasattr(save_checkpoint, "best")
|
||||
or is_better(val_loss, save_checkpoint.best)
|
||||
)
|
||||
if val_loss is not None and cfg.keep_best_checkpoints > 0:
|
||||
worst_best = getattr(save_checkpoint, "best", None)
|
||||
chkpts = checkpoint_paths(
|
||||
cfg.save_dir,
|
||||
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
|
||||
cfg.best_checkpoint_metric, suffix
|
||||
),
|
||||
)
|
||||
if len(chkpts) > 0:
|
||||
p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0]
|
||||
worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), ""))
|
||||
# add random digits to resolve ties
|
||||
with data_utils.numpy_seed(epoch, updates, val_loss):
|
||||
rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints)
|
||||
|
||||
checkpoint_conds[
|
||||
"checkpoint.best_{}_{:.3f}{}{}.pt".format(
|
||||
cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix
|
||||
)
|
||||
] = worst_best is None or is_better(val_loss, worst_best)
|
||||
checkpoint_conds[
|
||||
"checkpoint_last{}.pt".format(suffix)
|
||||
] = not cfg.no_last_checkpoints
|
||||
|
||||
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
|
||||
if hasattr(save_checkpoint, "best"):
|
||||
extra_state.update({"best": save_checkpoint.best})
|
||||
|
||||
checkpoints = [
|
||||
os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
|
||||
]
|
||||
if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank:
|
||||
trainer.save_checkpoint(checkpoints[0], extra_state)
|
||||
for cp in checkpoints[1:]:
|
||||
if cfg.write_checkpoints_asynchronously:
|
||||
# TODO[ioPath]: Need to implement a delayed asynchronous
|
||||
# file copying/moving feature.
|
||||
logger.warning(
|
||||
f"ioPath is not copying {checkpoints[0]} to {cp} "
|
||||
"since async write mode is on."
|
||||
)
|
||||
else:
|
||||
assert PathManager.copy(
|
||||
checkpoints[0], cp, overwrite=True
|
||||
), f"Failed to copy {checkpoints[0]} to {cp}"
|
||||
|
||||
write_timer.stop()
|
||||
logger.info(
|
||||
"Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
|
||||
checkpoints[0], epoch, updates, val_loss, write_timer.sum
|
||||
)
|
||||
)
|
||||
|
||||
if not end_of_epoch and cfg.keep_interval_updates > 0:
|
||||
# remove old checkpoints; checkpoints are sorted in descending order
|
||||
if cfg.keep_interval_updates_pattern == -1:
|
||||
checkpoints = checkpoint_paths(
|
||||
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
|
||||
)
|
||||
else:
|
||||
checkpoints = checkpoint_paths(
|
||||
cfg.save_dir,
|
||||
pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix),
|
||||
keep_match=True,
|
||||
)
|
||||
checkpoints = [
|
||||
x[0]
|
||||
for x in checkpoints
|
||||
if x[1] % cfg.keep_interval_updates_pattern != 0
|
||||
]
|
||||
|
||||
for old_chk in checkpoints[cfg.keep_interval_updates :]:
|
||||
if os.path.lexists(old_chk):
|
||||
os.remove(old_chk)
|
||||
elif PathManager.exists(old_chk):
|
||||
PathManager.rm(old_chk)
|
||||
|
||||
if cfg.keep_last_epochs > 0:
|
||||
# remove old epoch checkpoints; checkpoints are sorted in descending order
|
||||
checkpoints = checkpoint_paths(
|
||||
cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
|
||||
)
|
||||
for old_chk in checkpoints[cfg.keep_last_epochs :]:
|
||||
if os.path.lexists(old_chk):
|
||||
os.remove(old_chk)
|
||||
elif PathManager.exists(old_chk):
|
||||
PathManager.rm(old_chk)
|
||||
|
||||
if cfg.keep_best_checkpoints > 0:
|
||||
# only keep the best N checkpoints according to validation metric
|
||||
checkpoints = checkpoint_paths(
|
||||
cfg.save_dir,
|
||||
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
|
||||
cfg.best_checkpoint_metric, suffix
|
||||
),
|
||||
)
|
||||
if not cfg.maximize_best_checkpoint_metric:
|
||||
checkpoints = checkpoints[::-1]
|
||||
for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
|
||||
if os.path.lexists(old_chk):
|
||||
os.remove(old_chk)
|
||||
elif PathManager.exists(old_chk):
|
||||
PathManager.rm(old_chk)
|
||||
|
||||
|
||||
def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
|
||||
"""
|
||||
Load a checkpoint and restore the training iterator.
|
||||
|
||||
*passthrough_args* will be passed through to
|
||||
``trainer.get_train_iterator``.
|
||||
"""
|
||||
|
||||
reset_optimizer = cfg.reset_optimizer
|
||||
reset_lr_scheduler = cfg.reset_lr_scheduler
|
||||
optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
|
||||
reset_meters = cfg.reset_meters
|
||||
reset_dataloader = cfg.reset_dataloader
|
||||
|
||||
if cfg.finetune_from_model is not None and (
|
||||
reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
|
||||
):
|
||||
raise ValueError(
|
||||
"--finetune-from-model can not be set together with either --reset-optimizer"
|
||||
" or reset_lr_scheduler or reset_meters or reset_dataloader"
|
||||
)
|
||||
|
||||
suffix = trainer.checkpoint_suffix
|
||||
if (
|
||||
cfg.restore_file == "checkpoint_last.pt"
|
||||
): # default value of restore_file is 'checkpoint_last.pt'
|
||||
checkpoint_path = os.path.join(
|
||||
cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
|
||||
)
|
||||
first_launch = not PathManager.exists(checkpoint_path)
|
||||
if first_launch and getattr(cfg, "continue_once", None) is not None:
|
||||
checkpoint_path = cfg.continue_once
|
||||
elif cfg.finetune_from_model is not None and first_launch:
|
||||
# if there is no last checkpoint to restore, start the finetune from pretrained model
|
||||
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
|
||||
if PathManager.exists(cfg.finetune_from_model):
|
||||
checkpoint_path = cfg.finetune_from_model
|
||||
reset_optimizer = True
|
||||
reset_lr_scheduler = True
|
||||
reset_meters = True
|
||||
reset_dataloader = True
|
||||
logger.info(
|
||||
f"loading pretrained model from {checkpoint_path}: "
|
||||
"optimizer, lr scheduler, meters, dataloader will be reset"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"--finetune-from-model {cfg.finetune_from_model} does not exist"
|
||||
)
|
||||
elif suffix is not None:
|
||||
checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
|
||||
else:
|
||||
checkpoint_path = cfg.restore_file
|
||||
|
||||
if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
|
||||
raise ValueError(
|
||||
"--finetune-from-model and --restore-file (non-default value) "
|
||||
"can not be specified together: " + str(cfg)
|
||||
)
|
||||
|
||||
extra_state = trainer.load_checkpoint(
|
||||
checkpoint_path,
|
||||
reset_optimizer,
|
||||
reset_lr_scheduler,
|
||||
optimizer_overrides,
|
||||
reset_meters=reset_meters,
|
||||
)
|
||||
|
||||
if (
|
||||
extra_state is not None
|
||||
and "best" in extra_state
|
||||
and not reset_optimizer
|
||||
and not reset_meters
|
||||
):
|
||||
save_checkpoint.best = extra_state["best"]
|
||||
|
||||
if extra_state is not None and not reset_dataloader:
|
||||
# restore iterator from checkpoint
|
||||
itr_state = extra_state["train_iterator"]
|
||||
epoch_itr = trainer.get_train_iterator(
|
||||
epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
|
||||
)
|
||||
epoch_itr.load_state_dict(itr_state)
|
||||
else:
|
||||
epoch_itr = trainer.get_train_iterator(
|
||||
epoch=1, load_dataset=True, **passthrough_args
|
||||
)
|
||||
|
||||
trainer.lr_step(epoch_itr.epoch)
|
||||
|
||||
return extra_state, epoch_itr
|
||||
|
||||
|
||||
def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
|
||||
"""Loads a checkpoint to CPU (with upgrading for backward compatibility).
|
||||
|
||||
If doing single-GPU training or if the checkpoint is only being loaded by at
|
||||
most one process on each node (current default behavior is for only rank 0
|
||||
to read the checkpoint from disk), load_on_all_ranks should be False to
|
||||
avoid errors from torch.distributed not having been initialized or
|
||||
torch.distributed.barrier() hanging.
|
||||
|
||||
If all processes on each node may be loading the checkpoint
|
||||
simultaneously, load_on_all_ranks should be set to True to avoid I/O
|
||||
conflicts.
|
||||
|
||||
There's currently no support for > 1 but < all processes loading the
|
||||
checkpoint on each node.
|
||||
"""
|
||||
local_path = PathManager.get_local_path(path)
|
||||
# The locally cached file returned by get_local_path() may be stale for
|
||||
# remote files that are periodically updated/overwritten (ex:
|
||||
# checkpoint_last.pt) - so we remove the local copy, sync across processes
|
||||
# (if needed), and then download a fresh copy.
|
||||
if local_path != path and PathManager.path_requires_pathmanager(path):
|
||||
try:
|
||||
os.remove(local_path)
|
||||
except FileNotFoundError:
|
||||
# With potentially multiple processes removing the same file, the
|
||||
# file being missing is benign (missing_ok isn't available until
|
||||
# Python 3.8).
|
||||
pass
|
||||
if load_on_all_ranks:
|
||||
torch.distributed.barrier()
|
||||
local_path = PathManager.get_local_path(path)
|
||||
|
||||
with open(local_path, "rb") as f:
|
||||
state = torch.load(f, map_location=torch.device("cpu"))
|
||||
|
||||
if "args" in state and state["args"] is not None and arg_overrides is not None:
|
||||
args = state["args"]
|
||||
for arg_name, arg_val in arg_overrides.items():
|
||||
setattr(args, arg_name, arg_val)
|
||||
|
||||
if "cfg" in state and state["cfg"] is not None:
|
||||
|
||||
# hack to be able to set Namespace in dict config. this should be removed when we update to newer
|
||||
# omegaconf version that supports object flags, or when we migrate all existing models
|
||||
from omegaconf import __version__ as oc_version
|
||||
from omegaconf import _utils
|
||||
|
||||
if oc_version < "2.2":
|
||||
old_primitive = _utils.is_primitive_type
|
||||
_utils.is_primitive_type = lambda _: True
|
||||
|
||||
state["cfg"] = OmegaConf.create(state["cfg"])
|
||||
|
||||
_utils.is_primitive_type = old_primitive
|
||||
OmegaConf.set_struct(state["cfg"], True)
|
||||
else:
|
||||
state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True})
|
||||
|
||||
if arg_overrides is not None:
|
||||
overwrite_args_by_name(state["cfg"], arg_overrides)
|
||||
|
||||
state = _upgrade_state_dict(state)
|
||||
return state
|
||||
|
||||
|
||||
def load_model_ensemble(
|
||||
filenames,
|
||||
arg_overrides: Optional[Dict[str, Any]] = None,
|
||||
task=None,
|
||||
strict=True,
|
||||
suffix="",
|
||||
num_shards=1,
|
||||
state=None,
|
||||
):
|
||||
"""Loads an ensemble of models.
|
||||
|
||||
Args:
|
||||
filenames (List[str]): checkpoint files to load
|
||||
arg_overrides (Dict[str,Any], optional): override model args that
|
||||
were used during model training
|
||||
task (fairseq.tasks.FairseqTask, optional): task to use for loading
|
||||
"""
|
||||
assert not (
|
||||
strict and num_shards > 1
|
||||
), "Cannot load state dict with strict=True and checkpoint shards > 1"
|
||||
ensemble, args, _task = load_model_ensemble_and_task(
|
||||
filenames,
|
||||
arg_overrides,
|
||||
task,
|
||||
strict,
|
||||
suffix,
|
||||
num_shards,
|
||||
state,
|
||||
)
|
||||
return ensemble, args
|
||||
|
||||
|
||||
def get_maybe_sharded_checkpoint_filename(
|
||||
filename: str, suffix: str, shard_idx: int, num_shards: int
|
||||
) -> str:
|
||||
orig_filename = filename
|
||||
filename = filename.replace(".pt", suffix + ".pt")
|
||||
fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt"
|
||||
model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
|
||||
if PathManager.exists(fsdp_filename):
|
||||
return fsdp_filename
|
||||
elif num_shards > 1:
|
||||
return model_parallel_filename
|
||||
else:
|
||||
return filename
|
||||
|
||||
|
||||
def load_model_ensemble_and_task(
|
||||
filenames,
|
||||
arg_overrides: Optional[Dict[str, Any]] = None,
|
||||
task=None,
|
||||
strict=True,
|
||||
suffix="",
|
||||
num_shards=1,
|
||||
state=None,
|
||||
):
|
||||
assert state is None or len(filenames) == 1
|
||||
|
||||
from fairseq import tasks
|
||||
|
||||
assert not (
|
||||
strict and num_shards > 1
|
||||
), "Cannot load state dict with strict=True and checkpoint shards > 1"
|
||||
ensemble = []
|
||||
cfg = None
|
||||
for filename in filenames:
|
||||
orig_filename = filename
|
||||
model_shard_state = {"shard_weights": [], "shard_metadata": []}
|
||||
assert num_shards > 0
|
||||
st = time.time()
|
||||
for shard_idx in range(num_shards):
|
||||
filename = get_maybe_sharded_checkpoint_filename(
|
||||
orig_filename, suffix, shard_idx, num_shards
|
||||
)
|
||||
|
||||
if not PathManager.exists(filename):
|
||||
raise IOError("Model file not found: {}".format(filename))
|
||||
if state is None:
|
||||
state = load_checkpoint_to_cpu(filename, arg_overrides)
|
||||
if "args" in state and state["args"] is not None:
|
||||
cfg = convert_namespace_to_omegaconf(state["args"])
|
||||
elif "cfg" in state and state["cfg"] is not None:
|
||||
cfg = state["cfg"]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Neither args nor cfg exist in state keys = {state.keys()}"
|
||||
)
|
||||
|
||||
if task is None:
|
||||
task = tasks.setup_task(cfg.task)
|
||||
|
||||
if "task_state" in state:
|
||||
task.load_state_dict(state["task_state"])
|
||||
|
||||
if "fsdp_metadata" in state and num_shards > 1:
|
||||
model_shard_state["shard_weights"].append(state["model"])
|
||||
model_shard_state["shard_metadata"].append(state["fsdp_metadata"])
|
||||
# check FSDP import before the code goes too far
|
||||
if not has_FSDP:
|
||||
raise ImportError(
|
||||
"Cannot find FullyShardedDataParallel. "
|
||||
"Please install fairscale with: pip install fairscale"
|
||||
)
|
||||
if shard_idx == num_shards - 1:
|
||||
consolidated_model_state = FSDP.consolidate_shard_weights(
|
||||
shard_weights=model_shard_state["shard_weights"],
|
||||
shard_metadata=model_shard_state["shard_metadata"],
|
||||
)
|
||||
model = task.build_model(cfg.model)
|
||||
if (
|
||||
"optimizer_history" in state
|
||||
and len(state["optimizer_history"]) > 0
|
||||
and "num_updates" in state["optimizer_history"][-1]
|
||||
):
|
||||
model.set_num_updates(
|
||||
state["optimizer_history"][-1]["num_updates"]
|
||||
)
|
||||
model.load_state_dict(
|
||||
consolidated_model_state, strict=strict, model_cfg=cfg.model
|
||||
)
|
||||
else:
|
||||
# model parallel checkpoint or unsharded checkpoint
|
||||
# support old external tasks
|
||||
|
||||
argspec = inspect.getfullargspec(task.build_model)
|
||||
if "from_checkpoint" in argspec.args:
|
||||
model = task.build_model(cfg.model, from_checkpoint=True)
|
||||
else:
|
||||
model = task.build_model(cfg.model)
|
||||
if (
|
||||
"optimizer_history" in state
|
||||
and len(state["optimizer_history"]) > 0
|
||||
and "num_updates" in state["optimizer_history"][-1]
|
||||
):
|
||||
model.set_num_updates(state["optimizer_history"][-1]["num_updates"])
|
||||
model.load_state_dict(
|
||||
state["model"], strict=strict, model_cfg=cfg.model
|
||||
)
|
||||
|
||||
# reset state so it gets loaded for the next model in ensemble
|
||||
state = None
|
||||
if shard_idx % 10 == 0 and shard_idx > 0:
|
||||
elapsed = time.time() - st
|
||||
logger.info(
|
||||
f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard"
|
||||
)
|
||||
|
||||
# build model for ensemble
|
||||
ensemble.append(model)
|
||||
return ensemble, cfg, task
|
||||
|
||||
|
||||
def load_model_ensemble_and_task_from_hf_hub(
|
||||
model_id,
|
||||
cache_dir: Optional[str] = None,
|
||||
arg_overrides: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"You need to install huggingface_hub to use `load_from_hf_hub`. "
|
||||
"See https://pypi.org/project/huggingface-hub/ for installation."
|
||||
)
|
||||
|
||||
library_name = "fairseq"
|
||||
cache_dir = cache_dir or (Path.home() / ".cache" / library_name).as_posix()
|
||||
cache_dir = snapshot_download(
|
||||
model_id, cache_dir=cache_dir, library_name=library_name, **kwargs
|
||||
)
|
||||
|
||||
_arg_overrides = arg_overrides or {}
|
||||
_arg_overrides["data"] = cache_dir
|
||||
return load_model_ensemble_and_task(
|
||||
[p.as_posix() for p in Path(cache_dir).glob("*.pt")],
|
||||
arg_overrides=_arg_overrides,
|
||||
)
|
||||
|
||||
|
||||
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False):
|
||||
"""Retrieves all checkpoints found in `path` directory.
|
||||
|
||||
Checkpoints are identified by matching filename to the specified pattern. If
|
||||
the pattern contains groups, the result will be sorted by the first group in
|
||||
descending order.
|
||||
"""
|
||||
pt_regexp = re.compile(pattern)
|
||||
files = PathManager.ls(path)
|
||||
|
||||
entries = []
|
||||
for i, f in enumerate(files):
|
||||
m = pt_regexp.fullmatch(f)
|
||||
if m is not None:
|
||||
idx = float(m.group(1)) if len(m.groups()) > 0 else i
|
||||
entries.append((idx, m.group(0)))
|
||||
if keep_match:
|
||||
return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)]
|
||||
else:
|
||||
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
|
||||
|
||||
|
||||
def torch_persistent_save(obj, filename, async_write: bool = False):
|
||||
if async_write:
|
||||
with PathManager.opena(filename, "wb") as f:
|
||||
_torch_persistent_save(obj, f)
|
||||
else:
|
||||
if PathManager.supports_rename(filename):
|
||||
# do atomic save
|
||||
with PathManager.open(filename + ".tmp", "wb") as f:
|
||||
_torch_persistent_save(obj, f)
|
||||
PathManager.rename(filename + ".tmp", filename)
|
||||
else:
|
||||
# fallback to non-atomic save
|
||||
with PathManager.open(filename, "wb") as f:
|
||||
_torch_persistent_save(obj, f)
|
||||
|
||||
|
||||
def _torch_persistent_save(obj, f):
|
||||
if isinstance(f, str):
|
||||
with PathManager.open(f, "wb") as h:
|
||||
torch_persistent_save(obj, h)
|
||||
return
|
||||
for i in range(3):
|
||||
try:
|
||||
return torch.save(obj, f)
|
||||
except Exception:
|
||||
if i == 2:
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
|
||||
def _upgrade_state_dict(state):
|
||||
"""Helper for upgrading old model checkpoints."""
|
||||
|
||||
# add optimizer_history
|
||||
if "optimizer_history" not in state:
|
||||
state["optimizer_history"] = [
|
||||
{"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
|
||||
]
|
||||
state["last_optimizer_state"] = state["optimizer"]
|
||||
del state["optimizer"]
|
||||
del state["best_loss"]
|
||||
# move extra_state into sub-dictionary
|
||||
if "epoch" in state and "extra_state" not in state:
|
||||
state["extra_state"] = {
|
||||
"epoch": state["epoch"],
|
||||
"batch_offset": state["batch_offset"],
|
||||
"val_loss": state["val_loss"],
|
||||
}
|
||||
del state["epoch"]
|
||||
del state["batch_offset"]
|
||||
del state["val_loss"]
|
||||
# reduce optimizer history's memory usage (only keep the last state)
|
||||
if "optimizer" in state["optimizer_history"][-1]:
|
||||
state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
|
||||
for optim_hist in state["optimizer_history"]:
|
||||
del optim_hist["optimizer"]
|
||||
# record the optimizer class name
|
||||
if "optimizer_name" not in state["optimizer_history"][-1]:
|
||||
state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
|
||||
# move best_loss into lr_scheduler_state
|
||||
if "lr_scheduler_state" not in state["optimizer_history"][-1]:
|
||||
state["optimizer_history"][-1]["lr_scheduler_state"] = {
|
||||
"best": state["optimizer_history"][-1]["best_loss"]
|
||||
}
|
||||
del state["optimizer_history"][-1]["best_loss"]
|
||||
# keep track of number of updates
|
||||
if "num_updates" not in state["optimizer_history"][-1]:
|
||||
state["optimizer_history"][-1]["num_updates"] = 0
|
||||
# use stateful training data iterator
|
||||
if "train_iterator" not in state["extra_state"]:
|
||||
state["extra_state"]["train_iterator"] = {
|
||||
"epoch": state["extra_state"].get("epoch", 0),
|
||||
"iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
|
||||
}
|
||||
|
||||
# backward compatibility, cfg updates
|
||||
if "args" in state and state["args"] is not None:
|
||||
# old model checkpoints may not have separate source/target positions
|
||||
if hasattr(state["args"], "max_positions") and not hasattr(
|
||||
state["args"], "max_source_positions"
|
||||
):
|
||||
state["args"].max_source_positions = state["args"].max_positions
|
||||
state["args"].max_target_positions = state["args"].max_positions
|
||||
# default to translation task
|
||||
if not hasattr(state["args"], "task"):
|
||||
state["args"].task = "translation"
|
||||
# --raw-text and --lazy-load are deprecated
|
||||
if getattr(state["args"], "raw_text", False):
|
||||
state["args"].dataset_impl = "raw"
|
||||
elif getattr(state["args"], "lazy_load", False):
|
||||
state["args"].dataset_impl = "lazy"
|
||||
# epochs start at 1
|
||||
if state["extra_state"]["train_iterator"] is not None:
|
||||
state["extra_state"]["train_iterator"]["epoch"] = max(
|
||||
state["extra_state"]["train_iterator"].get("epoch", 1), 1
|
||||
)
|
||||
# --remove-bpe ==> --postprocess
|
||||
if hasattr(state["args"], "remove_bpe"):
|
||||
state["args"].post_process = state["args"].remove_bpe
|
||||
# --min-lr ==> --stop-min-lr
|
||||
if hasattr(state["args"], "min_lr"):
|
||||
state["args"].stop_min_lr = state["args"].min_lr
|
||||
del state["args"].min_lr
|
||||
# binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
|
||||
if hasattr(state["args"], "criterion") and state["args"].criterion in [
|
||||
"binary_cross_entropy",
|
||||
"kd_binary_cross_entropy",
|
||||
]:
|
||||
state["args"].criterion = "wav2vec"
|
||||
# remove log_keys if it's None (criteria will supply a default value of [])
|
||||
if hasattr(state["args"], "log_keys") and state["args"].log_keys is None:
|
||||
delattr(state["args"], "log_keys")
|
||||
# speech_pretraining => audio pretraining
|
||||
if (
|
||||
hasattr(state["args"], "task")
|
||||
and state["args"].task == "speech_pretraining"
|
||||
):
|
||||
state["args"].task = "audio_pretraining"
|
||||
# audio_cpc => wav2vec
|
||||
if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc":
|
||||
state["args"].arch = "wav2vec"
|
||||
# convert legacy float learning rate to List[float]
|
||||
if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float):
|
||||
state["args"].lr = [state["args"].lr]
|
||||
# convert task data arg to a string instead of List[string]
|
||||
if (
|
||||
hasattr(state["args"], "data")
|
||||
and isinstance(state["args"].data, list)
|
||||
and len(state["args"].data) > 0
|
||||
):
|
||||
state["args"].data = state["args"].data[0]
|
||||
|
||||
state["cfg"] = convert_namespace_to_omegaconf(state["args"])
|
||||
|
||||
if "cfg" in state and state["cfg"] is not None:
|
||||
cfg = state["cfg"]
|
||||
with open_dict(cfg):
|
||||
# any upgrades for Hydra-based configs
|
||||
if (
|
||||
"task" in cfg
|
||||
and "eval_wer_config" in cfg.task
|
||||
and isinstance(cfg.task.eval_wer_config.print_alignment, bool)
|
||||
):
|
||||
cfg.task.eval_wer_config.print_alignment = "hard"
|
||||
if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool):
|
||||
cfg.generation.print_alignment = (
|
||||
"hard" if cfg.generation.print_alignment else None
|
||||
)
|
||||
if (
|
||||
"model" in cfg
|
||||
and "w2v_args" in cfg.model
|
||||
and cfg.model.w2v_args is not None
|
||||
and (
|
||||
hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args
|
||||
)
|
||||
and hasattr(cfg.model.w2v_args.task, "eval_wer_config")
|
||||
and cfg.model.w2v_args.task.eval_wer_config is not None
|
||||
and isinstance(
|
||||
cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool
|
||||
)
|
||||
):
|
||||
cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard"
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
|
||||
"""Prune the given state_dict if desired for LayerDrop
|
||||
(https://arxiv.org/abs/1909.11556).
|
||||
|
||||
Training with LayerDrop allows models to be robust to pruning at inference
|
||||
time. This function prunes state_dict to allow smaller models to be loaded
|
||||
from a larger model and re-maps the existing state_dict for this to occur.
|
||||
|
||||
It's called by functions that load models from checkpoints and does not
|
||||
need to be called directly.
|
||||
"""
|
||||
arch = None
|
||||
if model_cfg is not None:
|
||||
arch = (
|
||||
model_cfg._name
|
||||
if isinstance(model_cfg, DictConfig)
|
||||
else getattr(model_cfg, "arch", None)
|
||||
)
|
||||
|
||||
if not model_cfg or arch is None or arch == "ptt_transformer":
|
||||
# args should not be none, but don't crash if it is.
|
||||
return state_dict
|
||||
|
||||
encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
|
||||
decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
|
||||
|
||||
if not encoder_layers_to_keep and not decoder_layers_to_keep:
|
||||
return state_dict
|
||||
|
||||
# apply pruning
|
||||
logger.info(
|
||||
"Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
|
||||
)
|
||||
|
||||
def create_pruning_pass(layers_to_keep, layer_name):
|
||||
keep_layers = sorted(
|
||||
int(layer_string) for layer_string in layers_to_keep.split(",")
|
||||
)
|
||||
mapping_dict = {}
|
||||
for i in range(len(keep_layers)):
|
||||
mapping_dict[str(keep_layers[i])] = str(i)
|
||||
|
||||
regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
|
||||
return {"substitution_regex": regex, "mapping_dict": mapping_dict}
|
||||
|
||||
pruning_passes = []
|
||||
if encoder_layers_to_keep:
|
||||
pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
|
||||
if decoder_layers_to_keep:
|
||||
pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
|
||||
|
||||
new_state_dict = {}
|
||||
for layer_name in state_dict.keys():
|
||||
match = re.search(r"\.layers\.(\d+)\.", layer_name)
|
||||
# if layer has no number in it, it is a supporting layer, such as an
|
||||
# embedding
|
||||
if not match:
|
||||
new_state_dict[layer_name] = state_dict[layer_name]
|
||||
continue
|
||||
|
||||
# otherwise, layer should be pruned.
|
||||
original_layer_number = match.group(1)
|
||||
# figure out which mapping dict to replace from
|
||||
for pruning_pass in pruning_passes:
|
||||
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
|
||||
"substitution_regex"
|
||||
].search(layer_name):
|
||||
new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
|
||||
substitution_match = pruning_pass["substitution_regex"].search(
|
||||
layer_name
|
||||
)
|
||||
new_state_key = (
|
||||
layer_name[: substitution_match.start(1)]
|
||||
+ new_layer_number
|
||||
+ layer_name[substitution_match.end(1) :]
|
||||
)
|
||||
new_state_dict[new_state_key] = state_dict[layer_name]
|
||||
|
||||
# Since layers are now pruned, *_layers_to_keep are no longer needed.
|
||||
# This is more of "It would make it work fix" rather than a proper fix.
|
||||
if isinstance(model_cfg, DictConfig):
|
||||
context = open_dict(model_cfg)
|
||||
else:
|
||||
context = contextlib.ExitStack()
|
||||
with context:
|
||||
if hasattr(model_cfg, "encoder_layers_to_keep"):
|
||||
model_cfg.encoder_layers_to_keep = None
|
||||
if hasattr(model_cfg, "decoder_layers_to_keep"):
|
||||
model_cfg.decoder_layers_to_keep = None
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def load_pretrained_component_from_model(
|
||||
component: Union[FairseqEncoder, FairseqDecoder],
|
||||
checkpoint: str,
|
||||
strict: bool = True,
|
||||
):
|
||||
"""
|
||||
Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
|
||||
provided `component` object. If state_dict fails to load, there may be a
|
||||
mismatch in the architecture of the corresponding `component` found in the
|
||||
`checkpoint` file.
|
||||
"""
|
||||
if not PathManager.exists(checkpoint):
|
||||
raise IOError("Model file not found: {}".format(checkpoint))
|
||||
state = load_checkpoint_to_cpu(checkpoint)
|
||||
if isinstance(component, FairseqEncoder):
|
||||
component_type = "encoder"
|
||||
elif isinstance(component, FairseqDecoder):
|
||||
component_type = "decoder"
|
||||
else:
|
||||
raise ValueError(
|
||||
"component to load must be either a FairseqEncoder or "
|
||||
"FairseqDecoder. Loading other component types are not supported."
|
||||
)
|
||||
component_state_dict = OrderedDict()
|
||||
for key in state["model"].keys():
|
||||
if key.startswith(component_type):
|
||||
# encoder.input_layers.0.0.weight --> input_layers.0.0.weight
|
||||
component_subkey = key[len(component_type) + 1 :]
|
||||
component_state_dict[component_subkey] = state["model"][key]
|
||||
component.load_state_dict(component_state_dict, strict=strict)
|
||||
return component
|
||||
|
||||
|
||||
def verify_checkpoint_directory(save_dir: str) -> None:
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
temp_file_path = os.path.join(save_dir, "dummy")
|
||||
try:
|
||||
with open(temp_file_path, "w"):
|
||||
pass
|
||||
except OSError as e:
|
||||
logger.warning(
|
||||
"Unable to access checkpoint save directory: {}".format(save_dir)
|
||||
)
|
||||
raise e
|
||||
else:
|
||||
os.remove(temp_file_path)
|
||||
|
||||
|
||||
def save_ema_as_checkpoint(src_path, dst_path):
|
||||
state = load_ema_from_checkpoint(src_path)
|
||||
torch_persistent_save(state, dst_path)
|
||||
|
||||
|
||||
def load_ema_from_checkpoint(fpath):
|
||||
"""Loads exponential moving averaged (EMA) checkpoint from input and
|
||||
returns a model with ema weights.
|
||||
|
||||
Args:
|
||||
fpath: A string path of checkpoint to load from.
|
||||
|
||||
Returns:
|
||||
A dict of string keys mapping to various values. The 'model' key
|
||||
from the returned dict should correspond to an OrderedDict mapping
|
||||
string parameter names to torch Tensors.
|
||||
"""
|
||||
params_dict = collections.OrderedDict()
|
||||
new_state = None
|
||||
|
||||
with PathManager.open(fpath, "rb") as f:
|
||||
new_state = torch.load(
|
||||
f,
|
||||
map_location=(
|
||||
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
|
||||
),
|
||||
)
|
||||
|
||||
# EMA model is stored in a separate "extra state"
|
||||
model_params = new_state["extra_state"]["ema"]
|
||||
|
||||
for key in list(model_params.keys()):
|
||||
p = model_params[key]
|
||||
if isinstance(p, torch.HalfTensor):
|
||||
p = p.float()
|
||||
if key not in params_dict:
|
||||
params_dict[key] = p.clone()
|
||||
# NOTE: clone() is needed in case of p is a shared parameter
|
||||
else:
|
||||
raise ValueError("Key {} is repeated in EMA model params.".format(key))
|
||||
|
||||
if len(params_dict) == 0:
|
||||
raise ValueError(
|
||||
f"Input checkpoint path '{fpath}' does not contain "
|
||||
"ema model weights, is this model trained with EMA?"
|
||||
)
|
||||
|
||||
new_state["model"] = params_dict
|
||||
return new_state
|
||||
Reference in New Issue
Block a user