mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-02 18:10:04 +00:00
29 lines
953 B
Python
29 lines
953 B
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from . import BaseWrapperDataset
|
|
|
|
|
|
class PrependDataset(BaseWrapperDataset):
|
|
def __init__(self, dataset, prepend_getter, ensure_first_token_is=None):
|
|
super().__init__(dataset)
|
|
self.prepend_getter = prepend_getter
|
|
self.ensure_first_token = ensure_first_token_is
|
|
|
|
def __getitem__(self, idx):
|
|
item = self.dataset[idx]
|
|
is_tuple = isinstance(item, tuple)
|
|
src = item[0] if is_tuple else item
|
|
|
|
assert self.ensure_first_token is None or src[0] == self.ensure_first_token
|
|
prepend_idx = self.prepend_getter(self.dataset, idx)
|
|
assert isinstance(prepend_idx, int)
|
|
src[0] = prepend_idx
|
|
item = tuple((src,) + item[1:]) if is_tuple else src
|
|
return item
|