mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-02 10:00:09 +00:00
206 lines
7.0 KiB
Python
206 lines
7.0 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import logging
|
|
import numpy as np
|
|
import torch.utils.data
|
|
from fairseq.data import data_utils
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EpochListening:
|
|
"""Mixin for receiving updates whenever the epoch increments."""
|
|
|
|
@property
|
|
def can_reuse_epoch_itr_across_epochs(self):
|
|
"""
|
|
Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for
|
|
this dataset across epochs.
|
|
|
|
This needs to return ``False`` if the sample sizes can change across
|
|
epochs, in which case we may need to regenerate batches at each epoch.
|
|
If your dataset relies in ``set_epoch`` then you should consider setting
|
|
this to ``False``.
|
|
"""
|
|
return True
|
|
|
|
def set_epoch(self, epoch):
|
|
"""Will receive the updated epoch number at the beginning of the epoch."""
|
|
pass
|
|
|
|
|
|
class FairseqDataset(torch.utils.data.Dataset, EpochListening):
|
|
"""A dataset that provides helpers for batching."""
|
|
|
|
def __getitem__(self, index):
|
|
raise NotImplementedError
|
|
|
|
def __len__(self):
|
|
raise NotImplementedError
|
|
|
|
def collater(self, samples):
|
|
"""Merge a list of samples to form a mini-batch.
|
|
|
|
Args:
|
|
samples (List[dict]): samples to collate
|
|
|
|
Returns:
|
|
dict: a mini-batch suitable for forwarding with a Model
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def num_tokens(self, index):
|
|
"""Return the number of tokens in a sample. This value is used to
|
|
enforce ``--max-tokens`` during batching."""
|
|
raise NotImplementedError
|
|
|
|
def num_tokens_vec(self, indices):
|
|
"""Return the number of tokens for a set of positions defined by indices.
|
|
This value is used to enforce ``--max-tokens`` during batching."""
|
|
raise NotImplementedError
|
|
|
|
def size(self, index):
|
|
"""Return an example's size as a float or tuple. This value is used when
|
|
filtering a dataset with ``--max-positions``."""
|
|
raise NotImplementedError
|
|
|
|
def ordered_indices(self):
|
|
"""Return an ordered list of indices. Batches will be constructed based
|
|
on this order."""
|
|
return np.arange(len(self), dtype=np.int64)
|
|
|
|
@property
|
|
def supports_prefetch(self):
|
|
"""Whether this dataset supports prefetching."""
|
|
return False
|
|
|
|
def attr(self, attr: str, index: int):
|
|
return getattr(self, attr, None)
|
|
|
|
def prefetch(self, indices):
|
|
"""Prefetch the data required for this epoch."""
|
|
raise NotImplementedError
|
|
|
|
def get_batch_shapes(self):
|
|
"""
|
|
Return a list of valid batch shapes, for example::
|
|
|
|
[(8, 512), (16, 256), (32, 128)]
|
|
|
|
The first dimension of each tuple is the batch size and can be ``None``
|
|
to automatically infer the max batch size based on ``--max-tokens``.
|
|
The second dimension of each tuple is the max supported length as given
|
|
by :func:`fairseq.data.FairseqDataset.num_tokens`.
|
|
|
|
This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size`
|
|
to restrict batch shapes. This is useful on TPUs to avoid too many
|
|
dynamic shapes (and recompilations).
|
|
"""
|
|
return None
|
|
|
|
def batch_by_size(
|
|
self,
|
|
indices,
|
|
max_tokens=None,
|
|
max_sentences=None,
|
|
required_batch_size_multiple=1,
|
|
):
|
|
"""
|
|
Given an ordered set of indices, return batches according to
|
|
*max_tokens*, *max_sentences* and *required_batch_size_multiple*.
|
|
"""
|
|
from fairseq.data import data_utils
|
|
|
|
fixed_shapes = self.get_batch_shapes()
|
|
if fixed_shapes is not None:
|
|
|
|
def adjust_bsz(bsz, num_tokens):
|
|
if bsz is None:
|
|
assert max_tokens is not None, "Must specify --max-tokens"
|
|
bsz = max_tokens // num_tokens
|
|
if max_sentences is not None:
|
|
bsz = min(bsz, max_sentences)
|
|
elif (
|
|
bsz >= required_batch_size_multiple
|
|
and bsz % required_batch_size_multiple != 0
|
|
):
|
|
bsz -= bsz % required_batch_size_multiple
|
|
return bsz
|
|
|
|
fixed_shapes = np.array(
|
|
[
|
|
[adjust_bsz(bsz, num_tokens), num_tokens]
|
|
for (bsz, num_tokens) in fixed_shapes
|
|
]
|
|
)
|
|
|
|
try:
|
|
num_tokens_vec = self.num_tokens_vec(indices).astype("int64")
|
|
except NotImplementedError:
|
|
num_tokens_vec = None
|
|
|
|
return data_utils.batch_by_size(
|
|
indices,
|
|
num_tokens_fn=self.num_tokens,
|
|
num_tokens_vec=num_tokens_vec,
|
|
max_tokens=max_tokens,
|
|
max_sentences=max_sentences,
|
|
required_batch_size_multiple=required_batch_size_multiple,
|
|
fixed_shapes=fixed_shapes,
|
|
)
|
|
|
|
def filter_indices_by_size(self, indices, max_sizes):
|
|
"""
|
|
Filter a list of sample indices. Remove those that are longer than
|
|
specified in *max_sizes*.
|
|
|
|
WARNING: don't update, override method in child classes
|
|
|
|
Args:
|
|
indices (np.array): original array of sample indices
|
|
max_sizes (int or list[int] or tuple[int]): max sample size,
|
|
can be defined separately for src and tgt (then list or tuple)
|
|
|
|
Returns:
|
|
np.array: filtered sample array
|
|
list: list of removed indices
|
|
"""
|
|
if isinstance(max_sizes, float) or isinstance(max_sizes, int):
|
|
if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray):
|
|
ignored = indices[self.sizes[indices] > max_sizes].tolist()
|
|
indices = indices[self.sizes[indices] <= max_sizes]
|
|
elif (
|
|
hasattr(self, "sizes")
|
|
and isinstance(self.sizes, list)
|
|
and len(self.sizes) == 1
|
|
):
|
|
ignored = indices[self.sizes[0][indices] > max_sizes].tolist()
|
|
indices = indices[self.sizes[0][indices] <= max_sizes]
|
|
else:
|
|
indices, ignored = data_utils._filter_by_size_dynamic(
|
|
indices, self.size, max_sizes
|
|
)
|
|
else:
|
|
indices, ignored = data_utils._filter_by_size_dynamic(
|
|
indices, self.size, max_sizes
|
|
)
|
|
return indices, ignored
|
|
|
|
@property
|
|
def supports_fetch_outside_dataloader(self):
|
|
"""Whether this dataset supports fetching outside the workers of the dataloader."""
|
|
return True
|
|
|
|
|
|
class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening):
|
|
"""
|
|
For datasets that need to be read sequentially, usually because the data is
|
|
being streamed or otherwise can't be manipulated on a single machine.
|
|
"""
|
|
|
|
def __iter__(self):
|
|
raise NotImplementedError
|