mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-02-27 08:34:10 +00:00
19 lines
485 B
Python
19 lines
485 B
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
|
|
from . import BaseWrapperDataset
|
|
|
|
|
|
class RollDataset(BaseWrapperDataset):
|
|
def __init__(self, dataset, shifts):
|
|
super().__init__(dataset)
|
|
self.shifts = shifts
|
|
|
|
def __getitem__(self, index):
|
|
item = self.dataset[index]
|
|
return torch.roll(item, self.shifts)
|