mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-05-05 22:01:21 +00:00
Add monkey patched fairseq package to run on python 3.11 (what is needed for our use of RVC at least)
This commit is contained in:
@@ -0,0 +1,468 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import datetime
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from bisect import bisect_right
|
||||
from collections import OrderedDict, defaultdict
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq.data import FairseqDataset, data_utils
|
||||
from fairseq.distributed import utils as distributed_utils
|
||||
|
||||
|
||||
def get_time_gap(s, e):
|
||||
return (
|
||||
datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s)
|
||||
).__str__()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def default_virtual_size_func(datasets, ratios, max_scale_up=1.5):
|
||||
sizes = [len(d) for d in datasets]
|
||||
if ratios is None:
|
||||
return sum(sizes)
|
||||
largest_idx = np.argmax(sizes)
|
||||
largest_r = ratios[largest_idx]
|
||||
largest_s = sizes[largest_idx]
|
||||
# set virtual sizes relative to the largest dataset
|
||||
virtual_sizes = [(r / largest_r) * largest_s for r in ratios]
|
||||
vsize = sum(virtual_sizes)
|
||||
max_size = sum(sizes) * max_scale_up
|
||||
return int(vsize if vsize < max_size else max_size)
|
||||
|
||||
|
||||
class CollateFormat(Enum):
|
||||
single = 1
|
||||
ordered_dict = 2
|
||||
|
||||
|
||||
class SampledMultiDataset(FairseqDataset):
|
||||
"""Samples from multiple sub-datasets according to given sampling ratios.
|
||||
Args:
|
||||
datasets (
|
||||
List[~torch.utils.data.Dataset]
|
||||
or OrderedDict[str, ~torch.utils.data.Dataset]
|
||||
): datasets
|
||||
sampling_ratios (List[float]): list of probability of each dataset to be sampled
|
||||
(default: None, which corresponds to concatenating all dataset together).
|
||||
seed (int): RNG seed to use (default: 2).
|
||||
epoch (int): starting epoch number (default: 1).
|
||||
eval_key (str, optional): a key used at evaluation time that causes
|
||||
this instance to pass-through batches from *datasets[eval_key]*.
|
||||
collate_format (CollateFormat): collater output format, either CollateFormat.ordered_dict or
|
||||
CollateFormat.single (default: CollateFormat.single) where CollateFormat.single configures
|
||||
the collater to output batches of data mixed from all sub-datasets,
|
||||
and CollateFormat.ordered_dict configures the collater to output a dictionary of batches indexed by keys
|
||||
of sub-datasets.
|
||||
Note that not all sub-datasets will present in a single batch in both formats.
|
||||
virtual_size (int, or callable): the expected virtual size of the dataset (default: default_virtual_size_func).
|
||||
split (str): the split of the data, e.g. 'train', 'valid' or 'test'.
|
||||
shared_collater (bool): whether or not to all sub-datasets have the same collater.
|
||||
shuffle (bool): whether or not to shuffle data (default: True).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datasets,
|
||||
sampling_ratios=None,
|
||||
seed=2,
|
||||
epoch=1,
|
||||
eval_key=None,
|
||||
collate_format=CollateFormat.single,
|
||||
virtual_size=default_virtual_size_func,
|
||||
split="",
|
||||
shared_collater=False,
|
||||
shuffle=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.shared_collater = shared_collater
|
||||
self.shuffle = shuffle
|
||||
|
||||
if isinstance(datasets, OrderedDict):
|
||||
self.keys = list(datasets.keys())
|
||||
datasets = list(datasets.values())
|
||||
elif isinstance(datasets, List):
|
||||
self.keys = list(range(len(datasets)))
|
||||
else:
|
||||
raise AssertionError()
|
||||
self.datasets = datasets
|
||||
self.split = split
|
||||
|
||||
self.eval_key = eval_key
|
||||
if self.eval_key is not None:
|
||||
self.collate_format = CollateFormat.single
|
||||
else:
|
||||
self.collate_format = collate_format
|
||||
|
||||
self.seed = seed
|
||||
self._cur_epoch = None
|
||||
|
||||
self.cumulated_sizes = None
|
||||
# self.datasets[k][self._cur_indices[i]] is the data item i in this sampled dataset
|
||||
# namely, data item i is sampled from the kth sub-dataset self.datasets[k]
|
||||
# where self.cumulated_sizes[k-1] <= i < self.cumulated_sizes[k]
|
||||
self._cur_indices = None
|
||||
|
||||
self._sizes = None
|
||||
self.virtual_size_per_dataset = None
|
||||
# caching properties
|
||||
self._reset_cached_properties()
|
||||
self.setup_sampling(sampling_ratios, virtual_size)
|
||||
self.set_epoch(epoch)
|
||||
|
||||
def _clean_if_not_none(self, var_list):
|
||||
for v in var_list:
|
||||
if v is not None:
|
||||
del v
|
||||
|
||||
def _reset_cached_properties(self):
|
||||
self._clean_if_not_none([self._sizes, self._cur_indices])
|
||||
self._sizes = None
|
||||
self._cur_indices = None
|
||||
|
||||
def setup_sampling(self, sample_ratios, virtual_size):
|
||||
sizes = [len(d) for d in self.datasets]
|
||||
if sample_ratios is None:
|
||||
# default back to concating datasets
|
||||
self.sample_ratios = None
|
||||
self.virtual_size = sum(sizes)
|
||||
else:
|
||||
if not isinstance(sample_ratios, np.ndarray):
|
||||
sample_ratios = np.array(sample_ratios)
|
||||
self.sample_ratios = sample_ratios
|
||||
virtual_size = (
|
||||
default_virtual_size_func if virtual_size is None else virtual_size
|
||||
)
|
||||
self.virtual_size = (
|
||||
virtual_size(self.datasets, self.sample_ratios)
|
||||
if callable(virtual_size)
|
||||
else virtual_size
|
||||
)
|
||||
|
||||
def adjust_sampling(self, epoch, sampling_ratios, virtual_size):
|
||||
if sampling_ratios is not None:
|
||||
sampling_ratios = self._sync_sample_ratios(sampling_ratios)
|
||||
self.setup_sampling(sampling_ratios, virtual_size)
|
||||
|
||||
def _sync_sample_ratios(self, ratios):
|
||||
# in case the ratios are not precisely the same across processes
|
||||
# also to ensure every procresses update the ratios in the same pace
|
||||
ratios = torch.DoubleTensor(ratios)
|
||||
if torch.distributed.is_initialized():
|
||||
if torch.cuda.is_available():
|
||||
distributed_utils.all_reduce(
|
||||
ratios.cuda(), group=distributed_utils.get_data_parallel_group()
|
||||
)
|
||||
else:
|
||||
distributed_utils.all_reduce(
|
||||
ratios, group=distributed_utils.get_data_parallel_group()
|
||||
)
|
||||
ret = ratios.cpu()
|
||||
ret = ret.numpy()
|
||||
return ret
|
||||
|
||||
def random_choice_in_dataset(self, rng, dataset, choice_size):
|
||||
if hasattr(dataset, "random_choice_in_dataset"):
|
||||
return dataset.random_choice_in_dataset(rng, choice_size)
|
||||
dataset_size = len(dataset)
|
||||
return rng.choice(
|
||||
dataset_size, choice_size, replace=(choice_size > dataset_size)
|
||||
)
|
||||
|
||||
def get_virtual_indices(self, rng, datasets, sample_ratios, virtual_size):
|
||||
def get_counts(sample_ratios):
|
||||
counts = np.array([virtual_size * r for r in sample_ratios], dtype=np.int64)
|
||||
diff = virtual_size - counts.sum()
|
||||
assert diff >= 0
|
||||
# due to round-offs, the size might not match the desired sizes
|
||||
if diff > 0:
|
||||
dataset_indices = rng.choice(
|
||||
len(sample_ratios), size=diff, p=sample_ratios
|
||||
)
|
||||
for i in dataset_indices:
|
||||
counts[i] += 1
|
||||
return counts
|
||||
|
||||
def get_in_dataset_indices(datasets, sizes, sample_ratios):
|
||||
counts = get_counts(sample_ratios)
|
||||
# uniformally sample desired counts for each dataset
|
||||
# if the desired counts are large, sample with replacement:
|
||||
indices = [
|
||||
self.random_choice_in_dataset(rng, d, c)
|
||||
for c, d in zip(counts, datasets)
|
||||
]
|
||||
return indices
|
||||
|
||||
sizes = [len(d) for d in datasets]
|
||||
if sample_ratios is None:
|
||||
# default back to concating datasets
|
||||
in_dataset_indices = [list(range(s)) for s in sizes]
|
||||
virtual_sizes_per_dataset = sizes
|
||||
else:
|
||||
ratios = sample_ratios / sample_ratios.sum()
|
||||
in_dataset_indices = get_in_dataset_indices(datasets, sizes, ratios)
|
||||
virtual_sizes_per_dataset = [len(d) for d in in_dataset_indices]
|
||||
virtual_sizes_per_dataset = np.array(virtual_sizes_per_dataset, np.int64)
|
||||
cumulative_sizes = np.cumsum(virtual_sizes_per_dataset)
|
||||
assert sum(virtual_sizes_per_dataset) == virtual_size
|
||||
assert cumulative_sizes[-1] == virtual_size
|
||||
if virtual_size < sum(sizes):
|
||||
logger.warning(
|
||||
f"virtual data size ({virtual_size}) is less than real data size ({sum(sizes)})."
|
||||
" If virtual size << real data size, there could be data coverage issue."
|
||||
)
|
||||
in_dataset_indices = np.hstack(in_dataset_indices)
|
||||
return in_dataset_indices, cumulative_sizes, virtual_sizes_per_dataset
|
||||
|
||||
def _get_dataset_and_index(self, index):
|
||||
i = bisect_right(self.cumulated_sizes, index)
|
||||
return i, self._cur_indices[index]
|
||||
|
||||
def __getitem__(self, index):
|
||||
# self.__getitem__(index) returns self.datasets[k][self._cur_indices[index]]
|
||||
# where k satisfies self.cumulated_sizes[k - 1] <= k < self.cumulated_sizes[k]
|
||||
ds_idx, ds_sample_idx = self._get_dataset_and_index(index)
|
||||
ret = (ds_idx, self.datasets[ds_idx][ds_sample_idx])
|
||||
return ret
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.sizes[index].max()
|
||||
|
||||
def num_tokens_vec(self, indices):
|
||||
sizes_vec = self.sizes[np.array(indices)]
|
||||
# max across all dimensions but first one
|
||||
return np.amax(sizes_vec, axis=tuple(range(1, len(sizes_vec.shape))))
|
||||
|
||||
def size(self, index):
|
||||
return self.sizes[index]
|
||||
|
||||
def __len__(self):
|
||||
return self.virtual_size
|
||||
|
||||
def collater(self, samples, **extra_args):
|
||||
"""Merge a list of samples to form a mini-batch."""
|
||||
if len(samples) == 0:
|
||||
return None
|
||||
if self.collate_format == "ordered_dict":
|
||||
collect_samples = [[] for _ in range(len(self.datasets))]
|
||||
for (i, sample) in samples:
|
||||
collect_samples[i].append(sample)
|
||||
batch = OrderedDict(
|
||||
[
|
||||
(self.keys[i], dataset.collater(collect_samples[i]))
|
||||
for i, (key, dataset) in enumerate(zip(self.keys, self.datasets))
|
||||
if len(collect_samples[i]) > 0
|
||||
]
|
||||
)
|
||||
elif self.shared_collater:
|
||||
batch = self.datasets[0].collater([s for _, s in samples])
|
||||
else:
|
||||
samples_dict = defaultdict(list)
|
||||
pad_to_length = (
|
||||
defaultdict(int)
|
||||
if "pad_to_length" not in extra_args
|
||||
else extra_args["pad_to_length"]
|
||||
)
|
||||
for ds_idx, s in samples:
|
||||
pad_to_length["source"] = max(
|
||||
pad_to_length["source"], s["source"].size(0)
|
||||
)
|
||||
if s["target"] is not None:
|
||||
pad_to_length["target"] = max(
|
||||
pad_to_length["target"], s["target"].size(0)
|
||||
)
|
||||
samples_dict[ds_idx].append(s)
|
||||
batches = [
|
||||
self.datasets[i].collater(samples_dict[i], pad_to_length=pad_to_length)
|
||||
for i in range(len(self.datasets))
|
||||
if len(samples_dict[i]) > 0
|
||||
]
|
||||
|
||||
def straight_data(tensors):
|
||||
batch = torch.cat(tensors, dim=0)
|
||||
return batch
|
||||
|
||||
src_lengths = straight_data(
|
||||
[b["net_input"]["src_lengths"] for b in batches]
|
||||
)
|
||||
src_lengths, sort_order = src_lengths.sort(descending=True)
|
||||
|
||||
def straight_order(tensors):
|
||||
batch = straight_data(tensors)
|
||||
return batch.index_select(0, sort_order)
|
||||
|
||||
batch = {
|
||||
"id": straight_order([b["id"] for b in batches]),
|
||||
"nsentences": sum(b["nsentences"] for b in batches),
|
||||
"ntokens": sum(b["ntokens"] for b in batches),
|
||||
"net_input": {
|
||||
"src_tokens": straight_order(
|
||||
[b["net_input"]["src_tokens"] for b in batches]
|
||||
),
|
||||
"src_lengths": src_lengths,
|
||||
},
|
||||
"target": straight_order([b["target"] for b in batches])
|
||||
if batches[0]["target"] is not None
|
||||
else None,
|
||||
}
|
||||
if "prev_output_tokens" in batches[0]["net_input"]:
|
||||
batch["net_input"]["prev_output_tokens"] = straight_order(
|
||||
[b["net_input"]["prev_output_tokens"] for b in batches]
|
||||
)
|
||||
if "src_lang_id" in batches[0]["net_input"]:
|
||||
batch["net_input"]["src_lang_id"] = straight_order(
|
||||
[b["net_input"]["src_lang_id"] for b in batches]
|
||||
)
|
||||
if "tgt_lang_id" in batches[0]:
|
||||
batch["tgt_lang_id"] = straight_order(
|
||||
[b["tgt_lang_id"] for b in batches]
|
||||
)
|
||||
return batch
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
if self._sizes is not None:
|
||||
return self._sizes
|
||||
start_time = time.time()
|
||||
in_sub_dataset_indices = [
|
||||
self._cur_indices[
|
||||
0 if i == 0 else self.cumulated_sizes[i - 1] : self.cumulated_sizes[i]
|
||||
]
|
||||
for i in range(len(self.datasets))
|
||||
]
|
||||
sub_dataset_sizes = [
|
||||
d.sizes[indices]
|
||||
for d, indices in zip(self.datasets, in_sub_dataset_indices)
|
||||
]
|
||||
self._sizes = np.vstack(sub_dataset_sizes)
|
||||
logger.info(f"sizes() calling time: {get_time_gap(start_time, time.time())}")
|
||||
return self._sizes
|
||||
|
||||
def ordered_indices(self):
|
||||
if self.shuffle:
|
||||
indices = np.random.permutation(len(self))
|
||||
else:
|
||||
indices = np.arange(len(self))
|
||||
|
||||
sizes = self.sizes
|
||||
tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
|
||||
src_sizes = (
|
||||
sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
|
||||
)
|
||||
|
||||
# sort by target length, then source length
|
||||
if tgt_sizes is not None:
|
||||
indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
|
||||
sort_indices = indices[np.argsort(src_sizes[indices], kind="mergesort")]
|
||||
return sort_indices
|
||||
|
||||
def prefetch(self, indices):
|
||||
prefetch_indices = [[] for _ in range(len(self.datasets))]
|
||||
for i in indices:
|
||||
ds_idx, ds_sample_idx = self._get_dataset_and_index(i)
|
||||
prefetch_indices[ds_idx].append(ds_sample_idx)
|
||||
for i in range(len(prefetch_indices)):
|
||||
self.datasets[i].prefetch(prefetch_indices[i])
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
return False
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
super().set_epoch(epoch)
|
||||
if epoch == self._cur_epoch:
|
||||
# re-enter so return
|
||||
return
|
||||
for d in self.datasets:
|
||||
if hasattr(d, "set_epoch"):
|
||||
d.set_epoch(epoch)
|
||||
self._cur_epoch = epoch
|
||||
self._establish_virtual_datasets()
|
||||
|
||||
def _establish_virtual_datasets(self):
|
||||
if self.sample_ratios is None and self._cur_indices is not None:
|
||||
# not a samping dataset, no need to resample if indices are already established
|
||||
return
|
||||
self._reset_cached_properties()
|
||||
|
||||
start_time = time.time()
|
||||
# Generate a weighted sample of indices as a function of the
|
||||
# random seed and the current epoch.
|
||||
rng = np.random.RandomState(
|
||||
[
|
||||
int(
|
||||
hashlib.sha1(
|
||||
str(self.__class__.__name__).encode("utf-8")
|
||||
).hexdigest(),
|
||||
16,
|
||||
)
|
||||
% (2**32),
|
||||
self.seed % (2**32), # global seed
|
||||
self._cur_epoch, # epoch index,
|
||||
]
|
||||
)
|
||||
self._clean_if_not_none(
|
||||
[self.cumulated_sizes, self.virtual_size_per_dataset, self._sizes]
|
||||
)
|
||||
self._sizes = None
|
||||
|
||||
indices, cumulated_sizes, virtual_size_per_dataset = self.get_virtual_indices(
|
||||
rng, self.datasets, self.sample_ratios, self.virtual_size
|
||||
)
|
||||
self._cur_indices = indices
|
||||
self.cumulated_sizes = cumulated_sizes
|
||||
self.virtual_size_per_dataset = virtual_size_per_dataset
|
||||
|
||||
raw_sizes = [len(d) for d in self.datasets]
|
||||
sampled_sizes = self.virtual_size_per_dataset
|
||||
logger.info(
|
||||
f"[{self.split}] Raw sizes: {str(dict(zip(self.keys, raw_sizes)))}; "
|
||||
f"raw total size: {sum(raw_sizes)}"
|
||||
)
|
||||
logger.info(
|
||||
f"[{self.split}] Resampled sizes: {str(dict(zip(self.keys, sampled_sizes)))}; "
|
||||
f"resampled total size: {sum(sampled_sizes)}"
|
||||
)
|
||||
if self.sample_ratios is not None:
|
||||
logger.info(
|
||||
f"[{self.split}] Upsampling ratios: {str(dict(zip(self.keys, self.sample_ratios)))}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"[{self.split}] A concat dataset")
|
||||
logger.info(
|
||||
f"[{self.split}] virtual dataset established time: {get_time_gap(start_time, time.time())}"
|
||||
)
|
||||
|
||||
def filter_indices_by_size(self, indices, max_sizes):
|
||||
"""Filter a list of sample indices. Remove those that are longer
|
||||
than specified in max_sizes.
|
||||
|
||||
Args:
|
||||
indices (np.array): original array of sample indices
|
||||
max_sizes (int or list[int] or tuple[int]): max sample size,
|
||||
can be defined separately for src and tgt (then list or tuple)
|
||||
|
||||
Returns:
|
||||
np.array: filtered sample array
|
||||
list: list of removed indices
|
||||
"""
|
||||
sizes = self.sizes
|
||||
tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
|
||||
src_sizes = (
|
||||
sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
|
||||
)
|
||||
|
||||
return data_utils.filter_paired_dataset_indices_by_size(
|
||||
src_sizes, tgt_sizes, indices, max_sizes
|
||||
)
|
||||
Reference in New Issue
Block a user