mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-30 03:11:26 +00:00
Add monkey patched fairseq package to run on python 3.11 (what is needed for our use of RVC at least)
This commit is contained in:
951
modules/voice_conversion/fairseq/utils.py
Normal file
951
modules/voice_conversion/fairseq/utils.py
Normal file
@@ -0,0 +1,951 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import contextlib
|
||||
import copy
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fairseq.modules.multihead_attention import MultiheadAttention
|
||||
|
||||
try:
|
||||
from amp_C import multi_tensor_l2norm
|
||||
|
||||
multi_tensor_l2norm_available = True
|
||||
except ImportError:
|
||||
multi_tensor_l2norm_available = False
|
||||
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm
|
||||
except ImportError:
|
||||
xm = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MANIFOLD_PATH_SEP = "|"
|
||||
|
||||
|
||||
class FileContentsAction(argparse.Action):
|
||||
def __init__(self, option_strings, dest, nargs=None, **kwargs):
|
||||
if nargs is not None:
|
||||
raise ValueError("nargs not allowed")
|
||||
super(FileContentsAction, self).__init__(option_strings, dest, **kwargs)
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
from fairseq.file_io import PathManager
|
||||
|
||||
if PathManager.isfile(values):
|
||||
with PathManager.open(values) as f:
|
||||
argument = f.read().strip()
|
||||
else:
|
||||
argument = values
|
||||
setattr(namespace, self.dest, argument)
|
||||
|
||||
|
||||
def split_paths(paths: str, separator=os.pathsep) -> List[str]:
|
||||
return (
|
||||
paths.split(separator) if "://" not in paths else paths.split(MANIFOLD_PATH_SEP)
|
||||
)
|
||||
|
||||
|
||||
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
|
||||
from fairseq import checkpoint_utils
|
||||
|
||||
deprecation_warning(
|
||||
"utils.load_ensemble_for_inference is deprecated. "
|
||||
"Please use checkpoint_utils.load_model_ensemble instead."
|
||||
)
|
||||
return checkpoint_utils.load_model_ensemble(
|
||||
filenames, arg_overrides=model_arg_overrides, task=task
|
||||
)
|
||||
|
||||
|
||||
def apply_to_sample(f, sample):
|
||||
if hasattr(sample, "__len__") and len(sample) == 0:
|
||||
return {}
|
||||
|
||||
def _apply(x):
|
||||
if torch.is_tensor(x):
|
||||
return f(x)
|
||||
elif isinstance(x, collections.OrderedDict):
|
||||
# OrderedDict has attributes that needs to be preserved
|
||||
od = collections.OrderedDict(
|
||||
(key, _apply(value)) for key, value in x.items()
|
||||
)
|
||||
od.__dict__ = x.__dict__
|
||||
return od
|
||||
elif isinstance(x, dict):
|
||||
return {key: _apply(value) for key, value in x.items()}
|
||||
elif isinstance(x, list):
|
||||
return [_apply(x) for x in x]
|
||||
elif isinstance(x, tuple):
|
||||
return tuple(_apply(x) for x in x)
|
||||
elif isinstance(x, set):
|
||||
return {_apply(x) for x in x}
|
||||
else:
|
||||
return x
|
||||
|
||||
return _apply(sample)
|
||||
|
||||
|
||||
def move_to_cuda(sample, device=None):
|
||||
device = device or torch.cuda.current_device()
|
||||
|
||||
def _move_to_cuda(tensor):
|
||||
# non_blocking is ignored if tensor is not pinned, so we can always set
|
||||
# to True (see github.com/PyTorchLightning/pytorch-lightning/issues/620)
|
||||
return tensor.to(device=device, non_blocking=True)
|
||||
|
||||
return apply_to_sample(_move_to_cuda, sample)
|
||||
|
||||
|
||||
def move_to_cpu(sample):
|
||||
def _move_to_cpu(tensor):
|
||||
# PyTorch has poor support for half tensors (float16) on CPU.
|
||||
# Move any such tensors to float32.
|
||||
if tensor.dtype in {torch.bfloat16, torch.float16}:
|
||||
tensor = tensor.to(dtype=torch.float32)
|
||||
return tensor.cpu()
|
||||
|
||||
return apply_to_sample(_move_to_cpu, sample)
|
||||
|
||||
|
||||
def move_to_tpu(sample):
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
device = xm.xla_device()
|
||||
|
||||
def _move_to_tpu(tensor):
|
||||
return tensor.to(device)
|
||||
|
||||
return apply_to_sample(_move_to_tpu, sample)
|
||||
|
||||
|
||||
def get_incremental_state(
|
||||
module: "MultiheadAttention",
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||
key: str,
|
||||
) -> Optional[Dict[str, Optional[Tensor]]]:
|
||||
"""Helper for getting incremental state for an nn.Module."""
|
||||
return module.get_incremental_state(incremental_state, key)
|
||||
|
||||
|
||||
def set_incremental_state(
|
||||
module: "MultiheadAttention",
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||
key: str,
|
||||
value: Dict[str, Optional[Tensor]],
|
||||
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
|
||||
"""Helper for setting incremental state for an nn.Module."""
|
||||
if incremental_state is not None:
|
||||
result = module.set_incremental_state(incremental_state, key, value)
|
||||
if result is not None:
|
||||
incremental_state = result
|
||||
return incremental_state
|
||||
|
||||
|
||||
def load_align_dict(replace_unk):
|
||||
if replace_unk is None:
|
||||
align_dict = None
|
||||
elif isinstance(replace_unk, str) and len(replace_unk) > 0:
|
||||
# Load alignment dictionary for unknown word replacement if it was passed as an argument.
|
||||
align_dict = {}
|
||||
with open(replace_unk, "r") as f:
|
||||
for line in f:
|
||||
cols = line.split()
|
||||
align_dict[cols[0]] = cols[1]
|
||||
else:
|
||||
# No alignment dictionary provided but we still want to perform unknown word replacement by copying the
|
||||
# original source word.
|
||||
align_dict = {}
|
||||
return align_dict
|
||||
|
||||
|
||||
def print_embed_overlap(embed_dict, vocab_dict):
|
||||
embed_keys = set(embed_dict.keys())
|
||||
vocab_keys = set(vocab_dict.symbols)
|
||||
overlap = len(embed_keys & vocab_keys)
|
||||
logger.info("found {}/{} types in embedding file".format(overlap, len(vocab_dict)))
|
||||
|
||||
|
||||
def parse_embedding(embed_path):
|
||||
"""Parse embedding text file into a dictionary of word and embedding tensors.
|
||||
|
||||
The first line can have vocabulary size and dimension. The following lines
|
||||
should contain word and embedding separated by spaces.
|
||||
|
||||
Example:
|
||||
2 5
|
||||
the -0.0230 -0.0264 0.0287 0.0171 0.1403
|
||||
at -0.0395 -0.1286 0.0275 0.0254 -0.0932
|
||||
"""
|
||||
embed_dict = {}
|
||||
with open(embed_path) as f_embed:
|
||||
next(f_embed) # skip header
|
||||
for line in f_embed:
|
||||
pieces = line.rstrip().split(" ")
|
||||
embed_dict[pieces[0]] = torch.Tensor(
|
||||
[float(weight) for weight in pieces[1:]]
|
||||
)
|
||||
return embed_dict
|
||||
|
||||
|
||||
def load_embedding(embed_dict, vocab, embedding):
|
||||
for idx in range(len(vocab)):
|
||||
token = vocab[idx]
|
||||
if token in embed_dict:
|
||||
embedding.weight.data[idx] = embed_dict[token]
|
||||
return embedding
|
||||
|
||||
|
||||
def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
|
||||
from fairseq import tokenizer
|
||||
|
||||
# Tokens are strings here
|
||||
hypo_tokens = tokenizer.tokenize_line(hypo_str)
|
||||
# TODO: Very rare cases where the replacement is '<eos>' should be handled gracefully
|
||||
src_tokens = tokenizer.tokenize_line(src_str) + ["<eos>"]
|
||||
for i, ht in enumerate(hypo_tokens):
|
||||
if ht == unk:
|
||||
src_token = src_tokens[alignment[i]]
|
||||
# Either take the corresponding value in the aligned dictionary or just copy the original value.
|
||||
hypo_tokens[i] = align_dict.get(src_token, src_token)
|
||||
return " ".join(hypo_tokens)
|
||||
|
||||
|
||||
def post_process_prediction(
|
||||
hypo_tokens,
|
||||
src_str,
|
||||
alignment,
|
||||
align_dict,
|
||||
tgt_dict,
|
||||
remove_bpe=None,
|
||||
extra_symbols_to_ignore=None,
|
||||
):
|
||||
hypo_str = tgt_dict.string(
|
||||
hypo_tokens, remove_bpe, extra_symbols_to_ignore=extra_symbols_to_ignore
|
||||
)
|
||||
if align_dict is not None:
|
||||
hypo_str = replace_unk(
|
||||
hypo_str, src_str, alignment, align_dict, tgt_dict.unk_string()
|
||||
)
|
||||
if align_dict is not None or remove_bpe is not None:
|
||||
# Convert back to tokens for evaluating with unk replacement or without BPE
|
||||
# Note that the dictionary can be modified inside the method.
|
||||
hypo_tokens = tgt_dict.encode_line(hypo_str, add_if_not_exist=True)
|
||||
return hypo_tokens, hypo_str, alignment
|
||||
|
||||
|
||||
def make_positions(tensor, padding_idx: int, onnx_trace: bool = False):
|
||||
"""Replace non-padding symbols with their position numbers.
|
||||
|
||||
Position numbers begin at padding_idx+1. Padding symbols are ignored.
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully
|
||||
# balanced to both work with ONNX export and XLA. In particular XLA
|
||||
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
|
||||
# how to handle the dtype kwarg in cumsum.
|
||||
mask = tensor.ne(padding_idx).int()
|
||||
return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
|
||||
|
||||
|
||||
def strip_pad(tensor, pad):
|
||||
return tensor[tensor.ne(pad)]
|
||||
|
||||
|
||||
def buffered_arange(max):
|
||||
if not hasattr(buffered_arange, "buf"):
|
||||
buffered_arange.buf = torch.LongTensor()
|
||||
if max > buffered_arange.buf.numel():
|
||||
buffered_arange.buf.resize_(max)
|
||||
torch.arange(max, out=buffered_arange.buf)
|
||||
return buffered_arange.buf[:max]
|
||||
|
||||
|
||||
def convert_padding_direction(
|
||||
src_tokens, padding_idx, right_to_left: bool = False, left_to_right: bool = False
|
||||
):
|
||||
assert right_to_left ^ left_to_right
|
||||
pad_mask = src_tokens.eq(padding_idx)
|
||||
if not pad_mask.any():
|
||||
# no padding, return early
|
||||
return src_tokens
|
||||
if left_to_right and not pad_mask[:, 0].any():
|
||||
# already right padded
|
||||
return src_tokens
|
||||
if right_to_left and not pad_mask[:, -1].any():
|
||||
# already left padded
|
||||
return src_tokens
|
||||
max_len = src_tokens.size(1)
|
||||
buffered = torch.empty(0).long()
|
||||
if max_len > 0:
|
||||
torch.arange(max_len, out=buffered)
|
||||
range = buffered.type_as(src_tokens).expand_as(src_tokens)
|
||||
num_pads = pad_mask.long().sum(dim=1, keepdim=True)
|
||||
if right_to_left:
|
||||
index = torch.remainder(range - num_pads, max_len)
|
||||
else:
|
||||
index = torch.remainder(range + num_pads, max_len)
|
||||
return src_tokens.gather(1, index)
|
||||
|
||||
|
||||
def item(tensor):
|
||||
# tpu-comment: making this a no-op for xla devices.
|
||||
if torch.is_tensor(tensor) and tensor.device.type == "xla":
|
||||
return tensor.detach()
|
||||
if hasattr(tensor, "item"):
|
||||
return tensor.item()
|
||||
if hasattr(tensor, "__getitem__"):
|
||||
return tensor[0]
|
||||
return tensor
|
||||
|
||||
|
||||
def multi_tensor_total_norm(grads, chunk_size=2048 * 32) -> torch.Tensor:
|
||||
per_device_grads = {}
|
||||
norms = []
|
||||
for grad in grads:
|
||||
device = grad.device
|
||||
cur_device_grads = per_device_grads.get(device)
|
||||
if cur_device_grads is None:
|
||||
cur_device_grads = []
|
||||
per_device_grads[device] = cur_device_grads
|
||||
cur_device_grads.append(grad)
|
||||
for device in per_device_grads.keys():
|
||||
cur_device_grads = per_device_grads[device]
|
||||
if device.type == "cuda":
|
||||
# TODO(msb) return has_inf
|
||||
has_inf = torch.zeros((1, 1), dtype=torch.int, device=device)
|
||||
with torch.cuda.device(device):
|
||||
norm = multi_tensor_l2norm(
|
||||
chunk_size, has_inf, [cur_device_grads], False
|
||||
)
|
||||
norms.append(norm[0].to(torch.cuda.current_device()))
|
||||
else:
|
||||
norms += [torch.norm(g, p=2, dtype=torch.float32) for g in cur_device_grads]
|
||||
total_norm = torch.norm(torch.stack(norms))
|
||||
return total_norm
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor:
|
||||
def grad_exists(p):
|
||||
return p is not None and getattr(p, "grad", None) is not None
|
||||
|
||||
if isinstance(params, torch.Tensor):
|
||||
params = [params]
|
||||
params = list(params)
|
||||
grads = [
|
||||
p.grad.detach() for p in params if grad_exists(p) and not hasattr(p, "expert")
|
||||
]
|
||||
expert_grads = [
|
||||
p.grad.detach() for p in params if grad_exists(p) and hasattr(p, "expert")
|
||||
]
|
||||
|
||||
if len(grads) == 0:
|
||||
if len(params) > 0:
|
||||
return params[0].new_tensor(0.0)
|
||||
else:
|
||||
return torch.tensor(0.0)
|
||||
|
||||
if len(grads) == 1:
|
||||
total_norm = torch.norm(grads[0], p=2, dtype=torch.float32)
|
||||
else:
|
||||
if multi_tensor_l2norm_available:
|
||||
total_norm = multi_tensor_total_norm(grads)
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
warnings.warn(
|
||||
"amp_C fused kernels unavailable, disabling multi_tensor_l2norm; "
|
||||
"you may get better performance by installing NVIDIA's apex library"
|
||||
)
|
||||
device = torch.cuda.current_device()
|
||||
elif grads[0].device.type == "xla":
|
||||
device = grads[0].device
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
total_norm = torch.norm(
|
||||
torch.stack(
|
||||
[torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads]
|
||||
)
|
||||
)
|
||||
|
||||
if aggregate_norm_fn is not None:
|
||||
total_norm = aggregate_norm_fn(total_norm)
|
||||
|
||||
if max_norm > 0:
|
||||
max_norm = float(max_norm)
|
||||
clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1)
|
||||
for g in grads + expert_grads:
|
||||
g.mul_(clip_coef)
|
||||
return total_norm
|
||||
|
||||
|
||||
def fill_with_neg_inf(t):
|
||||
"""FP16-compatible function that fills a tensor with -inf."""
|
||||
return t.float().fill_(float("-inf")).type_as(t)
|
||||
|
||||
|
||||
def _match_types(arg1, arg2):
|
||||
"""Convert the numerical argument to the same type as the other argument"""
|
||||
|
||||
def upgrade(arg_number, arg_structure):
|
||||
if isinstance(arg_structure, tuple):
|
||||
return tuple([arg_number] * len(arg_structure))
|
||||
elif isinstance(arg_structure, dict):
|
||||
arg = copy.deepcopy(arg_structure)
|
||||
for k in arg:
|
||||
arg[k] = upgrade(arg_number, arg_structure[k])
|
||||
return arg
|
||||
else:
|
||||
return arg_number
|
||||
|
||||
if isinstance(arg1, float) or isinstance(arg1, int):
|
||||
return upgrade(arg1, arg2), arg2
|
||||
elif isinstance(arg2, float) or isinstance(arg2, int):
|
||||
return arg1, upgrade(arg2, arg1)
|
||||
|
||||
return arg1, arg2
|
||||
|
||||
|
||||
def resolve_max_positions(*args):
|
||||
"""Resolve max position constraints from multiple sources."""
|
||||
|
||||
def map_value_update(d1, d2):
|
||||
updated_value = copy.deepcopy(d1)
|
||||
for key in d2:
|
||||
if key not in updated_value:
|
||||
updated_value[key] = d2[key]
|
||||
else:
|
||||
updated_value[key] = min(d1[key], d2[key])
|
||||
return updated_value
|
||||
|
||||
def nullsafe_min(l):
|
||||
minim = None
|
||||
for item in l:
|
||||
if minim is None:
|
||||
minim = item
|
||||
elif item is not None and item < minim:
|
||||
minim = item
|
||||
return minim
|
||||
|
||||
max_positions = None
|
||||
for arg in args:
|
||||
if max_positions is None:
|
||||
max_positions = arg
|
||||
elif arg is not None:
|
||||
max_positions, arg = _match_types(max_positions, arg)
|
||||
if isinstance(arg, float) or isinstance(arg, int):
|
||||
max_positions = min(max_positions, arg)
|
||||
elif isinstance(arg, dict):
|
||||
max_positions = map_value_update(max_positions, arg)
|
||||
else:
|
||||
max_positions = tuple(map(nullsafe_min, zip(max_positions, arg)))
|
||||
|
||||
return max_positions
|
||||
|
||||
|
||||
def import_user_module(args):
|
||||
module_path = getattr(args, "user_dir", None)
|
||||
if module_path is not None:
|
||||
module_path = os.path.abspath(args.user_dir)
|
||||
if not os.path.exists(module_path) and not os.path.isfile(
|
||||
os.path.dirname(module_path)
|
||||
):
|
||||
fairseq_rel_path = os.path.join(os.path.dirname(__file__), args.user_dir)
|
||||
if os.path.exists(fairseq_rel_path):
|
||||
module_path = fairseq_rel_path
|
||||
else:
|
||||
fairseq_rel_path = os.path.join(
|
||||
os.path.dirname(__file__), "..", args.user_dir
|
||||
)
|
||||
if os.path.exists(fairseq_rel_path):
|
||||
module_path = fairseq_rel_path
|
||||
else:
|
||||
raise FileNotFoundError(module_path)
|
||||
|
||||
# ensure that user modules are only imported once
|
||||
import_user_module.memo = getattr(import_user_module, "memo", set())
|
||||
if module_path not in import_user_module.memo:
|
||||
import_user_module.memo.add(module_path)
|
||||
|
||||
module_parent, module_name = os.path.split(module_path)
|
||||
if module_name not in sys.modules:
|
||||
sys.path.insert(0, module_parent)
|
||||
importlib.import_module(module_name)
|
||||
|
||||
tasks_path = os.path.join(module_path, "tasks")
|
||||
if os.path.exists(tasks_path):
|
||||
from fairseq.tasks import import_tasks
|
||||
|
||||
import_tasks(tasks_path, f"{module_name}.tasks")
|
||||
|
||||
models_path = os.path.join(module_path, "models")
|
||||
if os.path.exists(models_path):
|
||||
from fairseq.models import import_models
|
||||
|
||||
import_models(models_path, f"{module_name}.models")
|
||||
elif module_path in sys.modules[module_name].__path__:
|
||||
logger.info(f"--user-dir={module_path} has already been imported.")
|
||||
else:
|
||||
raise ImportError(
|
||||
"Failed to import --user-dir={} because the corresponding module name "
|
||||
"({}) is not globally unique. Please rename the directory to "
|
||||
"something unique and try again.".format(module_path, module_name)
|
||||
)
|
||||
|
||||
|
||||
def softmax(x, dim: int, onnx_trace: bool = False):
|
||||
if onnx_trace:
|
||||
return F.softmax(x.float(), dim=dim)
|
||||
else:
|
||||
return F.softmax(x, dim=dim, dtype=torch.float32)
|
||||
|
||||
|
||||
def log_softmax(x, dim: int, onnx_trace: bool = False):
|
||||
if onnx_trace:
|
||||
return F.log_softmax(x.float(), dim=dim)
|
||||
else:
|
||||
return F.log_softmax(x, dim=dim, dtype=torch.float32)
|
||||
|
||||
|
||||
def get_perplexity(loss, round=2, base=2):
|
||||
from fairseq.logging.meters import safe_round
|
||||
|
||||
if loss is None:
|
||||
return 0.0
|
||||
try:
|
||||
return safe_round(base**loss, round)
|
||||
except OverflowError:
|
||||
return float("inf")
|
||||
|
||||
|
||||
def deprecation_warning(message, stacklevel=3):
|
||||
# don't use DeprecationWarning, since it's ignored by default
|
||||
warnings.warn(message, stacklevel=stacklevel)
|
||||
|
||||
|
||||
def relu_squared(x: torch.Tensor):
|
||||
return F.relu(x).pow(2)
|
||||
|
||||
|
||||
def get_activation_fn(activation: str) -> Callable:
|
||||
"""Returns the activation function corresponding to `activation`"""
|
||||
from fairseq.modules import gelu, gelu_accurate
|
||||
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
elif activation == "relu_squared":
|
||||
return relu_squared
|
||||
elif activation == "gelu":
|
||||
return gelu
|
||||
elif activation == "gelu_fast":
|
||||
deprecation_warning(
|
||||
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
|
||||
)
|
||||
return gelu_accurate
|
||||
elif activation == "gelu_accurate":
|
||||
return gelu_accurate
|
||||
elif activation == "tanh":
|
||||
return torch.tanh
|
||||
elif activation == "linear":
|
||||
return lambda x: x
|
||||
elif activation == "swish":
|
||||
return torch.nn.SiLU
|
||||
else:
|
||||
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
||||
|
||||
|
||||
def get_available_activation_fns() -> List:
|
||||
return [
|
||||
"relu",
|
||||
"gelu",
|
||||
"gelu_fast", # deprecated
|
||||
"gelu_accurate",
|
||||
"tanh",
|
||||
"linear",
|
||||
]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def model_eval(model):
|
||||
is_training = model.training
|
||||
model.eval()
|
||||
yield
|
||||
model.train(is_training)
|
||||
|
||||
|
||||
def has_parameters(module):
|
||||
try:
|
||||
next(module.parameters())
|
||||
return True
|
||||
except StopIteration:
|
||||
return False
|
||||
|
||||
|
||||
def get_rng_state():
|
||||
state = {"torch_rng_state": torch.get_rng_state()}
|
||||
if xm is not None:
|
||||
state["xla_rng_state"] = xm.get_rng_state()
|
||||
if torch.cuda.is_available():
|
||||
state["cuda_rng_state"] = torch.cuda.get_rng_state()
|
||||
return state
|
||||
|
||||
|
||||
def set_rng_state(state):
|
||||
torch.set_rng_state(state["torch_rng_state"])
|
||||
if xm is not None:
|
||||
xm.set_rng_state(state["xla_rng_state"])
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_rng_state(state["cuda_rng_state"])
|
||||
|
||||
|
||||
class set_torch_seed(object):
|
||||
def __init__(self, seed):
|
||||
assert isinstance(seed, int)
|
||||
self.rng_state = get_rng_state()
|
||||
|
||||
torch.manual_seed(seed)
|
||||
if xm is not None:
|
||||
xm.set_rng_state(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
set_rng_state(self.rng_state)
|
||||
|
||||
|
||||
def parse_alignment(line):
|
||||
"""
|
||||
Parses a single line from the alingment file.
|
||||
|
||||
Args:
|
||||
line (str): String containing the alignment of the format:
|
||||
<src_idx_1>-<tgt_idx_1> <src_idx_2>-<tgt_idx_2> ..
|
||||
<src_idx_m>-<tgt_idx_m>. All indices are 0 indexed.
|
||||
|
||||
Returns:
|
||||
torch.IntTensor: packed alignments of shape (2 * m).
|
||||
"""
|
||||
alignments = line.strip().split()
|
||||
parsed_alignment = torch.IntTensor(2 * len(alignments))
|
||||
for idx, alignment in enumerate(alignments):
|
||||
src_idx, tgt_idx = alignment.split("-")
|
||||
parsed_alignment[2 * idx] = int(src_idx)
|
||||
parsed_alignment[2 * idx + 1] = int(tgt_idx)
|
||||
return parsed_alignment
|
||||
|
||||
|
||||
def get_token_to_word_mapping(tokens, exclude_list):
|
||||
n = len(tokens)
|
||||
word_start = [int(token not in exclude_list) for token in tokens]
|
||||
word_idx = list(accumulate(word_start))
|
||||
token_to_word = {i: word_idx[i] for i in range(n)}
|
||||
return token_to_word
|
||||
|
||||
|
||||
def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos):
|
||||
tgt_valid = (
|
||||
((tgt_sent != pad) & (tgt_sent != eos)).nonzero(as_tuple=False).squeeze(dim=-1)
|
||||
)
|
||||
src_invalid = (
|
||||
((src_sent == pad) | (src_sent == eos)).nonzero(as_tuple=False).squeeze(dim=-1)
|
||||
)
|
||||
src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad])
|
||||
tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [eos, pad])
|
||||
alignment = []
|
||||
if len(tgt_valid) != 0 and len(src_invalid) < len(src_sent):
|
||||
attn_valid = attn[tgt_valid]
|
||||
attn_valid[:, src_invalid] = float("-inf")
|
||||
_, src_indices = attn_valid.max(dim=1)
|
||||
for tgt_idx, src_idx in zip(tgt_valid, src_indices):
|
||||
alignment.append(
|
||||
(
|
||||
src_token_to_word[src_idx.item()] - 1,
|
||||
tgt_token_to_word[tgt_idx.item()] - 1,
|
||||
)
|
||||
)
|
||||
return alignment
|
||||
|
||||
|
||||
def extract_soft_alignment(attn, src_sent, tgt_sent, pad, eos):
|
||||
tgt_valid = ((tgt_sent != pad)).nonzero(as_tuple=False)
|
||||
src_valid = ((src_sent != pad)).nonzero(as_tuple=False).squeeze(dim=-1)
|
||||
alignment = []
|
||||
if len(tgt_valid) != 0 and len(src_valid) != 0:
|
||||
attn_valid = attn[tgt_valid, src_valid]
|
||||
alignment = [
|
||||
["{:.6f}".format(p) for p in src_probs.tolist()] for src_probs in attn_valid
|
||||
]
|
||||
return alignment
|
||||
|
||||
|
||||
def new_arange(x, *size):
|
||||
"""
|
||||
Return a Tensor of `size` filled with a range function on the device of x.
|
||||
If size is empty, using the size of the variable x.
|
||||
"""
|
||||
if len(size) == 0:
|
||||
size = x.size()
|
||||
return torch.arange(size[-1], device=x.device).expand(*size).contiguous()
|
||||
|
||||
|
||||
def get_tpu_device():
|
||||
return xm.xla_device()
|
||||
|
||||
|
||||
def tpu_data_loader(itr):
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.parallel_loader as pl
|
||||
|
||||
from fairseq.data import iterators
|
||||
|
||||
xm.rendezvous("tpu_data_loader") # wait for all workers
|
||||
xm.mark_step()
|
||||
device = xm.xla_device()
|
||||
return iterators.CountingIterator(
|
||||
pl.ParallelLoader(itr, [device]).per_device_loader(device),
|
||||
start=getattr(itr, "n", 0),
|
||||
total=len(itr),
|
||||
)
|
||||
|
||||
|
||||
def is_xla_tensor(tensor):
|
||||
return torch.is_tensor(tensor) and tensor.device.type == "xla"
|
||||
|
||||
|
||||
def index_put(tensor, indices, value):
|
||||
if is_xla_tensor(tensor):
|
||||
for _ in range(indices.dim(), tensor.dim()):
|
||||
indices = indices.unsqueeze(-1)
|
||||
if indices.size(-1) < tensor.size(-1):
|
||||
indices = indices.expand_as(tensor)
|
||||
tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
|
||||
else:
|
||||
tensor[indices] = value
|
||||
return tensor
|
||||
|
||||
|
||||
def xla_device_to_cpu(dat):
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
return xm._maybe_convert_to_cpu(dat)
|
||||
|
||||
|
||||
class CudaEnvironment(object):
|
||||
def __init__(self):
|
||||
cur_device = torch.cuda.current_device()
|
||||
prop = torch.cuda.get_device_properties("cuda:{}".format(cur_device))
|
||||
self.name = prop.name
|
||||
self.major = prop.major
|
||||
self.minor = prop.minor
|
||||
self.total_memory_in_GB = prop.total_memory / 1024 / 1024 / 1024
|
||||
|
||||
@staticmethod
|
||||
def pretty_print_cuda_env_list(cuda_env_list):
|
||||
"""
|
||||
Given a list of CudaEnviorments, pretty print them
|
||||
"""
|
||||
num_workers = len(cuda_env_list)
|
||||
center = "CUDA enviroments for all {} workers".format(num_workers)
|
||||
banner_len = 40 - len(center) // 2
|
||||
first_line = "*" * banner_len + center + "*" * banner_len
|
||||
logger.info(first_line)
|
||||
for r, env in enumerate(cuda_env_list):
|
||||
logger.info(
|
||||
"rank {:3d}: ".format(r)
|
||||
+ "capabilities = {:2d}.{:<2d} ; ".format(env.major, env.minor)
|
||||
+ "total memory = {:.3f} GB ; ".format(env.total_memory_in_GB)
|
||||
+ "name = {:40s}".format(env.name)
|
||||
)
|
||||
logger.info(first_line)
|
||||
|
||||
|
||||
def csv_str_list(x):
|
||||
return x.split(",")
|
||||
|
||||
|
||||
def eval_str_list(x, type=float):
|
||||
if x is None:
|
||||
return None
|
||||
if isinstance(x, str):
|
||||
x = eval(x)
|
||||
try:
|
||||
return list(map(type, x))
|
||||
except TypeError:
|
||||
return [type(x)]
|
||||
|
||||
|
||||
def eval_str_dict(x, type=dict):
|
||||
if x is None:
|
||||
return None
|
||||
if isinstance(x, str):
|
||||
x = eval(x)
|
||||
return x
|
||||
|
||||
|
||||
def eval_bool(x, default=False):
|
||||
if x is None:
|
||||
return default
|
||||
try:
|
||||
return bool(eval(x))
|
||||
except TypeError:
|
||||
return default
|
||||
|
||||
|
||||
def reset_logging():
|
||||
root = logging.getLogger()
|
||||
for handler in root.handlers:
|
||||
root.removeHandler(handler)
|
||||
root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper())
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(
|
||||
logging.Formatter(
|
||||
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
)
|
||||
root.addHandler(handler)
|
||||
|
||||
|
||||
def safe_getattr(obj, k, default=None):
|
||||
"""Returns obj[k] if it exists and is not None, otherwise returns default."""
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
if OmegaConf.is_config(obj):
|
||||
return obj[k] if k in obj and obj[k] is not None else default
|
||||
|
||||
return getattr(obj, k, default)
|
||||
|
||||
|
||||
def safe_hasattr(obj, k):
|
||||
"""Returns True if the given key exists and is not None."""
|
||||
return getattr(obj, k, None) is not None
|
||||
|
||||
|
||||
def hotreload_function(name=None):
|
||||
"""
|
||||
Decorator to function to enable hot-reload for debugging.
|
||||
It allows you to debug a function without having reloading all heavy models, dataset loading and
|
||||
preprocessing, allow faster debugging.
|
||||
If you want to change model or dataset loading, consider relaunching your code
|
||||
-----------------------------------
|
||||
This will run the decorated function func:
|
||||
if func run successful:
|
||||
It will pause, allow user to edit code, and prompt user to:
|
||||
Press enter to re-run the function with updated code
|
||||
Type "done" to finish the function, return output
|
||||
Type "disable" to stop pausing this function and let code continue without pause
|
||||
Ctril + C to terminal
|
||||
if func raise error:
|
||||
it will prompt user to
|
||||
1. Edit code, and press enter to retry
|
||||
2. Ctrl + C to terminate
|
||||
3. Type "raise" to raise that exception
|
||||
* Requirements:
|
||||
0. Fairseq was installed with `pip install --editable .`
|
||||
1. pip install jurigged[develoop]
|
||||
2. set environment HOTRELOAD_PAUSE=1 CUDA_LAUNCH_BLOCKING=1
|
||||
3. Run on only 1 GPU (no distributed)
|
||||
* How to use:
|
||||
1. in python, import and decorate the top-level function to be re-run after code edits:
|
||||
```python
|
||||
from fairseq.utils import hotreload_function
|
||||
....
|
||||
@hotreload_function("train_step")
|
||||
def train_step(self, sample ....):
|
||||
....
|
||||
....
|
||||
```
|
||||
2. in bash run scripts:
|
||||
```bash
|
||||
watch_dir=<home>/fairseq-py/fairseq/tasks # directory to watch for file changes
|
||||
export CUDA_VISIBLE_DEVICES=0 # single-gpu
|
||||
HOTRELOAD_PAUSE=1 CUDA_LAUNCH_BLOCKING=1 python -m jurigged -w ${watch_dir} --poll 2 -v train.py ......
|
||||
```
|
||||
* NOTE:
|
||||
1. -w ${watch_dir} specify all the files to be watched for changes
|
||||
once functions, class, ... code are changed, all instances in the process will get updated (hot-reload)
|
||||
* Limitation:
|
||||
* Currently distributed debugging not working
|
||||
* Need to launch train.py locally (cannot submit jobs)
|
||||
"""
|
||||
try:
|
||||
import jurigged
|
||||
except ImportError as e:
|
||||
logger.warning("Please install jurigged: pip install jurigged[develoop]")
|
||||
raise e
|
||||
from fairseq.distributed import utils as distributed_utils
|
||||
import traceback
|
||||
|
||||
def hotreload_decorator(func):
|
||||
assert callable(func), f"not callable: {func}"
|
||||
jname = name or func.__name__
|
||||
logger.info(f"jurigged-hotreload:Apply jurigged on {jname}:{func.__name__}")
|
||||
HOTRELOAD_PAUSE = bool(os.environ.get("HOTRELOAD_PAUSE", 0))
|
||||
cublk = bool(os.environ.get("CUDA_LAUNCH_BLOCKING", 0))
|
||||
prefix = f"HOTRELOAD:{jname}:[cublk={cublk}]"
|
||||
hot_reload_state = {"disable": False}
|
||||
|
||||
def func_wrapper(*args, **kwargs):
|
||||
if not HOTRELOAD_PAUSE or hot_reload_state["disable"]:
|
||||
return func(*args, **kwargs)
|
||||
world_size = distributed_utils.get_global_world_size()
|
||||
assert (
|
||||
world_size <= 1
|
||||
), f"HOTRELOAD_PAUSE:{jname} currently cannot do distributed training"
|
||||
success = False
|
||||
while not success:
|
||||
try:
|
||||
output = func(*args, **kwargs)
|
||||
# success = True
|
||||
end_action = input(
|
||||
f"{prefix}: PAUSE, you may edit code now. Enter to re-run, ctrl+C to terminate, "
|
||||
f'type "done" to continue (function still being watched), or type "disable" to stop pausing this function :'
|
||||
)
|
||||
if end_action.strip().lower() in ["disable", "done"]:
|
||||
success = True
|
||||
else:
|
||||
logger.warning(
|
||||
f"{prefix}: action={end_action} function will re-run now."
|
||||
)
|
||||
except Exception as e:
|
||||
action = input(
|
||||
f"{prefix}:ERROR: \n{traceback.format_exc()}\n"
|
||||
f'Edit code to try again: enter to continue, ctrl+C to terminate, or type "raise" to raise the exception: '
|
||||
)
|
||||
if action.strip().lower() == "raise":
|
||||
raise e
|
||||
|
||||
if end_action.strip().lower() == "disable":
|
||||
logger.warning(
|
||||
f"{prefix}: Stop pausing {jname}. The function is still being watched and newly editted code will take effect "
|
||||
f"if the {jname} is called again later."
|
||||
f' "unset HOTRELOAD_PAUSE" before relaunch to disable hotreload and'
|
||||
f" remove @hotreload_function decorator in the code."
|
||||
)
|
||||
hot_reload_state["disable"] = True
|
||||
return output
|
||||
|
||||
return func_wrapper
|
||||
|
||||
return hotreload_decorator
|
||||
Reference in New Issue
Block a user