mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-08 04:50:14 +00:00
33 lines
729 B
Python
33 lines
729 B
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from . import BaseWrapperDataset
|
|
|
|
|
|
class ListDataset(BaseWrapperDataset):
|
|
def __init__(self, dataset, sizes=None):
|
|
super().__init__(dataset)
|
|
self._sizes = sizes
|
|
|
|
def __iter__(self):
|
|
for x in self.dataset:
|
|
yield x
|
|
|
|
def collater(self, samples):
|
|
return samples
|
|
|
|
@property
|
|
def sizes(self):
|
|
return self._sizes
|
|
|
|
def num_tokens(self, index):
|
|
return self.sizes[index]
|
|
|
|
def size(self, index):
|
|
return self.sizes[index]
|
|
|
|
def set_epoch(self, epoch):
|
|
pass
|