mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-07 20:40:03 +00:00
207 lines
7.5 KiB
Python
207 lines
7.5 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import numpy as np
|
|
import torch
|
|
from fairseq.data import FairseqDataset, plasma_utils
|
|
from fairseq.data.indexed_dataset import best_fitting_int_dtype
|
|
from typing import Tuple
|
|
|
|
|
|
class TokenBlockDataset(FairseqDataset):
|
|
"""Break a Dataset of tokens into blocks.
|
|
|
|
Args:
|
|
dataset (~torch.utils.data.Dataset): dataset to break into blocks
|
|
sizes (List[int]): sentence lengths (required for 'complete' and 'eos')
|
|
block_size (int): maximum block size (ignored in 'eos' break mode)
|
|
break_mode (str, optional): Mode used for breaking tokens. Values can
|
|
be one of:
|
|
- 'none': break tokens into equally sized blocks (up to block_size)
|
|
- 'complete': break tokens into blocks (up to block_size) such that
|
|
blocks contains complete sentences, although block_size may be
|
|
exceeded if some sentences exceed block_size
|
|
- 'complete_doc': similar to 'complete' mode, but do not
|
|
cross document boundaries
|
|
- 'eos': each block contains one sentence (block_size is ignored)
|
|
include_targets (bool, optional): return next tokens as targets
|
|
(default: False).
|
|
document_sep_len (int, optional): document separator size (required for
|
|
'complete_doc' break mode). Typically 1 if the sentences have eos
|
|
and 0 otherwise.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dataset,
|
|
sizes,
|
|
block_size,
|
|
pad,
|
|
eos,
|
|
break_mode=None,
|
|
include_targets=False,
|
|
document_sep_len=1,
|
|
use_plasma_view=False,
|
|
split_path=None,
|
|
plasma_path=None,
|
|
):
|
|
|
|
super().__init__()
|
|
self.dataset = dataset
|
|
self.pad = pad
|
|
self.eos = eos
|
|
self.include_targets = include_targets
|
|
|
|
assert len(dataset) > 0
|
|
|
|
assert len(dataset) == len(sizes)
|
|
_sizes, block_to_dataset_index, slice_indices = self._build_slice_indices(
|
|
sizes, break_mode, document_sep_len, block_size
|
|
)
|
|
if use_plasma_view:
|
|
plasma_id = (block_size, document_sep_len, str(break_mode), len(dataset))
|
|
self._slice_indices = plasma_utils.PlasmaView(
|
|
slice_indices, split_path, (plasma_id, 0), plasma_path=plasma_path
|
|
)
|
|
self._sizes = plasma_utils.PlasmaView(
|
|
_sizes, split_path, (plasma_id, 1), plasma_path=plasma_path
|
|
)
|
|
self._block_to_dataset_index = plasma_utils.PlasmaView(
|
|
block_to_dataset_index,
|
|
split_path,
|
|
(plasma_id, 2),
|
|
plasma_path=plasma_path,
|
|
)
|
|
else:
|
|
self._slice_indices = plasma_utils.PlasmaArray(slice_indices)
|
|
self._sizes = plasma_utils.PlasmaArray(_sizes)
|
|
self._block_to_dataset_index = plasma_utils.PlasmaArray(
|
|
block_to_dataset_index
|
|
)
|
|
|
|
@staticmethod
|
|
def _build_slice_indices(
|
|
sizes, break_mode, document_sep_len, block_size
|
|
) -> Tuple[np.ndarray]:
|
|
"""Use token_block_utils_fast to build arrays for indexing into self.dataset"""
|
|
try:
|
|
from fairseq.data.token_block_utils_fast import (
|
|
_get_slice_indices_fast,
|
|
_get_block_to_dataset_index_fast,
|
|
)
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Please build Cython components with: `pip install --editable .` "
|
|
"or `python setup.py build_ext --inplace`"
|
|
)
|
|
|
|
if isinstance(sizes, list):
|
|
sizes = np.array(sizes, dtype=np.int64)
|
|
else:
|
|
if torch.is_tensor(sizes):
|
|
sizes = sizes.numpy()
|
|
sizes = sizes.astype(np.int64)
|
|
|
|
break_mode = break_mode if break_mode is not None else "none"
|
|
|
|
# For "eos" break-mode, block_size is not required parameters.
|
|
if break_mode == "eos" and block_size is None:
|
|
block_size = 0
|
|
|
|
slice_indices = _get_slice_indices_fast(
|
|
sizes, str(break_mode), block_size, document_sep_len
|
|
)
|
|
_sizes = slice_indices[:, 1] - slice_indices[:, 0]
|
|
|
|
# build index mapping block indices to the underlying dataset indices
|
|
if break_mode == "eos":
|
|
# much faster version for eos break mode
|
|
block_to_dataset_index = np.stack(
|
|
[
|
|
np.arange(len(sizes)), # starting index in dataset
|
|
np.zeros(
|
|
len(sizes), dtype=np.compat.long
|
|
), # starting offset within starting index
|
|
np.arange(len(sizes)), # ending index in dataset
|
|
],
|
|
1,
|
|
)
|
|
else:
|
|
block_to_dataset_index = _get_block_to_dataset_index_fast(
|
|
sizes,
|
|
slice_indices,
|
|
)
|
|
size_dtype = np.uint16 if block_size < 65535 else np.uint32
|
|
num_tokens = slice_indices[-1].max()
|
|
slice_indices_dtype = best_fitting_int_dtype(num_tokens)
|
|
slice_indices = slice_indices.astype(slice_indices_dtype)
|
|
_sizes = _sizes.astype(size_dtype)
|
|
block_to_dataset_index = block_to_dataset_index.astype(slice_indices_dtype)
|
|
return _sizes, block_to_dataset_index, slice_indices
|
|
|
|
@property
|
|
def slice_indices(self):
|
|
return self._slice_indices.array
|
|
|
|
@property
|
|
def sizes(self):
|
|
return self._sizes.array
|
|
|
|
@property
|
|
def block_to_dataset_index(self):
|
|
return self._block_to_dataset_index.array
|
|
|
|
def attr(self, attr: str, index: int):
|
|
start_ds_idx, _, _ = self.block_to_dataset_index[index]
|
|
return self.dataset.attr(attr, start_ds_idx)
|
|
|
|
def __getitem__(self, index):
|
|
start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index]
|
|
|
|
buffer = torch.cat(
|
|
[self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)]
|
|
)
|
|
slice_s, slice_e = self.slice_indices[index]
|
|
length = slice_e - slice_s
|
|
s, e = start_offset, start_offset + length
|
|
item = buffer[s:e]
|
|
|
|
if self.include_targets:
|
|
# *target* is the original sentence (=item)
|
|
# *source* is shifted right by 1 (maybe left-padded with eos)
|
|
# *past_target* is shifted right by 2 (left-padded as needed)
|
|
if s == 0:
|
|
source = torch.cat([item.new([self.eos]), buffer[0 : e - 1]])
|
|
past_target = torch.cat(
|
|
[item.new([self.pad, self.eos]), buffer[0 : e - 2]]
|
|
)
|
|
else:
|
|
source = buffer[s - 1 : e - 1]
|
|
if s == 1:
|
|
past_target = torch.cat([item.new([self.eos]), buffer[0 : e - 2]])
|
|
else:
|
|
past_target = buffer[s - 2 : e - 2]
|
|
|
|
return source, item, past_target
|
|
|
|
return item
|
|
|
|
def __len__(self):
|
|
return len(self.slice_indices)
|
|
|
|
@property
|
|
def supports_prefetch(self):
|
|
return getattr(self.dataset, "supports_prefetch", False)
|
|
|
|
def prefetch(self, indices):
|
|
self.dataset.prefetch(
|
|
{
|
|
ds_idx
|
|
for index in indices
|
|
for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]]
|
|
for ds_idx in range(start_ds_idx, end_ds_idx + 1)
|
|
}
|
|
)
|