mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-10 22:10:22 +00:00
254 lines
8.6 KiB
Python
254 lines
8.6 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from . import FairseqDataset, data_utils
|
|
|
|
|
|
def collate(samples, pad_idx, eos_idx, fixed_pad_length=None, pad_to_bsz=None):
|
|
if len(samples) == 0:
|
|
return {}
|
|
|
|
def merge(key, is_list=False):
|
|
if is_list:
|
|
res = []
|
|
for i in range(len(samples[0][key])):
|
|
res.append(
|
|
data_utils.collate_tokens(
|
|
[s[key][i] for s in samples],
|
|
pad_idx,
|
|
eos_idx,
|
|
left_pad=False,
|
|
pad_to_length=fixed_pad_length,
|
|
pad_to_bsz=pad_to_bsz,
|
|
)
|
|
)
|
|
return res
|
|
else:
|
|
return data_utils.collate_tokens(
|
|
[s[key] for s in samples],
|
|
pad_idx,
|
|
eos_idx,
|
|
left_pad=False,
|
|
pad_to_length=fixed_pad_length,
|
|
pad_to_bsz=pad_to_bsz,
|
|
)
|
|
|
|
src_tokens = merge("source")
|
|
if samples[0]["target"] is not None:
|
|
is_target_list = isinstance(samples[0]["target"], list)
|
|
target = merge("target", is_target_list)
|
|
else:
|
|
target = src_tokens
|
|
|
|
return {
|
|
"id": torch.LongTensor([s["id"] for s in samples]),
|
|
"nsentences": len(samples),
|
|
"ntokens": sum(len(s["source"]) for s in samples),
|
|
"net_input": {
|
|
"src_tokens": src_tokens,
|
|
"src_lengths": torch.LongTensor([s["source"].numel() for s in samples]),
|
|
},
|
|
"target": target,
|
|
}
|
|
|
|
|
|
class MonolingualDataset(FairseqDataset):
|
|
"""
|
|
A wrapper around torch.utils.data.Dataset for monolingual data.
|
|
|
|
Args:
|
|
dataset (torch.utils.data.Dataset): dataset to wrap
|
|
sizes (List[int]): sentence lengths
|
|
vocab (~fairseq.data.Dictionary): vocabulary
|
|
shuffle (bool, optional): shuffle the elements before batching
|
|
(default: True).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dataset,
|
|
sizes,
|
|
src_vocab,
|
|
tgt_vocab=None,
|
|
add_eos_for_other_targets=False,
|
|
shuffle=False,
|
|
targets=None,
|
|
add_bos_token=False,
|
|
fixed_pad_length=None,
|
|
pad_to_bsz=None,
|
|
src_lang_idx=None,
|
|
tgt_lang_idx=None,
|
|
):
|
|
self.dataset = dataset
|
|
self.sizes = np.array(sizes)
|
|
self.vocab = src_vocab
|
|
self.tgt_vocab = tgt_vocab or src_vocab
|
|
self.add_eos_for_other_targets = add_eos_for_other_targets
|
|
self.shuffle = shuffle
|
|
self.add_bos_token = add_bos_token
|
|
self.fixed_pad_length = fixed_pad_length
|
|
self.pad_to_bsz = pad_to_bsz
|
|
self.src_lang_idx = src_lang_idx
|
|
self.tgt_lang_idx = tgt_lang_idx
|
|
|
|
assert targets is None or all(
|
|
t in {"self", "future", "past"} for t in targets
|
|
), "targets must be none or one of 'self', 'future', 'past'"
|
|
if targets is not None and len(targets) == 0:
|
|
targets = None
|
|
self.targets = targets
|
|
|
|
def __getitem__(self, index):
|
|
if self.targets is not None:
|
|
# *future_target* is the original sentence
|
|
# *source* is shifted right by 1 (maybe left-padded with eos)
|
|
# *past_target* is shifted right by 2 (left-padded as needed)
|
|
#
|
|
# Left-to-right language models should condition on *source* and
|
|
# predict *future_target*.
|
|
# Right-to-left language models should condition on *source* and
|
|
# predict *past_target*.
|
|
source, future_target, past_target = self.dataset[index]
|
|
source, target = self._make_source_target(
|
|
source, future_target, past_target
|
|
)
|
|
else:
|
|
source = self.dataset[index]
|
|
target = None
|
|
source, target = self._maybe_add_bos(source, target)
|
|
return {"id": index, "source": source, "target": target}
|
|
|
|
def __len__(self):
|
|
return len(self.dataset)
|
|
|
|
def _make_source_target(self, source, future_target, past_target):
|
|
if self.targets is not None:
|
|
target = []
|
|
|
|
if (
|
|
self.add_eos_for_other_targets
|
|
and (("self" in self.targets) or ("past" in self.targets))
|
|
and source[-1] != self.vocab.eos()
|
|
):
|
|
# append eos at the end of source
|
|
source = torch.cat([source, source.new([self.vocab.eos()])])
|
|
|
|
if "future" in self.targets:
|
|
future_target = torch.cat(
|
|
[future_target, future_target.new([self.vocab.pad()])]
|
|
)
|
|
if "past" in self.targets:
|
|
# first token is before the start of sentence which is only used in "none" break mode when
|
|
# add_eos_for_other_targets is False
|
|
past_target = torch.cat(
|
|
[
|
|
past_target.new([self.vocab.pad()]),
|
|
past_target[1:],
|
|
source[-2, None],
|
|
]
|
|
)
|
|
|
|
for t in self.targets:
|
|
if t == "self":
|
|
target.append(source)
|
|
elif t == "future":
|
|
target.append(future_target)
|
|
elif t == "past":
|
|
target.append(past_target)
|
|
else:
|
|
raise Exception("invalid target " + t)
|
|
|
|
if len(target) == 1:
|
|
target = target[0]
|
|
else:
|
|
target = future_target
|
|
|
|
return source, self._filter_vocab(target)
|
|
|
|
def _maybe_add_bos(self, source, target):
|
|
if self.add_bos_token:
|
|
source = torch.cat([source.new([self.vocab.bos()]), source])
|
|
if target is not None:
|
|
target = torch.cat([target.new([self.tgt_vocab.bos()]), target])
|
|
return source, target
|
|
|
|
def num_tokens_vec(self, indices):
|
|
"""Return the number of tokens for a set of positions defined by indices.
|
|
This value is used to enforce ``--max-tokens`` during batching."""
|
|
return self.sizes[indices]
|
|
|
|
def _filter_vocab(self, target):
|
|
if len(self.tgt_vocab) != len(self.vocab):
|
|
|
|
def _filter(target):
|
|
mask = target.ge(len(self.tgt_vocab))
|
|
if mask.any():
|
|
target[mask] = self.tgt_vocab.unk()
|
|
return target
|
|
|
|
if isinstance(target, list):
|
|
return [_filter(t) for t in target]
|
|
return _filter(target)
|
|
return target
|
|
|
|
def collater(self, samples):
|
|
"""Merge a list of samples to form a mini-batch.
|
|
|
|
Args:
|
|
samples (List[dict]): samples to collate
|
|
|
|
Returns:
|
|
dict: a mini-batch with the following keys:
|
|
|
|
- `id` (LongTensor): example IDs in the original input order
|
|
- `ntokens` (int): total number of tokens in the batch
|
|
- `net_input` (dict): the input to the Model, containing keys:
|
|
|
|
- `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
|
|
the source sentence of shape `(bsz, src_len)`. Padding will
|
|
appear on the right.
|
|
|
|
- `target` (LongTensor): a padded 2D Tensor of tokens in the
|
|
target sentence of shape `(bsz, tgt_len)`. Padding will appear
|
|
on the right.
|
|
"""
|
|
return collate(
|
|
samples,
|
|
self.vocab.pad(),
|
|
self.vocab.eos(),
|
|
self.fixed_pad_length,
|
|
self.pad_to_bsz,
|
|
)
|
|
|
|
def num_tokens(self, index):
|
|
"""Return the number of tokens in a sample. This value is used to
|
|
enforce ``--max-tokens`` during batching."""
|
|
return self.sizes[index]
|
|
|
|
def size(self, index):
|
|
"""Return an example's size as a float or tuple. This value is used when
|
|
filtering a dataset with ``--max-positions``."""
|
|
return self.sizes[index]
|
|
|
|
def ordered_indices(self):
|
|
"""Return an ordered list of indices. Batches will be constructed based
|
|
on this order."""
|
|
if self.shuffle:
|
|
order = [np.random.permutation(len(self))]
|
|
else:
|
|
order = [np.arange(len(self))]
|
|
order.append(self.sizes)
|
|
return np.lexsort(order)
|
|
|
|
@property
|
|
def supports_prefetch(self):
|
|
return getattr(self.dataset, "supports_prefetch", False)
|
|
|
|
def prefetch(self, indices):
|
|
self.dataset.prefetch(indices)
|