mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-11 14:30:03 +00:00
55 lines
1.5 KiB
Python
55 lines
1.5 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
|
|
from . import FairseqDataset
|
|
|
|
|
|
class ConcatSentencesDataset(FairseqDataset):
|
|
def __init__(self, *datasets):
|
|
super().__init__()
|
|
self.datasets = datasets
|
|
assert all(
|
|
len(ds) == len(datasets[0]) for ds in datasets
|
|
), "datasets must have the same length"
|
|
|
|
def __getitem__(self, index):
|
|
return torch.cat([ds[index] for ds in self.datasets])
|
|
|
|
def __len__(self):
|
|
return len(self.datasets[0])
|
|
|
|
def collater(self, samples):
|
|
return self.datasets[0].collater(samples)
|
|
|
|
@property
|
|
def sizes(self):
|
|
return sum(ds.sizes for ds in self.datasets)
|
|
|
|
def num_tokens(self, index):
|
|
return sum(ds.num_tokens(index) for ds in self.datasets)
|
|
|
|
def size(self, index):
|
|
return sum(ds.size(index) for ds in self.datasets)
|
|
|
|
def ordered_indices(self):
|
|
return self.datasets[0].ordered_indices()
|
|
|
|
@property
|
|
def supports_prefetch(self):
|
|
return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets)
|
|
|
|
def prefetch(self, indices):
|
|
for ds in self.datasets:
|
|
if getattr(ds, "supports_prefetch", False):
|
|
ds.prefetch(indices)
|
|
|
|
def set_epoch(self, epoch):
|
|
super().set_epoch(epoch)
|
|
for ds in self.datasets:
|
|
if hasattr(ds, "set_epoch"):
|
|
ds.set_epoch(epoch)
|