mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-29 10:51:19 +00:00
79 lines
2.3 KiB
Python
79 lines
2.3 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import numpy as np
|
|
import torch.nn.functional as F
|
|
from fairseq.data import BaseWrapperDataset
|
|
from fairseq.data.data_utils import get_buckets, get_bucketed_sizes
|
|
|
|
|
|
class BucketPadLengthDataset(BaseWrapperDataset):
|
|
"""
|
|
Bucket and pad item lengths to the nearest bucket size. This can be used to
|
|
reduce the number of unique batch shapes, which is important on TPUs since
|
|
each new batch shape requires a recompilation.
|
|
|
|
Args:
|
|
dataset (FairseqDatset): dataset to bucket
|
|
sizes (List[int]): all item sizes
|
|
num_buckets (int): number of buckets to create
|
|
pad_idx (int): padding symbol
|
|
left_pad (bool): if True, pad on the left; otherwise right pad
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dataset,
|
|
sizes,
|
|
num_buckets,
|
|
pad_idx,
|
|
left_pad,
|
|
tensor_key=None,
|
|
):
|
|
super().__init__(dataset)
|
|
self.pad_idx = pad_idx
|
|
self.left_pad = left_pad
|
|
|
|
assert num_buckets > 0
|
|
self.buckets = get_buckets(sizes, num_buckets)
|
|
self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets)
|
|
self._tensor_key = tensor_key
|
|
|
|
def _set_tensor(self, item, val):
|
|
if self._tensor_key is None:
|
|
return val
|
|
item[self._tensor_key] = val
|
|
return item
|
|
|
|
def _get_tensor(self, item):
|
|
if self._tensor_key is None:
|
|
return item
|
|
return item[self._tensor_key]
|
|
|
|
def _pad(self, tensor, bucket_size, dim=-1):
|
|
num_pad = bucket_size - tensor.size(dim)
|
|
return F.pad(
|
|
tensor,
|
|
(num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad),
|
|
value=self.pad_idx,
|
|
)
|
|
|
|
def __getitem__(self, index):
|
|
item = self.dataset[index]
|
|
bucket_size = self._bucketed_sizes[index]
|
|
tensor = self._get_tensor(item)
|
|
padded = self._pad(tensor, bucket_size)
|
|
return self._set_tensor(item, padded)
|
|
|
|
@property
|
|
def sizes(self):
|
|
return self._bucketed_sizes
|
|
|
|
def num_tokens(self, index):
|
|
return self._bucketed_sizes[index]
|
|
|
|
def size(self, index):
|
|
return self._bucketed_sizes[index]
|