mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-02-22 06:04:26 +00:00
79 lines
2.1 KiB
Python
79 lines
2.1 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from torch.utils.data.dataloader import default_collate
|
|
|
|
from . import FairseqDataset
|
|
|
|
|
|
class BaseWrapperDataset(FairseqDataset):
|
|
def __init__(self, dataset):
|
|
super().__init__()
|
|
self.dataset = dataset
|
|
|
|
def __getitem__(self, index):
|
|
return self.dataset[index]
|
|
|
|
def __len__(self):
|
|
return len(self.dataset)
|
|
|
|
def collater(self, samples):
|
|
if hasattr(self.dataset, "collater"):
|
|
return self.dataset.collater(samples)
|
|
else:
|
|
return default_collate(samples)
|
|
|
|
@property
|
|
def sizes(self):
|
|
return self.dataset.sizes
|
|
|
|
def num_tokens(self, index):
|
|
return self.dataset.num_tokens(index)
|
|
|
|
def size(self, index):
|
|
return self.dataset.size(index)
|
|
|
|
def ordered_indices(self):
|
|
return self.dataset.ordered_indices()
|
|
|
|
@property
|
|
def supports_prefetch(self):
|
|
return getattr(self.dataset, "supports_prefetch", False)
|
|
|
|
def attr(self, attr: str, index: int):
|
|
return self.dataset.attr(attr, index)
|
|
|
|
def prefetch(self, indices):
|
|
self.dataset.prefetch(indices)
|
|
|
|
def get_batch_shapes(self):
|
|
return self.dataset.get_batch_shapes()
|
|
|
|
def batch_by_size(
|
|
self,
|
|
indices,
|
|
max_tokens=None,
|
|
max_sentences=None,
|
|
required_batch_size_multiple=1,
|
|
):
|
|
return self.dataset.batch_by_size(
|
|
indices,
|
|
max_tokens=max_tokens,
|
|
max_sentences=max_sentences,
|
|
required_batch_size_multiple=required_batch_size_multiple,
|
|
)
|
|
|
|
def filter_indices_by_size(self, indices, max_sizes):
|
|
return self.dataset.filter_indices_by_size(indices, max_sizes)
|
|
|
|
@property
|
|
def can_reuse_epoch_itr_across_epochs(self):
|
|
return self.dataset.can_reuse_epoch_itr_across_epochs
|
|
|
|
def set_epoch(self, epoch):
|
|
super().set_epoch(epoch)
|
|
if hasattr(self.dataset, "set_epoch"):
|
|
self.dataset.set_epoch(epoch)
|