mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-28 10:21:20 +00:00
294 lines
11 KiB
Python
294 lines
11 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from . import Dictionary, FairseqDataset, data_utils
|
|
|
|
|
|
def collate(
|
|
samples,
|
|
pad_idx,
|
|
eos_idx,
|
|
vocab,
|
|
left_pad_source=False,
|
|
left_pad_target=False,
|
|
input_feeding=True,
|
|
pad_to_length=None,
|
|
):
|
|
assert input_feeding
|
|
if len(samples) == 0:
|
|
return {}
|
|
|
|
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
|
|
return data_utils.collate_tokens(
|
|
[s[key] for s in samples],
|
|
pad_idx,
|
|
eos_idx=None, # use eos_idx of each sample instead of vocab.eos()
|
|
left_pad=left_pad,
|
|
move_eos_to_beginning=move_eos_to_beginning,
|
|
pad_to_length=pad_to_length,
|
|
)
|
|
|
|
id = torch.LongTensor([s["id"] for s in samples])
|
|
src_tokens = merge(
|
|
"source",
|
|
left_pad=left_pad_source,
|
|
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
|
|
)
|
|
# sort by descending source length
|
|
src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
|
|
src_lengths, sort_order = src_lengths.sort(descending=True)
|
|
id = id.index_select(0, sort_order)
|
|
src_tokens = src_tokens.index_select(0, sort_order)
|
|
|
|
prev_output_tokens = None
|
|
target = None
|
|
if samples[0].get("target", None) is not None:
|
|
target = merge(
|
|
"target",
|
|
left_pad=left_pad_target,
|
|
pad_to_length=pad_to_length["target"]
|
|
if pad_to_length is not None
|
|
else None,
|
|
)
|
|
target = target.index_select(0, sort_order)
|
|
ntokens = sum(len(s["target"]) for s in samples)
|
|
|
|
if input_feeding:
|
|
# we create a shifted version of targets for feeding the
|
|
# previous output token(s) into the next decoder step
|
|
prev_output_tokens = merge(
|
|
"target",
|
|
left_pad=left_pad_target,
|
|
move_eos_to_beginning=True,
|
|
pad_to_length=pad_to_length["target"]
|
|
if pad_to_length is not None
|
|
else None,
|
|
)
|
|
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
|
|
else:
|
|
ntokens = sum(len(s["source"]) for s in samples)
|
|
|
|
batch = {
|
|
"id": id,
|
|
"ntokens": ntokens,
|
|
"net_input": {
|
|
"src_tokens": src_tokens,
|
|
"src_lengths": src_lengths,
|
|
},
|
|
"target": target,
|
|
"target_lengths": torch.LongTensor([len(t) for t in target]),
|
|
"nsentences": samples[0]["source"].size(0),
|
|
"sort_order": sort_order,
|
|
}
|
|
if prev_output_tokens is not None:
|
|
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
|
|
|
|
return batch
|
|
|
|
|
|
class SpanMaskedTokensDataset(FairseqDataset):
|
|
"""
|
|
A wrapper around TokenBlockDataset for T5 dataset.
|
|
|
|
Args:
|
|
dataset (~torch.utils.data.Dataset): dataset to wrap
|
|
vocab (~fairseq.data.Dictionary): vocabulary
|
|
noise_density (float): fraction of the tokens to select as noise.
|
|
mean_noise_span_length (float): mean noise span length.
|
|
shuffle (bool, optional): shuffle the elements before batching.
|
|
Default: ``True``
|
|
seed: Seed for random number generator for reproducibility.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dataset: torch.utils.data.Dataset,
|
|
vocab: Dictionary,
|
|
noise_density: float,
|
|
mean_noise_span_length: float,
|
|
shuffle: bool,
|
|
seed: int = 1,
|
|
):
|
|
self.dataset = dataset
|
|
self.vocab = vocab
|
|
self.seed = seed
|
|
self.noise_density = noise_density
|
|
self.mean_noise_span_length = mean_noise_span_length
|
|
self.shuffle = shuffle
|
|
self.epoch = 0
|
|
|
|
@property
|
|
def can_reuse_epoch_itr_across_epochs(self):
|
|
return True # only the noise changes, not item sizes
|
|
|
|
def set_epoch(self, epoch, **unused):
|
|
self.epoch = epoch
|
|
|
|
def __getitem__(self, index):
|
|
with data_utils.numpy_seed(self.seed, self.epoch, index):
|
|
item = self.dataset[index]
|
|
assert item[-1] == self.vocab.eos()
|
|
|
|
noise_mask = self.random_spans_noise_mask(len(item))
|
|
|
|
source_sentinel_ids = self.create_sentinel_ids(noise_mask.astype(np.int8))
|
|
source = self.filter_input_ids(item, source_sentinel_ids)
|
|
|
|
target_sentinel_ids = self.create_sentinel_ids(
|
|
(~noise_mask).astype(np.int8)
|
|
)
|
|
target = self.filter_input_ids(item, target_sentinel_ids)
|
|
|
|
return {
|
|
"id": index,
|
|
"source": torch.from_numpy(source),
|
|
"target": torch.from_numpy(target),
|
|
}
|
|
|
|
def random_spans_noise_mask(self, length):
|
|
|
|
"""
|
|
This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
|
|
Noise mask consisting of random spans of noise tokens.
|
|
The number of noise tokens and the number of noise spans and non-noise spans
|
|
are determined deterministically as follows:
|
|
num_noise_tokens = round(length * noise_density)
|
|
num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
|
|
Spans alternate between non-noise and noise, beginning with non-noise.
|
|
Subject to the above restrictions, all masks are equally likely.
|
|
Args:
|
|
length: an int32 scalar (length of the incoming token sequence)
|
|
Returns:
|
|
a boolean tensor with shape [length]
|
|
"""
|
|
|
|
orig_length = length
|
|
|
|
num_noise_tokens = int(np.round(length * self.noise_density))
|
|
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
|
|
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
|
|
num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length))
|
|
|
|
# avoid degeneracy by ensuring positive number of noise spans
|
|
num_noise_spans = max(num_noise_spans, 1)
|
|
num_nonnoise_tokens = length - num_noise_tokens
|
|
|
|
# pick the lengths of the noise spans and the non-noise spans
|
|
def _random_segmentation(num_items, num_segments):
|
|
"""
|
|
Partition a sequence of items randomly into non-empty segments.
|
|
Args:
|
|
num_items: an integer scalar > 0
|
|
num_segments: an integer scalar in [1, num_items]
|
|
Returns:
|
|
a Tensor with shape [num_segments] containing positive integers that add up to num_items
|
|
"""
|
|
mask_indices = np.arange(num_items - 1) < (num_segments - 1)
|
|
np.random.shuffle(mask_indices)
|
|
first_in_segment = np.pad(mask_indices, [[1, 0]])
|
|
segment_id = np.cumsum(first_in_segment)
|
|
# count length of subsegments assuming that list is sorted
|
|
_, segment_length = np.unique(segment_id, return_counts=True)
|
|
return segment_length
|
|
|
|
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
|
|
nonnoise_span_lengths = _random_segmentation(
|
|
num_nonnoise_tokens, num_noise_spans
|
|
)
|
|
|
|
interleaved_span_lengths = np.reshape(
|
|
np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1),
|
|
[num_noise_spans * 2],
|
|
)
|
|
span_starts = np.cumsum(interleaved_span_lengths)[:-1]
|
|
span_start_indicator = np.zeros((length,), dtype=np.int8)
|
|
span_start_indicator[span_starts] = True
|
|
span_num = np.cumsum(span_start_indicator)
|
|
is_noise = np.equal(span_num % 2, 1)
|
|
|
|
return is_noise[:orig_length]
|
|
|
|
def create_sentinel_ids(self, mask_indices):
|
|
"""
|
|
Sentinel ids creation given the indices that should be masked.
|
|
The start indices of each mask are replaced by the sentinel ids in increasing
|
|
order. Consecutive mask indices to be deleted are replaced with `-1`.
|
|
"""
|
|
start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
|
|
|
|
sentinel_ids = np.where(
|
|
start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices
|
|
)
|
|
# making sure all sentinel tokens are unique over the example
|
|
sentinel_ids = np.where(sentinel_ids != 0, len(self.vocab) - sentinel_ids, 0)
|
|
sentinel_ids -= mask_indices - start_indices
|
|
return sentinel_ids
|
|
|
|
@staticmethod
|
|
def filter_input_ids(input_ids, sentinel_ids):
|
|
"""
|
|
Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
|
|
This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
|
|
"""
|
|
input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
|
|
|
|
# input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
|
|
# masked tokens coming after sentinel tokens and should be removed
|
|
return input_ids_full[input_ids_full >= 0]
|
|
|
|
def __len__(self):
|
|
return len(self.dataset)
|
|
|
|
def collater(self, samples, pad_to_length=None):
|
|
"""
|
|
Merge a list of samples to form a mini-batch.
|
|
Args:
|
|
samples (List[dict]): samples to collate
|
|
Returns:
|
|
dict: a mini-batch of data
|
|
"""
|
|
return collate(
|
|
samples,
|
|
self.vocab.pad(),
|
|
self.vocab.eos(),
|
|
self.vocab,
|
|
pad_to_length=pad_to_length,
|
|
)
|
|
|
|
def num_tokens(self, index):
|
|
"""Return the number of tokens in a sample. This value is used to
|
|
enforce ``--max-tokens`` during batching."""
|
|
return self.dataset.sizes[index]
|
|
|
|
def size(self, index):
|
|
"""Return an example's size as a float or tuple. This value is used when
|
|
filtering a dataset with ``--max-positions``."""
|
|
return self.dataset.sizes[index]
|
|
|
|
def ordered_indices(self):
|
|
"""Return an ordered list of indices. Batches will be constructed based
|
|
on this order."""
|
|
if self.shuffle:
|
|
indices = np.random.permutation(len(self))
|
|
else:
|
|
indices = np.arange(len(self))
|
|
return indices[np.argsort(self.dataset.sizes[indices], kind="mergesort")]
|
|
|
|
def prefetch(self, indices):
|
|
self.src.prefetch(indices)
|
|
self.tgt.prefetch(indices)
|
|
|
|
@property
|
|
def supports_prefetch(self):
|
|
return (
|
|
hasattr(self.src, "supports_prefetch")
|
|
and self.src.supports_prefetch
|
|
and hasattr(self.tgt, "supports_prefetch")
|
|
and self.tgt.supports_prefetch
|
|
)
|