mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-05-04 13:21:18 +00:00
Add monkey patched fairseq package to run on python 3.11 (what is needed for our use of RVC at least)
This commit is contained in:
@@ -0,0 +1,284 @@
|
||||
# Copyright (c) 2021-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Optional, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
from fairseq.data.resampling_dataset import ResamplingDataset
|
||||
import torch
|
||||
from fairseq.data import (
|
||||
ConcatDataset,
|
||||
LanguagePairDataset,
|
||||
FileAudioDataset,
|
||||
data_utils,
|
||||
)
|
||||
from fairseq.data import FairseqDataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModalityDatasetItem(NamedTuple):
|
||||
datasetname: str
|
||||
dataset: any
|
||||
max_positions: List[int]
|
||||
max_tokens: Optional[int] = None
|
||||
max_sentences: Optional[int] = None
|
||||
|
||||
|
||||
def resampling_dataset_present(ds):
|
||||
if isinstance(ds, ResamplingDataset):
|
||||
return True
|
||||
if isinstance(ds, ConcatDataset):
|
||||
return any(resampling_dataset_present(d) for d in ds.datasets)
|
||||
if hasattr(ds, "dataset"):
|
||||
return resampling_dataset_present(ds.dataset)
|
||||
return False
|
||||
|
||||
|
||||
# MultiModalityDataset: it concate multiple datasets with different modalities.
|
||||
# Compared with ConcatDataset it can 1) sample data given the ratios for different datasets
|
||||
# 2) it adds mode to indicate what type of the data samples come from.
|
||||
# It will be used with GroupedEpochBatchIterator together to generate mini-batch with samples
|
||||
# from the same type of dataset
|
||||
# If only one dataset is used, it will perform like the original dataset with mode added
|
||||
class MultiModalityDataset(ConcatDataset):
|
||||
def __init__(self, datasets: List[ModalityDatasetItem]):
|
||||
id_to_mode = []
|
||||
dsets = []
|
||||
max_tokens = []
|
||||
max_sentences = []
|
||||
max_positions = []
|
||||
for dset in datasets:
|
||||
id_to_mode.append(dset.datasetname)
|
||||
dsets.append(dset.dataset)
|
||||
max_tokens.append(dset.max_tokens)
|
||||
max_positions.append(dset.max_positions)
|
||||
max_sentences.append(dset.max_sentences)
|
||||
weights = [1.0 for s in dsets]
|
||||
super().__init__(dsets, weights)
|
||||
self.max_tokens = max_tokens
|
||||
self.max_positions = max_positions
|
||||
self.max_sentences = max_sentences
|
||||
self.id_to_mode = id_to_mode
|
||||
self.raw_sub_batch_samplers = []
|
||||
self._cur_epoch = 0
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
super().set_epoch(epoch)
|
||||
self._cur_epoch = epoch
|
||||
|
||||
def __getitem__(self, idx):
|
||||
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
|
||||
sample = self.datasets[dataset_idx][sample_idx]
|
||||
return (dataset_idx, sample)
|
||||
|
||||
def collater(self, samples):
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
dataset_idx = samples[0][0]
|
||||
# make sure all samples in samples are from same dataset
|
||||
assert sum([0 if dataset_idx == s[0] else 1 for s in samples]) == 0
|
||||
samples = self.datasets[dataset_idx].collater([x[1] for x in samples])
|
||||
# add mode
|
||||
samples["net_input"]["mode"] = self.id_to_mode[dataset_idx]
|
||||
|
||||
return samples
|
||||
|
||||
def size(self, index: int):
|
||||
if len(self.datasets) == 1:
|
||||
return self.datasets[0].size(index)
|
||||
return super().size(index)
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
if len(self.datasets) == 1:
|
||||
return self.datasets[0].sizes
|
||||
return super().sizes
|
||||
|
||||
def ordered_indices(self):
|
||||
"""
|
||||
Returns indices sorted by length. So less padding is needed.
|
||||
"""
|
||||
if len(self.datasets) == 1:
|
||||
return self.datasets[0].ordered_indices()
|
||||
indices_group = []
|
||||
for d_idx, ds in enumerate(self.datasets):
|
||||
sample_num = self.cumulative_sizes[d_idx]
|
||||
if d_idx > 0:
|
||||
sample_num = sample_num - self.cumulative_sizes[d_idx - 1]
|
||||
assert sample_num == len(ds)
|
||||
indices_group.append(ds.ordered_indices())
|
||||
return indices_group
|
||||
|
||||
def get_raw_batch_samplers(self, required_batch_size_multiple, seed):
|
||||
with data_utils.numpy_seed(seed):
|
||||
indices = self.ordered_indices()
|
||||
for i, ds in enumerate(self.datasets):
|
||||
# If we have ResamplingDataset, the same id can correpond to a different
|
||||
# sample in the next epoch, so we need to rebuild this at every epoch
|
||||
if i < len(self.raw_sub_batch_samplers) and not resampling_dataset_present(
|
||||
ds
|
||||
):
|
||||
logger.info(f"dataset {i} is valid and it is not re-sampled")
|
||||
continue
|
||||
indices[i] = ds.filter_indices_by_size(
|
||||
indices[i],
|
||||
self.max_positions[i],
|
||||
)[0]
|
||||
sub_batch_sampler = ds.batch_by_size(
|
||||
indices[i],
|
||||
max_tokens=self.max_tokens[i],
|
||||
max_sentences=self.max_sentences[i],
|
||||
required_batch_size_multiple=required_batch_size_multiple,
|
||||
)
|
||||
if i < len(self.raw_sub_batch_samplers):
|
||||
self.raw_sub_batch_samplers[i] = sub_batch_sampler
|
||||
else:
|
||||
self.raw_sub_batch_samplers.append(sub_batch_sampler)
|
||||
|
||||
def get_batch_samplers(self, mult_ratios, required_batch_size_multiple, seed):
|
||||
self.get_raw_batch_samplers(required_batch_size_multiple, seed)
|
||||
batch_samplers = []
|
||||
for i, _ in enumerate(self.datasets):
|
||||
if i > 0:
|
||||
sub_batch_sampler = [
|
||||
[y + self.cumulative_sizes[i - 1] for y in x]
|
||||
for x in self.raw_sub_batch_samplers[i]
|
||||
]
|
||||
else:
|
||||
sub_batch_sampler = list(self.raw_sub_batch_samplers[i])
|
||||
smp_r = mult_ratios[i]
|
||||
if smp_r != 1:
|
||||
is_increase = "increased" if smp_r > 1 else "decreased"
|
||||
logger.info(
|
||||
"number of batch for the dataset {} is {} from {} to {}".format(
|
||||
self.id_to_mode[i],
|
||||
is_increase,
|
||||
len(sub_batch_sampler),
|
||||
int(len(sub_batch_sampler) * smp_r),
|
||||
)
|
||||
)
|
||||
mul_samplers = []
|
||||
for _ in range(math.floor(smp_r)):
|
||||
mul_samplers = mul_samplers + sub_batch_sampler
|
||||
if math.floor(smp_r) != smp_r:
|
||||
with data_utils.numpy_seed(seed + self._cur_epoch):
|
||||
np.random.shuffle(sub_batch_sampler)
|
||||
smp_num = int(
|
||||
(smp_r - math.floor(smp_r)) * len(sub_batch_sampler)
|
||||
)
|
||||
mul_samplers = mul_samplers + sub_batch_sampler[:smp_num]
|
||||
sub_batch_sampler = mul_samplers
|
||||
else:
|
||||
logger.info(
|
||||
"dataset {} batch number is {} ".format(
|
||||
self.id_to_mode[i], len(sub_batch_sampler)
|
||||
)
|
||||
)
|
||||
batch_samplers.append(sub_batch_sampler)
|
||||
|
||||
return batch_samplers
|
||||
|
||||
|
||||
class LangPairMaskDataset(FairseqDataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: LanguagePairDataset,
|
||||
src_eos: int,
|
||||
src_bos: Optional[int] = None,
|
||||
noise_id: Optional[int] = -1,
|
||||
mask_ratio: Optional[float] = 0,
|
||||
mask_type: Optional[str] = "random",
|
||||
):
|
||||
self.dataset = dataset
|
||||
self.src_eos = src_eos
|
||||
self.src_bos = src_bos
|
||||
self.noise_id = noise_id
|
||||
self.mask_ratio = mask_ratio
|
||||
self.mask_type = mask_type
|
||||
assert mask_type in ("random", "tail")
|
||||
|
||||
@property
|
||||
def src_sizes(self):
|
||||
return self.dataset.src_sizes
|
||||
|
||||
@property
|
||||
def tgt_sizes(self):
|
||||
return self.dataset.tgt_sizes
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
# dataset.sizes can be a dynamically computed sizes:
|
||||
return self.dataset.sizes
|
||||
|
||||
def get_batch_shapes(self):
|
||||
if hasattr(self.dataset, "get_batch_shapes"):
|
||||
return self.dataset.get_batch_shapes()
|
||||
return self.dataset.buckets
|
||||
|
||||
def num_tokens_vec(self, indices):
|
||||
return self.dataset.num_tokens_vec(indices)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.dataset.num_tokens(index)
|
||||
|
||||
def size(self, index):
|
||||
return self.dataset.size(index)
|
||||
|
||||
def ordered_indices(self):
|
||||
return self.dataset.ordered_indices()
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return getattr(self.dataset, "supports_prefetch", False)
|
||||
|
||||
def prefetch(self, indices):
|
||||
return self.dataset.prefetch(indices)
|
||||
|
||||
def mask_src_tokens(self, sample):
|
||||
src_item = sample["source"]
|
||||
mask = None
|
||||
if self.mask_type == "random":
|
||||
mask = torch.rand(len(src_item)).le(self.mask_ratio)
|
||||
else:
|
||||
mask = torch.ones(len(src_item))
|
||||
mask[: int(len(src_item) * (1 - self.mask_ratio))] = 0
|
||||
mask = mask.eq(1)
|
||||
if src_item[0] == self.src_bos:
|
||||
mask[0] = False
|
||||
if src_item[-1] == self.src_eos:
|
||||
mask[-1] = False
|
||||
mask_src_item = src_item.masked_fill(mask, self.noise_id)
|
||||
smp = {"id": sample["id"], "source": mask_src_item, "target": sample["target"]}
|
||||
return smp
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.dataset[index]
|
||||
if self.mask_ratio > 0:
|
||||
sample = self.mask_src_tokens(sample)
|
||||
return sample
|
||||
|
||||
def collater(self, samples, pad_to_length=None):
|
||||
return self.dataset.collater(samples, pad_to_length)
|
||||
|
||||
|
||||
class FileAudioDatasetWrapper(FileAudioDataset):
|
||||
def collater(self, samples):
|
||||
samples = super().collater(samples)
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
samples["net_input"]["src_tokens"] = samples["net_input"]["source"]
|
||||
samples["net_input"]["prev_output_tokens"] = None
|
||||
del samples["net_input"]["source"]
|
||||
samples["net_input"]["src_lengths"] = None
|
||||
samples["net_input"]["alignment"] = None
|
||||
return samples
|
||||
Reference in New Issue
Block a user