mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-11 14:30:03 +00:00
444 lines
15 KiB
Python
444 lines
15 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import math
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from . import FairseqDataset, data_utils
|
|
|
|
|
|
def collate(
|
|
samples,
|
|
pad_idx,
|
|
eos_idx,
|
|
vocab,
|
|
left_pad_source=False,
|
|
left_pad_target=False,
|
|
input_feeding=True,
|
|
pad_to_length=None,
|
|
):
|
|
assert input_feeding
|
|
if len(samples) == 0:
|
|
return {}
|
|
|
|
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
|
|
return data_utils.collate_tokens(
|
|
[s[key] for s in samples],
|
|
pad_idx,
|
|
eos_idx=None, # use eos_idx of each sample instead of vocab.eos()
|
|
left_pad=left_pad,
|
|
move_eos_to_beginning=move_eos_to_beginning,
|
|
pad_to_length=pad_to_length,
|
|
)
|
|
|
|
id = torch.LongTensor([s["id"] for s in samples])
|
|
src_tokens = merge(
|
|
"source",
|
|
left_pad=left_pad_source,
|
|
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
|
|
)
|
|
# sort by descending source length
|
|
src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
|
|
src_lengths, sort_order = src_lengths.sort(descending=True)
|
|
id = id.index_select(0, sort_order)
|
|
src_tokens = src_tokens.index_select(0, sort_order)
|
|
|
|
prev_output_tokens = None
|
|
target = None
|
|
if samples[0].get("target", None) is not None:
|
|
target = merge(
|
|
"target",
|
|
left_pad=left_pad_target,
|
|
pad_to_length=pad_to_length["target"]
|
|
if pad_to_length is not None
|
|
else None,
|
|
)
|
|
target = target.index_select(0, sort_order)
|
|
ntokens = sum(len(s["target"]) for s in samples)
|
|
|
|
if input_feeding:
|
|
# we create a shifted version of targets for feeding the
|
|
# previous output token(s) into the next decoder step
|
|
prev_output_tokens = merge(
|
|
"target",
|
|
left_pad=left_pad_target,
|
|
move_eos_to_beginning=True,
|
|
pad_to_length=pad_to_length["target"]
|
|
if pad_to_length is not None
|
|
else None,
|
|
)
|
|
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
|
|
else:
|
|
ntokens = sum(len(s["source"]) for s in samples)
|
|
|
|
batch = {
|
|
"id": id,
|
|
"ntokens": ntokens,
|
|
"net_input": {
|
|
"src_tokens": src_tokens,
|
|
"src_lengths": src_lengths,
|
|
},
|
|
"target": target,
|
|
"nsentences": samples[0]["source"].size(0),
|
|
"sort_order": sort_order,
|
|
}
|
|
if prev_output_tokens is not None:
|
|
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
|
|
|
|
return batch
|
|
|
|
|
|
class DenoisingDataset(FairseqDataset):
|
|
"""
|
|
A wrapper around TokenBlockDataset for BART dataset.
|
|
|
|
Args:
|
|
dataset (TokenBlockDataset): dataset to wrap
|
|
sizes (List[int]): sentence lengths
|
|
vocab (~fairseq.data.Dictionary): vocabulary
|
|
mask_idx (int): dictionary index used for masked token
|
|
mask_whole_words: only mask whole words. This should be a byte mask
|
|
over vocab indices, indicating whether it is the beginning of a
|
|
word. We will extend any mask to encompass the whole word.
|
|
shuffle (bool, optional): shuffle the elements before batching.
|
|
Default: ``True``
|
|
seed: Seed for random number generator for reproducibility.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dataset,
|
|
sizes,
|
|
vocab,
|
|
mask_idx,
|
|
mask_whole_words,
|
|
shuffle,
|
|
seed,
|
|
mask,
|
|
mask_random,
|
|
insert,
|
|
rotate,
|
|
permute_sentences,
|
|
bpe,
|
|
replace_length,
|
|
mask_length,
|
|
poisson_lambda,
|
|
eos=None,
|
|
item_transform_func=None,
|
|
):
|
|
self.dataset = dataset
|
|
|
|
self.sizes = sizes
|
|
|
|
self.vocab = vocab
|
|
self.shuffle = shuffle
|
|
self.seed = seed
|
|
self.mask_idx = mask_idx
|
|
self.mask_whole_word = mask_whole_words
|
|
self.mask_ratio = mask
|
|
self.random_ratio = mask_random
|
|
self.insert_ratio = insert
|
|
self.rotate_ratio = rotate
|
|
self.permute_sentence_ratio = permute_sentences
|
|
self.eos = eos if eos is not None else vocab.eos()
|
|
self.item_transform_func = item_transform_func
|
|
|
|
if bpe != "gpt2":
|
|
self.full_stop_index = self.vocab.eos()
|
|
else:
|
|
assert bpe == "gpt2"
|
|
self.full_stop_index = self.vocab.index("13")
|
|
|
|
self.replace_length = replace_length
|
|
if self.replace_length not in [-1, 0, 1]:
|
|
raise ValueError(f"invalid arg: replace_length={self.replace_length}")
|
|
if mask_length not in ["subword", "word", "span-poisson"]:
|
|
raise ValueError(f"invalid arg: mask-length={mask_length}")
|
|
if mask_length == "subword" and replace_length not in [0, 1]:
|
|
raise ValueError(f"if using subwords, use replace-length=1 or 0")
|
|
|
|
self.mask_span_distribution = None
|
|
if mask_length == "span-poisson":
|
|
_lambda = poisson_lambda
|
|
|
|
lambda_to_the_k = 1
|
|
e_to_the_minus_lambda = math.exp(-_lambda)
|
|
k_factorial = 1
|
|
ps = []
|
|
for k in range(0, 128):
|
|
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
|
|
lambda_to_the_k *= _lambda
|
|
k_factorial *= k + 1
|
|
if ps[-1] < 0.0000001:
|
|
break
|
|
ps = torch.FloatTensor(ps)
|
|
self.mask_span_distribution = torch.distributions.Categorical(ps)
|
|
|
|
self.epoch = 0
|
|
|
|
@property
|
|
def can_reuse_epoch_itr_across_epochs(self):
|
|
return True # only the noise changes, not item sizes
|
|
|
|
def set_epoch(self, epoch, **unused):
|
|
self.epoch = epoch
|
|
|
|
def __getitem__(self, index):
|
|
with data_utils.numpy_seed(self.seed, self.epoch, index):
|
|
tokens = self.dataset[index]
|
|
assert tokens[-1] == self.eos
|
|
source, target = tokens, tokens.clone()
|
|
|
|
if self.permute_sentence_ratio > 0.0:
|
|
source = self.permute_sentences(source, self.permute_sentence_ratio)
|
|
|
|
if self.mask_ratio > 0:
|
|
source = self.add_whole_word_mask(source, self.mask_ratio)
|
|
|
|
if self.insert_ratio > 0:
|
|
source = self.add_insertion_noise(source, self.insert_ratio)
|
|
|
|
if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio:
|
|
source = self.add_rolling_noise(source)
|
|
# there can additional changes to make:
|
|
if self.item_transform_func is not None:
|
|
source, target = self.item_transform_func(source, target)
|
|
|
|
assert (source >= 0).all()
|
|
assert (source[1:-1] >= 1).all()
|
|
assert (source <= len(self.vocab)).all()
|
|
assert source[0] == self.vocab.bos()
|
|
assert source[-1] == self.eos
|
|
return {
|
|
"id": index,
|
|
"source": source,
|
|
"target": target,
|
|
}
|
|
|
|
def __len__(self):
|
|
return len(self.dataset)
|
|
|
|
def permute_sentences(self, source, p=1.0):
|
|
full_stops = source == self.full_stop_index
|
|
# Pretend it ends with a full stop so last span is a sentence
|
|
full_stops[-2] = 1
|
|
|
|
# Tokens that are full stops, where the previous token is not
|
|
sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2
|
|
result = source.clone()
|
|
|
|
num_sentences = sentence_ends.size(0)
|
|
num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0)
|
|
substitutions = torch.randperm(num_sentences)[:num_to_permute]
|
|
ordering = torch.arange(0, num_sentences)
|
|
ordering[substitutions] = substitutions[torch.randperm(num_to_permute)]
|
|
|
|
# Ignore <bos> at start
|
|
index = 1
|
|
for i in ordering:
|
|
sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]]
|
|
result[index : index + sentence.size(0)] = sentence
|
|
index += sentence.size(0)
|
|
return result
|
|
|
|
def word_starts(self, source):
|
|
if self.mask_whole_word is not None:
|
|
is_word_start = self.mask_whole_word.gather(0, source)
|
|
else:
|
|
is_word_start = torch.ones(source.size())
|
|
is_word_start[0] = 0
|
|
is_word_start[-1] = 0
|
|
return is_word_start
|
|
|
|
def add_whole_word_mask(self, source, p):
|
|
is_word_start = self.word_starts(source)
|
|
num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
|
|
num_inserts = 0
|
|
if num_to_mask == 0:
|
|
return source
|
|
|
|
if self.mask_span_distribution is not None:
|
|
lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
|
|
|
|
# Make sure we have enough to mask
|
|
cum_length = torch.cumsum(lengths, 0)
|
|
while cum_length[-1] < num_to_mask:
|
|
lengths = torch.cat(
|
|
[
|
|
lengths,
|
|
self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
|
|
],
|
|
dim=0,
|
|
)
|
|
cum_length = torch.cumsum(lengths, 0)
|
|
|
|
# Trim to masking budget
|
|
i = 0
|
|
while cum_length[i] < num_to_mask:
|
|
i += 1
|
|
lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
|
|
num_to_mask = i + 1
|
|
lengths = lengths[:num_to_mask]
|
|
|
|
# Handle 0-length mask (inserts) separately
|
|
lengths = lengths[lengths > 0]
|
|
num_inserts = num_to_mask - lengths.size(0)
|
|
num_to_mask -= num_inserts
|
|
if num_to_mask == 0:
|
|
return self.add_insertion_noise(source, num_inserts / source.size(0))
|
|
|
|
assert (lengths > 0).all()
|
|
else:
|
|
lengths = torch.ones((num_to_mask,)).long()
|
|
assert is_word_start[-1] == 0
|
|
word_starts = is_word_start.nonzero(as_tuple=False)
|
|
indices = word_starts[
|
|
torch.randperm(word_starts.size(0))[:num_to_mask]
|
|
].squeeze(1)
|
|
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
|
|
|
|
source_length = source.size(0)
|
|
assert source_length - 1 not in indices
|
|
to_keep = torch.ones(source_length, dtype=torch.bool)
|
|
is_word_start[
|
|
-1
|
|
] = 255 # acts as a long length, so spans don't go over the end of doc
|
|
if self.replace_length == 0:
|
|
to_keep[indices] = 0
|
|
else:
|
|
# keep index, but replace it with [MASK]
|
|
source[indices] = self.mask_idx
|
|
source[indices[mask_random]] = torch.randint(
|
|
1, len(self.vocab), size=(mask_random.sum(),)
|
|
)
|
|
|
|
if self.mask_span_distribution is not None:
|
|
assert len(lengths.size()) == 1
|
|
assert lengths.size() == indices.size()
|
|
lengths -= 1
|
|
while indices.size(0) > 0:
|
|
assert lengths.size() == indices.size()
|
|
lengths -= is_word_start[indices + 1].long()
|
|
uncompleted = lengths >= 0
|
|
indices = indices[uncompleted] + 1
|
|
mask_random = mask_random[uncompleted]
|
|
lengths = lengths[uncompleted]
|
|
if self.replace_length != -1:
|
|
# delete token
|
|
to_keep[indices] = 0
|
|
else:
|
|
# keep index, but replace it with [MASK]
|
|
source[indices] = self.mask_idx
|
|
source[indices[mask_random]] = torch.randint(
|
|
1, len(self.vocab), size=(mask_random.sum(),)
|
|
)
|
|
else:
|
|
# A bit faster when all lengths are 1
|
|
while indices.size(0) > 0:
|
|
uncompleted = is_word_start[indices + 1] == 0
|
|
indices = indices[uncompleted] + 1
|
|
mask_random = mask_random[uncompleted]
|
|
if self.replace_length != -1:
|
|
# delete token
|
|
to_keep[indices] = 0
|
|
else:
|
|
# keep index, but replace it with [MASK]
|
|
source[indices] = self.mask_idx
|
|
source[indices[mask_random]] = torch.randint(
|
|
1, len(self.vocab), size=(mask_random.sum(),)
|
|
)
|
|
|
|
assert source_length - 1 not in indices
|
|
|
|
source = source[to_keep]
|
|
|
|
if num_inserts > 0:
|
|
source = self.add_insertion_noise(source, num_inserts / source.size(0))
|
|
|
|
return source
|
|
|
|
def add_permuted_noise(self, tokens, p):
|
|
num_words = len(tokens)
|
|
num_to_permute = math.ceil(((num_words * 2) * p) / 2.0)
|
|
substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1
|
|
tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]]
|
|
return tokens
|
|
|
|
def add_rolling_noise(self, tokens):
|
|
offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1)
|
|
tokens = torch.cat(
|
|
(tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]),
|
|
dim=0,
|
|
)
|
|
return tokens
|
|
|
|
def add_insertion_noise(self, tokens, p):
|
|
if p == 0.0:
|
|
return tokens
|
|
|
|
num_tokens = len(tokens)
|
|
n = int(math.ceil(num_tokens * p))
|
|
|
|
noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
|
|
noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
|
|
noise_mask[noise_indices] = 1
|
|
result = torch.LongTensor(n + len(tokens)).fill_(-1)
|
|
|
|
num_random = int(math.ceil(n * self.random_ratio))
|
|
result[noise_indices[num_random:]] = self.mask_idx
|
|
result[noise_indices[:num_random]] = torch.randint(
|
|
low=1, high=len(self.vocab), size=(num_random,)
|
|
)
|
|
|
|
result[~noise_mask] = tokens
|
|
|
|
assert (result >= 0).all()
|
|
return result
|
|
|
|
def collater(self, samples, pad_to_length=None):
|
|
"""Merge a list of samples to form a mini-batch.
|
|
Args:
|
|
samples (List[dict]): samples to collate
|
|
Returns:
|
|
dict: a mini-batch of data
|
|
"""
|
|
return collate(
|
|
samples, self.vocab.pad(), self.eos, self.vocab, pad_to_length=pad_to_length
|
|
)
|
|
|
|
def num_tokens(self, index):
|
|
"""Return the number of tokens in a sample. This value is used to
|
|
enforce ``--max-tokens`` during batching."""
|
|
return self.sizes[index]
|
|
|
|
def size(self, index):
|
|
"""Return an example's size as a float or tuple. This value is used when
|
|
filtering a dataset with ``--max-positions``."""
|
|
return self.sizes[index]
|
|
|
|
def ordered_indices(self):
|
|
"""Return an ordered list of indices. Batches will be constructed based
|
|
on this order."""
|
|
if self.shuffle:
|
|
indices = np.random.permutation(len(self))
|
|
else:
|
|
indices = np.arange(len(self))
|
|
return indices[np.argsort(self.sizes[indices], kind="mergesort")]
|
|
|
|
def prefetch(self, indices):
|
|
self.src.prefetch(indices)
|
|
self.tgt.prefetch(indices)
|
|
|
|
@property
|
|
def supports_prefetch(self):
|
|
return (
|
|
hasattr(self.src, "supports_prefetch")
|
|
and self.src.supports_prefetch
|
|
and hasattr(self.tgt, "supports_prefetch")
|
|
and self.tgt.supports_prefetch
|
|
)
|