mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-02-22 06:04:26 +00:00
605 lines
21 KiB
Python
605 lines
21 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
try:
|
|
from collections.abc import Iterable
|
|
except ImportError:
|
|
from collections import Iterable
|
|
import contextlib
|
|
import itertools
|
|
import logging
|
|
import re
|
|
import warnings
|
|
from typing import Optional, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from fairseq.file_io import PathManager
|
|
from fairseq import utils
|
|
import os
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def infer_language_pair(path):
|
|
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
|
|
src, dst = None, None
|
|
for filename in PathManager.ls(path):
|
|
parts = filename.split(".")
|
|
if len(parts) >= 3 and len(parts[1].split("-")) == 2:
|
|
return parts[1].split("-")
|
|
return src, dst
|
|
|
|
|
|
def collate_tokens(
|
|
values,
|
|
pad_idx,
|
|
eos_idx=None,
|
|
left_pad=False,
|
|
move_eos_to_beginning=False,
|
|
pad_to_length=None,
|
|
pad_to_multiple=1,
|
|
pad_to_bsz=None,
|
|
):
|
|
"""Convert a list of 1d tensors into a padded 2d tensor."""
|
|
size = max(v.size(0) for v in values)
|
|
size = size if pad_to_length is None else max(size, pad_to_length)
|
|
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
|
|
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
|
|
|
|
batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz)
|
|
res = values[0].new(batch_size, size).fill_(pad_idx)
|
|
|
|
def copy_tensor(src, dst):
|
|
assert dst.numel() == src.numel()
|
|
if move_eos_to_beginning:
|
|
if eos_idx is None:
|
|
# if no eos_idx is specified, then use the last token in src
|
|
dst[0] = src[-1]
|
|
else:
|
|
dst[0] = eos_idx
|
|
dst[1:] = src[:-1]
|
|
else:
|
|
dst.copy_(src)
|
|
|
|
for i, v in enumerate(values):
|
|
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
|
|
return res
|
|
|
|
|
|
def load_indexed_dataset(
|
|
path, dictionary=None, dataset_impl=None, combine=False, default="cached"
|
|
):
|
|
"""A helper function for loading indexed datasets.
|
|
|
|
Args:
|
|
path (str): path to indexed dataset (e.g., 'data-bin/train')
|
|
dictionary (~fairseq.data.Dictionary): data dictionary
|
|
dataset_impl (str, optional): which dataset implementation to use. If
|
|
not provided, it will be inferred automatically. For legacy indexed
|
|
data we use the 'cached' implementation by default.
|
|
combine (bool, optional): automatically load and combine multiple
|
|
datasets. For example, if *path* is 'data-bin/train', then we will
|
|
combine 'data-bin/train', 'data-bin/train1', ... and return a
|
|
single ConcatDataset instance.
|
|
"""
|
|
import fairseq.data.indexed_dataset as indexed_dataset
|
|
from fairseq.data.concat_dataset import ConcatDataset
|
|
|
|
datasets = []
|
|
for k in itertools.count():
|
|
path_k = path + (str(k) if k > 0 else "")
|
|
try:
|
|
path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
|
|
except Exception as e:
|
|
if "StorageException: [404] Path not found" in str(e):
|
|
logger.warning(f"path_k: {e} not found")
|
|
else:
|
|
raise e
|
|
|
|
dataset_impl_k = dataset_impl
|
|
if dataset_impl_k is None:
|
|
dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
|
|
dataset = indexed_dataset.make_dataset(
|
|
path_k,
|
|
impl=dataset_impl_k or default,
|
|
fix_lua_indexing=True,
|
|
dictionary=dictionary,
|
|
)
|
|
if dataset is None:
|
|
break
|
|
logger.info("loaded {:,} examples from: {}".format(len(dataset), path_k))
|
|
datasets.append(dataset)
|
|
if not combine:
|
|
break
|
|
if len(datasets) == 0:
|
|
return None
|
|
elif len(datasets) == 1:
|
|
return datasets[0]
|
|
else:
|
|
return ConcatDataset(datasets)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def numpy_seed(seed, *addl_seeds):
|
|
"""Context manager which seeds the NumPy PRNG with the specified seed and
|
|
restores the state afterward"""
|
|
if seed is None:
|
|
yield
|
|
return
|
|
if len(addl_seeds) > 0:
|
|
seed = int(hash((seed, *addl_seeds)) % 1e6)
|
|
state = np.random.get_state()
|
|
np.random.seed(seed)
|
|
try:
|
|
yield
|
|
finally:
|
|
np.random.set_state(state)
|
|
|
|
|
|
def collect_filtered(function, iterable, filtered):
|
|
"""
|
|
Similar to :func:`filter` but collects filtered elements in ``filtered``.
|
|
|
|
Args:
|
|
function (callable): function that returns ``False`` for elements that
|
|
should be filtered
|
|
iterable (iterable): iterable to filter
|
|
filtered (list): list to store filtered elements
|
|
"""
|
|
for el in iterable:
|
|
if function(el):
|
|
yield el
|
|
else:
|
|
filtered.append(el)
|
|
|
|
|
|
def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
|
|
def compare_leq(a, b):
|
|
return a <= b if not isinstance(a, tuple) else max(a) <= b
|
|
|
|
def check_size(idx):
|
|
if isinstance(max_positions, float) or isinstance(max_positions, int):
|
|
return size_fn(idx) <= max_positions
|
|
elif isinstance(max_positions, dict):
|
|
idx_size = size_fn(idx)
|
|
assert isinstance(idx_size, dict)
|
|
intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
|
|
return all(
|
|
all(
|
|
a is None or b is None or a <= b
|
|
for a, b in zip(idx_size[key], max_positions[key])
|
|
)
|
|
for key in intersect_keys
|
|
)
|
|
else:
|
|
# For MultiCorpusSampledDataset, will generalize it later
|
|
if not isinstance(size_fn(idx), Iterable):
|
|
return all(size_fn(idx) <= b for b in max_positions)
|
|
return all(
|
|
a is None or b is None or a <= b
|
|
for a, b in zip(size_fn(idx), max_positions)
|
|
)
|
|
|
|
ignored = []
|
|
itr = collect_filtered(check_size, indices, ignored)
|
|
indices = np.fromiter(itr, dtype=np.int64, count=-1)
|
|
return indices, ignored
|
|
|
|
|
|
def filter_by_size(indices, dataset, max_positions, raise_exception=False):
|
|
"""
|
|
[deprecated] Filter indices based on their size.
|
|
Use `FairseqDataset::filter_indices_by_size` instead.
|
|
|
|
Args:
|
|
indices (List[int]): ordered list of dataset indices
|
|
dataset (FairseqDataset): fairseq dataset instance
|
|
max_positions (tuple): filter elements larger than this size.
|
|
Comparisons are done component-wise.
|
|
raise_exception (bool, optional): if ``True``, raise an exception if
|
|
any elements are filtered (default: False).
|
|
"""
|
|
warnings.warn(
|
|
"data_utils.filter_by_size is deprecated. "
|
|
"Use `FairseqDataset::filter_indices_by_size` instead.",
|
|
stacklevel=2,
|
|
)
|
|
if isinstance(max_positions, float) or isinstance(max_positions, int):
|
|
if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
|
|
ignored = indices[dataset.sizes[indices] > max_positions].tolist()
|
|
indices = indices[dataset.sizes[indices] <= max_positions]
|
|
elif (
|
|
hasattr(dataset, "sizes")
|
|
and isinstance(dataset.sizes, list)
|
|
and len(dataset.sizes) == 1
|
|
):
|
|
ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
|
|
indices = indices[dataset.sizes[0][indices] <= max_positions]
|
|
else:
|
|
indices, ignored = _filter_by_size_dynamic(
|
|
indices, dataset.size, max_positions
|
|
)
|
|
else:
|
|
indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
|
|
|
|
if len(ignored) > 0 and raise_exception:
|
|
raise Exception(
|
|
(
|
|
"Size of sample #{} is invalid (={}) since max_positions={}, "
|
|
"skip this example with --skip-invalid-size-inputs-valid-test"
|
|
).format(ignored[0], dataset.size(ignored[0]), max_positions)
|
|
)
|
|
if len(ignored) > 0:
|
|
logger.warning(
|
|
(
|
|
"{} samples have invalid sizes and will be skipped, "
|
|
"max_positions={}, first few sample ids={}"
|
|
).format(len(ignored), max_positions, ignored[:10])
|
|
)
|
|
return indices
|
|
|
|
|
|
def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes):
|
|
"""Filter a list of sample indices. Remove those that are longer
|
|
than specified in max_sizes.
|
|
|
|
Args:
|
|
indices (np.array): original array of sample indices
|
|
max_sizes (int or list[int] or tuple[int]): max sample size,
|
|
can be defined separately for src and tgt (then list or tuple)
|
|
|
|
Returns:
|
|
np.array: filtered sample array
|
|
list: list of removed indices
|
|
"""
|
|
if max_sizes is None:
|
|
return indices, []
|
|
if type(max_sizes) in (int, float):
|
|
max_src_size, max_tgt_size = max_sizes, max_sizes
|
|
else:
|
|
max_src_size, max_tgt_size = max_sizes
|
|
if tgt_sizes is None:
|
|
ignored = indices[src_sizes[indices] > max_src_size]
|
|
else:
|
|
ignored = indices[
|
|
(src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size)
|
|
]
|
|
if len(ignored) > 0:
|
|
if tgt_sizes is None:
|
|
indices = indices[src_sizes[indices] <= max_src_size]
|
|
else:
|
|
indices = indices[
|
|
(src_sizes[indices] <= max_src_size)
|
|
& (tgt_sizes[indices] <= max_tgt_size)
|
|
]
|
|
return indices, ignored.tolist()
|
|
|
|
|
|
def batch_by_size(
|
|
indices,
|
|
num_tokens_fn,
|
|
num_tokens_vec=None,
|
|
max_tokens=None,
|
|
max_sentences=None,
|
|
required_batch_size_multiple=1,
|
|
fixed_shapes=None,
|
|
):
|
|
"""
|
|
Yield mini-batches of indices bucketed by size. Batches may contain
|
|
sequences of different lengths.
|
|
|
|
Args:
|
|
indices (List[int]): ordered list of dataset indices
|
|
num_tokens_fn (callable): function that returns the number of tokens at
|
|
a given index
|
|
num_tokens_vec (List[int], optional): precomputed vector of the number
|
|
of tokens for each index in indices (to enable faster batch generation)
|
|
max_tokens (int, optional): max number of tokens in each batch
|
|
(default: None).
|
|
max_sentences (int, optional): max number of sentences in each
|
|
batch (default: None).
|
|
required_batch_size_multiple (int, optional): require batch size to
|
|
be less than N or a multiple of N (default: 1).
|
|
fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
|
|
only be created with the given shapes. *max_sentences* and
|
|
*required_batch_size_multiple* will be ignored (default: None).
|
|
"""
|
|
try:
|
|
from fairseq.data.data_utils_fast import (
|
|
batch_by_size_fn,
|
|
batch_by_size_vec,
|
|
batch_fixed_shapes_fast,
|
|
)
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Please build Cython components with: "
|
|
"`python setup.py build_ext --inplace`"
|
|
)
|
|
except ValueError:
|
|
raise ValueError(
|
|
"Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`."
|
|
)
|
|
|
|
# added int() to avoid TypeError: an integer is required
|
|
max_tokens = int(max_tokens) if max_tokens is not None else -1
|
|
max_sentences = max_sentences if max_sentences is not None else -1
|
|
bsz_mult = required_batch_size_multiple
|
|
|
|
if not isinstance(indices, np.ndarray):
|
|
indices = np.fromiter(indices, dtype=np.int64, count=-1)
|
|
|
|
if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray):
|
|
num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1)
|
|
|
|
if fixed_shapes is None:
|
|
if num_tokens_vec is None:
|
|
return batch_by_size_fn(
|
|
indices,
|
|
num_tokens_fn,
|
|
max_tokens,
|
|
max_sentences,
|
|
bsz_mult,
|
|
)
|
|
else:
|
|
return batch_by_size_vec(
|
|
indices,
|
|
num_tokens_vec,
|
|
max_tokens,
|
|
max_sentences,
|
|
bsz_mult,
|
|
)
|
|
|
|
else:
|
|
fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
|
|
sort_order = np.lexsort(
|
|
[
|
|
fixed_shapes[:, 1].argsort(), # length
|
|
fixed_shapes[:, 0].argsort(), # bsz
|
|
]
|
|
)
|
|
fixed_shapes_sorted = fixed_shapes[sort_order]
|
|
return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
|
|
|
|
|
|
def post_process(sentence: str, symbol: str):
|
|
if symbol == "sentencepiece":
|
|
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
|
|
elif symbol == "wordpiece":
|
|
sentence = sentence.replace(" ", "").replace("_", " ").strip()
|
|
elif symbol == "letter":
|
|
sentence = sentence.replace(" ", "").replace("|", " ").strip()
|
|
elif symbol == "silence":
|
|
import re
|
|
|
|
sentence = sentence.replace("<SIL>", "")
|
|
sentence = re.sub(" +", " ", sentence).strip()
|
|
elif symbol == "_EOW":
|
|
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
|
|
elif symbol in {"subword_nmt", "@@ ", "@@"}:
|
|
if symbol == "subword_nmt":
|
|
symbol = "@@ "
|
|
sentence = (sentence + " ").replace(symbol, "").rstrip()
|
|
elif symbol == "none":
|
|
pass
|
|
elif symbol is not None:
|
|
raise NotImplementedError(f"Unknown post_process option: {symbol}")
|
|
return sentence
|
|
|
|
|
|
def compute_mask_indices(
|
|
shape: Tuple[int, int],
|
|
padding_mask: Optional[torch.Tensor],
|
|
mask_prob: float,
|
|
mask_length: int,
|
|
mask_type: str = "static",
|
|
mask_other: float = 0.0,
|
|
min_masks: int = 0,
|
|
no_overlap: bool = False,
|
|
min_space: int = 0,
|
|
require_same_masks: bool = True,
|
|
mask_dropout: float = 0.0,
|
|
) -> np.ndarray:
|
|
"""
|
|
Computes random mask spans for a given shape
|
|
|
|
Args:
|
|
shape: the the shape for which to compute masks.
|
|
should be of size 2 where first element is batch size and 2nd is timesteps
|
|
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
|
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
|
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
|
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
|
mask_type: how to compute mask lengths
|
|
static = fixed size
|
|
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
|
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
|
poisson = sample from possion distribution with lambda = mask length
|
|
min_masks: minimum number of masked spans
|
|
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
|
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
|
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
|
|
mask_dropout: randomly dropout this percentage of masks in each example
|
|
"""
|
|
|
|
bsz, all_sz = shape
|
|
mask = np.full((bsz, all_sz), False)
|
|
|
|
all_num_mask = int(
|
|
# add a random number for probabilistic rounding
|
|
mask_prob * all_sz / float(mask_length)
|
|
+ np.random.rand()
|
|
)
|
|
|
|
all_num_mask = max(min_masks, all_num_mask)
|
|
|
|
mask_idcs = []
|
|
for i in range(bsz):
|
|
if padding_mask is not None:
|
|
sz = all_sz - padding_mask[i].long().sum().item()
|
|
num_mask = int(
|
|
# add a random number for probabilistic rounding
|
|
mask_prob * sz / float(mask_length)
|
|
+ np.random.rand()
|
|
)
|
|
num_mask = max(min_masks, num_mask)
|
|
else:
|
|
sz = all_sz
|
|
num_mask = all_num_mask
|
|
|
|
if mask_type == "static":
|
|
lengths = np.full(num_mask, mask_length)
|
|
elif mask_type == "uniform":
|
|
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
|
elif mask_type == "normal":
|
|
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
|
lengths = [max(1, int(round(x))) for x in lengths]
|
|
elif mask_type == "poisson":
|
|
lengths = np.random.poisson(mask_length, size=num_mask)
|
|
lengths = [int(round(x)) for x in lengths]
|
|
else:
|
|
raise Exception("unknown mask selection " + mask_type)
|
|
|
|
if sum(lengths) == 0:
|
|
lengths[0] = min(mask_length, sz - 1)
|
|
|
|
if no_overlap:
|
|
mask_idc = []
|
|
|
|
def arrange(s, e, length, keep_length):
|
|
span_start = np.random.randint(s, e - length)
|
|
mask_idc.extend(span_start + i for i in range(length))
|
|
|
|
new_parts = []
|
|
if span_start - s - min_space >= keep_length:
|
|
new_parts.append((s, span_start - min_space + 1))
|
|
if e - span_start - length - min_space > keep_length:
|
|
new_parts.append((span_start + length + min_space, e))
|
|
return new_parts
|
|
|
|
parts = [(0, sz)]
|
|
min_length = min(lengths)
|
|
for length in sorted(lengths, reverse=True):
|
|
lens = np.fromiter(
|
|
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
|
np.int,
|
|
)
|
|
l_sum = np.sum(lens)
|
|
if l_sum == 0:
|
|
break
|
|
probs = lens / np.sum(lens)
|
|
c = np.random.choice(len(parts), p=probs)
|
|
s, e = parts.pop(c)
|
|
parts.extend(arrange(s, e, length, min_length))
|
|
mask_idc = np.asarray(mask_idc)
|
|
else:
|
|
min_len = min(lengths)
|
|
if sz - min_len <= num_mask:
|
|
min_len = sz - num_mask - 1
|
|
|
|
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
|
|
|
mask_idc = np.asarray(
|
|
[
|
|
mask_idc[j] + offset
|
|
for j in range(len(mask_idc))
|
|
for offset in range(lengths[j])
|
|
]
|
|
)
|
|
|
|
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
|
|
|
min_len = min([len(m) for m in mask_idcs])
|
|
for i, mask_idc in enumerate(mask_idcs):
|
|
if len(mask_idc) > min_len and require_same_masks:
|
|
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
|
if mask_dropout > 0:
|
|
num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int)
|
|
mask_idc = np.random.choice(
|
|
mask_idc, len(mask_idc) - num_holes, replace=False
|
|
)
|
|
|
|
mask[i, mask_idc] = True
|
|
|
|
return mask
|
|
|
|
|
|
def get_mem_usage():
|
|
try:
|
|
import psutil
|
|
|
|
mb = 1024 * 1024
|
|
return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb"
|
|
except ImportError:
|
|
return "N/A"
|
|
|
|
|
|
# lens: torch.LongTensor
|
|
# returns: torch.BoolTensor
|
|
def lengths_to_padding_mask(lens):
|
|
bsz, max_lens = lens.size(0), torch.max(lens).item()
|
|
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
|
|
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
|
|
return mask
|
|
|
|
|
|
# lens: torch.LongTensor
|
|
# returns: torch.BoolTensor
|
|
def lengths_to_mask(lens):
|
|
return ~lengths_to_padding_mask(lens)
|
|
|
|
|
|
def get_buckets(sizes, num_buckets):
|
|
buckets = np.unique(
|
|
np.percentile(
|
|
sizes,
|
|
np.linspace(0, 100, num_buckets + 1),
|
|
interpolation="lower",
|
|
)[1:]
|
|
)
|
|
return buckets
|
|
|
|
|
|
def get_bucketed_sizes(orig_sizes, buckets):
|
|
sizes = np.copy(orig_sizes)
|
|
assert np.min(sizes) >= 0
|
|
start_val = -1
|
|
for end_val in buckets:
|
|
mask = (sizes > start_val) & (sizes <= end_val)
|
|
sizes[mask] = end_val
|
|
start_val = end_val
|
|
return sizes
|
|
|
|
|
|
def _find_extra_valid_paths(dataset_path: str) -> set:
|
|
paths = utils.split_paths(dataset_path)
|
|
all_valid_paths = set()
|
|
for sub_dir in paths:
|
|
contents = PathManager.ls(sub_dir)
|
|
valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None]
|
|
all_valid_paths |= {os.path.basename(p) for p in valid_paths}
|
|
# Remove .bin, .idx etc
|
|
roots = {os.path.splitext(p)[0] for p in all_valid_paths}
|
|
return roots
|
|
|
|
|
|
def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None:
|
|
"""Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored."""
|
|
if (
|
|
train_cfg.dataset.ignore_unused_valid_subsets
|
|
or train_cfg.dataset.combine_valid_subsets
|
|
or train_cfg.dataset.disable_validation
|
|
or not hasattr(train_cfg.task, "data")
|
|
):
|
|
return
|
|
other_paths = _find_extra_valid_paths(train_cfg.task.data)
|
|
specified_subsets = train_cfg.dataset.valid_subset.split(",")
|
|
ignored_paths = [p for p in other_paths if p not in specified_subsets]
|
|
if ignored_paths:
|
|
advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them."
|
|
msg = f"Valid paths {ignored_paths} will be ignored. {advice}"
|
|
raise ValueError(msg)
|