mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-09 13:30:22 +00:00
335 lines
12 KiB
Python
335 lines
12 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import numpy as np
|
|
import torch
|
|
from fairseq.data import data_utils
|
|
|
|
|
|
class WordNoising(object):
|
|
"""Generate a noisy version of a sentence, without changing words themselves."""
|
|
|
|
def __init__(self, dictionary, bpe_cont_marker="@@", bpe_end_marker=None):
|
|
self.dictionary = dictionary
|
|
self.bpe_end = None
|
|
if bpe_cont_marker:
|
|
self.bpe_end = np.array(
|
|
[
|
|
not self.dictionary[i].endswith(bpe_cont_marker)
|
|
for i in range(len(self.dictionary))
|
|
]
|
|
)
|
|
elif bpe_end_marker:
|
|
self.bpe_end = np.array(
|
|
[
|
|
self.dictionary[i].endswith(bpe_end_marker)
|
|
for i in range(len(self.dictionary))
|
|
]
|
|
)
|
|
|
|
self.get_word_idx = (
|
|
self._get_bpe_word_idx if self.bpe_end is not None else self._get_token_idx
|
|
)
|
|
|
|
def noising(self, x, lengths, noising_prob=0.0):
|
|
raise NotImplementedError()
|
|
|
|
def _get_bpe_word_idx(self, x):
|
|
"""
|
|
Given a list of BPE tokens, for every index in the tokens list,
|
|
return the index of the word grouping that it belongs to.
|
|
For example, for input x corresponding to ["how", "are", "y@@", "ou"],
|
|
return [[0], [1], [2], [2]].
|
|
"""
|
|
# x: (T x B)
|
|
bpe_end = self.bpe_end[x]
|
|
|
|
if x.size(0) == 1 and x.size(1) == 1:
|
|
# Special case when we only have one word in x. If x = [[N]],
|
|
# bpe_end is a scalar (bool) instead of a 2-dim array of bools,
|
|
# which makes the sum operation below fail.
|
|
return np.array([[0]])
|
|
|
|
# do a reduce front sum to generate word ids
|
|
word_idx = bpe_end[::-1].cumsum(0)[::-1]
|
|
word_idx = word_idx.max(0)[None, :] - word_idx
|
|
return word_idx
|
|
|
|
def _get_token_idx(self, x):
|
|
"""
|
|
This is to extend noising functions to be able to apply to non-bpe
|
|
tokens, e.g. word or characters.
|
|
"""
|
|
x = torch.t(x)
|
|
word_idx = np.array([range(len(x_i)) for x_i in x])
|
|
return np.transpose(word_idx)
|
|
|
|
|
|
class WordDropout(WordNoising):
|
|
"""Randomly drop input words. If not passing blank_idx (default is None),
|
|
then dropped words will be removed. Otherwise, it will be replaced by the
|
|
blank_idx."""
|
|
|
|
def __init__(
|
|
self,
|
|
dictionary,
|
|
default_dropout_prob=0.1,
|
|
bpe_cont_marker="@@",
|
|
bpe_end_marker=None,
|
|
):
|
|
super().__init__(dictionary, bpe_cont_marker, bpe_end_marker)
|
|
self.default_dropout_prob = default_dropout_prob
|
|
|
|
def noising(self, x, lengths, dropout_prob=None, blank_idx=None):
|
|
if dropout_prob is None:
|
|
dropout_prob = self.default_dropout_prob
|
|
# x: (T x B), lengths: B
|
|
if dropout_prob == 0:
|
|
return x, lengths
|
|
|
|
assert 0 < dropout_prob < 1
|
|
|
|
# be sure to drop entire words
|
|
word_idx = self.get_word_idx(x)
|
|
sentences = []
|
|
modified_lengths = []
|
|
for i in range(lengths.size(0)):
|
|
# Since dropout probabilities need to apply over non-pad tokens,
|
|
# it is not trivial to generate the keep mask without consider
|
|
# input lengths; otherwise, this could be done outside the loop
|
|
|
|
# We want to drop whole words based on word_idx grouping
|
|
num_words = max(word_idx[:, i]) + 1
|
|
|
|
# ith example: [x0, x1, ..., eos, pad, ..., pad]
|
|
# We should only generate keep probs for non-EOS tokens. Thus if the
|
|
# input sentence ends in EOS, the last word idx is not included in
|
|
# the dropout mask generation and we append True to always keep EOS.
|
|
# Otherwise, just generate the dropout mask for all word idx
|
|
# positions.
|
|
has_eos = x[lengths[i] - 1, i] == self.dictionary.eos()
|
|
if has_eos: # has eos?
|
|
keep = np.random.rand(num_words - 1) >= dropout_prob
|
|
keep = np.append(keep, [True]) # keep EOS symbol
|
|
else:
|
|
keep = np.random.rand(num_words) >= dropout_prob
|
|
|
|
words = x[: lengths[i], i].tolist()
|
|
|
|
# TODO: speed up the following loop
|
|
# drop words from the input according to keep
|
|
new_s = [
|
|
w if keep[word_idx[j, i]] else blank_idx for j, w in enumerate(words)
|
|
]
|
|
new_s = [w for w in new_s if w is not None]
|
|
# we need to have at least one word in the sentence (more than the
|
|
# start / end sentence symbols)
|
|
if len(new_s) <= 1:
|
|
# insert at beginning in case the only token left is EOS
|
|
# EOS should be at end of list.
|
|
new_s.insert(0, words[np.random.randint(0, len(words))])
|
|
assert len(new_s) >= 1 and (
|
|
not has_eos # Either don't have EOS at end or last token is EOS
|
|
or (len(new_s) >= 2 and new_s[-1] == self.dictionary.eos())
|
|
), "New sentence is invalid."
|
|
sentences.append(new_s)
|
|
modified_lengths.append(len(new_s))
|
|
# re-construct input
|
|
modified_lengths = torch.LongTensor(modified_lengths)
|
|
modified_x = torch.LongTensor(
|
|
modified_lengths.max(), modified_lengths.size(0)
|
|
).fill_(self.dictionary.pad())
|
|
for i in range(modified_lengths.size(0)):
|
|
modified_x[: modified_lengths[i], i].copy_(torch.LongTensor(sentences[i]))
|
|
|
|
return modified_x, modified_lengths
|
|
|
|
|
|
class WordShuffle(WordNoising):
|
|
"""Shuffle words by no more than k positions."""
|
|
|
|
def __init__(
|
|
self,
|
|
dictionary,
|
|
default_max_shuffle_distance=3,
|
|
bpe_cont_marker="@@",
|
|
bpe_end_marker=None,
|
|
):
|
|
super().__init__(dictionary, bpe_cont_marker, bpe_end_marker)
|
|
self.default_max_shuffle_distance = 3
|
|
|
|
def noising(self, x, lengths, max_shuffle_distance=None):
|
|
if max_shuffle_distance is None:
|
|
max_shuffle_distance = self.default_max_shuffle_distance
|
|
# x: (T x B), lengths: B
|
|
if max_shuffle_distance == 0:
|
|
return x, lengths
|
|
|
|
# max_shuffle_distance < 1 will return the same sequence
|
|
assert max_shuffle_distance > 1
|
|
|
|
# define noise word scores
|
|
noise = np.random.uniform(
|
|
0,
|
|
max_shuffle_distance,
|
|
size=(x.size(0), x.size(1)),
|
|
)
|
|
noise[0] = -1 # do not move start sentence symbol
|
|
# be sure to shuffle entire words
|
|
word_idx = self.get_word_idx(x)
|
|
x2 = x.clone()
|
|
for i in range(lengths.size(0)):
|
|
length_no_eos = lengths[i]
|
|
if x[lengths[i] - 1, i] == self.dictionary.eos():
|
|
length_no_eos = lengths[i] - 1
|
|
# generate a random permutation
|
|
scores = word_idx[:length_no_eos, i] + noise[word_idx[:length_no_eos, i], i]
|
|
# ensure no reordering inside a word
|
|
scores += 1e-6 * np.arange(length_no_eos.item())
|
|
permutation = scores.argsort()
|
|
# shuffle words
|
|
x2[:length_no_eos, i].copy_(
|
|
x2[:length_no_eos, i][torch.from_numpy(permutation)]
|
|
)
|
|
return x2, lengths
|
|
|
|
|
|
class UnsupervisedMTNoising(WordNoising):
|
|
"""
|
|
Implements the default configuration for noising in UnsupervisedMT
|
|
(github.com/facebookresearch/UnsupervisedMT)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dictionary,
|
|
max_word_shuffle_distance,
|
|
word_dropout_prob,
|
|
word_blanking_prob,
|
|
bpe_cont_marker="@@",
|
|
bpe_end_marker=None,
|
|
):
|
|
super().__init__(dictionary)
|
|
self.max_word_shuffle_distance = max_word_shuffle_distance
|
|
self.word_dropout_prob = word_dropout_prob
|
|
self.word_blanking_prob = word_blanking_prob
|
|
|
|
self.word_dropout = WordDropout(
|
|
dictionary=dictionary,
|
|
bpe_cont_marker=bpe_cont_marker,
|
|
bpe_end_marker=bpe_end_marker,
|
|
)
|
|
self.word_shuffle = WordShuffle(
|
|
dictionary=dictionary,
|
|
bpe_cont_marker=bpe_cont_marker,
|
|
bpe_end_marker=bpe_end_marker,
|
|
)
|
|
|
|
def noising(self, x, lengths):
|
|
# 1. Word Shuffle
|
|
noisy_src_tokens, noisy_src_lengths = self.word_shuffle.noising(
|
|
x=x,
|
|
lengths=lengths,
|
|
max_shuffle_distance=self.max_word_shuffle_distance,
|
|
)
|
|
# 2. Word Dropout
|
|
noisy_src_tokens, noisy_src_lengths = self.word_dropout.noising(
|
|
x=noisy_src_tokens,
|
|
lengths=noisy_src_lengths,
|
|
dropout_prob=self.word_dropout_prob,
|
|
)
|
|
# 3. Word Blanking
|
|
noisy_src_tokens, noisy_src_lengths = self.word_dropout.noising(
|
|
x=noisy_src_tokens,
|
|
lengths=noisy_src_lengths,
|
|
dropout_prob=self.word_blanking_prob,
|
|
blank_idx=self.dictionary.unk(),
|
|
)
|
|
|
|
return noisy_src_tokens
|
|
|
|
|
|
class NoisingDataset(torch.utils.data.Dataset):
|
|
def __init__(
|
|
self,
|
|
src_dataset,
|
|
src_dict,
|
|
seed,
|
|
noiser=None,
|
|
noising_class=UnsupervisedMTNoising,
|
|
**kwargs
|
|
):
|
|
"""
|
|
Wrap a :class:`~torch.utils.data.Dataset` and apply noise to the
|
|
samples based on the supplied noising configuration.
|
|
|
|
Args:
|
|
src_dataset (~torch.utils.data.Dataset): dataset to wrap.
|
|
to build self.src_dataset --
|
|
a LanguagePairDataset with src dataset as the source dataset and
|
|
None as the target dataset. Should NOT have padding so that
|
|
src_lengths are accurately calculated by language_pair_dataset
|
|
collate function.
|
|
We use language_pair_dataset here to encapsulate the tgt_dataset
|
|
so we can re-use the LanguagePairDataset collater to format the
|
|
batches in the structure that SequenceGenerator expects.
|
|
src_dict (~fairseq.data.Dictionary): source dictionary
|
|
seed (int): seed to use when generating random noise
|
|
noiser (WordNoising): a pre-initialized :class:`WordNoising`
|
|
instance. If this is None, a new instance will be created using
|
|
*noising_class* and *kwargs*.
|
|
noising_class (class, optional): class to use to initialize a
|
|
default :class:`WordNoising` instance.
|
|
kwargs (dict, optional): arguments to initialize the default
|
|
:class:`WordNoising` instance given by *noiser*.
|
|
"""
|
|
self.src_dataset = src_dataset
|
|
self.src_dict = src_dict
|
|
self.seed = seed
|
|
self.noiser = (
|
|
noiser
|
|
if noiser is not None
|
|
else noising_class(
|
|
dictionary=src_dict,
|
|
**kwargs,
|
|
)
|
|
)
|
|
self.sizes = src_dataset.sizes
|
|
|
|
def __getitem__(self, index):
|
|
"""
|
|
Returns a single noisy sample. Multiple samples are fed to the collater
|
|
create a noising dataset batch.
|
|
"""
|
|
src_tokens = self.src_dataset[index]
|
|
src_lengths = torch.LongTensor([len(src_tokens)])
|
|
src_tokens = src_tokens.unsqueeze(0)
|
|
|
|
# Transpose src tokens to fit expected shape of x in noising function
|
|
# (batch size, sequence length) -> (sequence length, batch size)
|
|
src_tokens_t = torch.t(src_tokens)
|
|
|
|
with data_utils.numpy_seed(self.seed + index):
|
|
noisy_src_tokens = self.noiser.noising(src_tokens_t, src_lengths)
|
|
|
|
# Transpose back to expected src_tokens format
|
|
# (sequence length, 1) -> (1, sequence length)
|
|
noisy_src_tokens = torch.t(noisy_src_tokens)
|
|
return noisy_src_tokens[0]
|
|
|
|
def __len__(self):
|
|
"""
|
|
The length of the noising dataset is the length of src.
|
|
"""
|
|
return len(self.src_dataset)
|
|
|
|
@property
|
|
def supports_prefetch(self):
|
|
return self.src_dataset.supports_prefetch
|
|
|
|
def prefetch(self, indices):
|
|
if self.src_dataset.supports_prefetch:
|
|
self.src_dataset.prefetch(indices)
|