mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-07 20:40:03 +00:00
32 lines
941 B
Python
32 lines
941 B
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from fairseq.data import data_utils
|
|
|
|
from . import BaseWrapperDataset
|
|
|
|
|
|
class PadDataset(BaseWrapperDataset):
|
|
def __init__(self, dataset, pad_idx, left_pad, pad_length=None):
|
|
super().__init__(dataset)
|
|
self.pad_idx = pad_idx
|
|
self.left_pad = left_pad
|
|
self.pad_length = pad_length
|
|
|
|
def collater(self, samples):
|
|
return data_utils.collate_tokens(
|
|
samples, self.pad_idx, left_pad=self.left_pad, pad_to_length=self.pad_length
|
|
)
|
|
|
|
|
|
class LeftPadDataset(PadDataset):
|
|
def __init__(self, dataset, pad_idx):
|
|
super().__init__(dataset, pad_idx, left_pad=True)
|
|
|
|
|
|
class RightPadDataset(PadDataset):
|
|
def __init__(self, dataset, pad_idx):
|
|
super().__init__(dataset, pad_idx, left_pad=False)
|