mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-02-22 06:04:26 +00:00
126 lines
3.9 KiB
Python
126 lines
3.9 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
from torch.utils.data.dataloader import default_collate
|
|
|
|
from . import FairseqDataset
|
|
|
|
|
|
def _flatten(dico, prefix=None):
|
|
"""Flatten a nested dictionary."""
|
|
new_dico = OrderedDict()
|
|
if isinstance(dico, dict):
|
|
prefix = prefix + "." if prefix is not None else ""
|
|
for k, v in dico.items():
|
|
if v is None:
|
|
continue
|
|
new_dico.update(_flatten(v, prefix + k))
|
|
elif isinstance(dico, list):
|
|
for i, v in enumerate(dico):
|
|
new_dico.update(_flatten(v, prefix + ".[" + str(i) + "]"))
|
|
else:
|
|
new_dico = OrderedDict({prefix: dico})
|
|
return new_dico
|
|
|
|
|
|
def _unflatten(dico):
|
|
"""Unflatten a flattened dictionary into a nested dictionary."""
|
|
new_dico = OrderedDict()
|
|
for full_k, v in dico.items():
|
|
full_k = full_k.split(".")
|
|
node = new_dico
|
|
for k in full_k[:-1]:
|
|
if k.startswith("[") and k.endswith("]"):
|
|
k = int(k[1:-1])
|
|
if k not in node:
|
|
node[k] = OrderedDict()
|
|
node = node[k]
|
|
node[full_k[-1]] = v
|
|
return new_dico
|
|
|
|
|
|
class NestedDictionaryDataset(FairseqDataset):
|
|
def __init__(self, defn, sizes=None):
|
|
super().__init__()
|
|
self.defn = _flatten(defn)
|
|
self.sizes = [sizes] if not isinstance(sizes, (list, tuple)) else sizes
|
|
|
|
first = None
|
|
for v in self.defn.values():
|
|
if not isinstance(
|
|
v,
|
|
(
|
|
FairseqDataset,
|
|
torch.utils.data.Dataset,
|
|
),
|
|
):
|
|
raise ValueError("Expected Dataset but found: {}".format(v.__class__))
|
|
first = first or v
|
|
if len(v) > 0:
|
|
assert len(v) == len(first), "dataset lengths must match"
|
|
|
|
self._len = len(first)
|
|
|
|
def __getitem__(self, index):
|
|
return OrderedDict((k, ds[index]) for k, ds in self.defn.items())
|
|
|
|
def __len__(self):
|
|
return self._len
|
|
|
|
def collater(self, samples):
|
|
"""Merge a list of samples to form a mini-batch.
|
|
|
|
Args:
|
|
samples (List[dict]): samples to collate
|
|
|
|
Returns:
|
|
dict: a mini-batch suitable for forwarding with a Model
|
|
"""
|
|
if len(samples) == 0:
|
|
return {}
|
|
sample = OrderedDict()
|
|
for k, ds in self.defn.items():
|
|
try:
|
|
sample[k] = ds.collater([s[k] for s in samples])
|
|
except NotImplementedError:
|
|
sample[k] = default_collate([s[k] for s in samples])
|
|
return _unflatten(sample)
|
|
|
|
def num_tokens(self, index):
|
|
"""Return the number of tokens in a sample. This value is used to
|
|
enforce ``--max-tokens`` during batching."""
|
|
return max(s[index] for s in self.sizes)
|
|
|
|
def size(self, index):
|
|
"""Return an example's size as a float or tuple. This value is used when
|
|
filtering a dataset with ``--max-positions``."""
|
|
if len(self.sizes) == 1:
|
|
return self.sizes[0][index]
|
|
else:
|
|
return (s[index] for s in self.sizes)
|
|
|
|
@property
|
|
def supports_prefetch(self):
|
|
"""Whether this dataset supports prefetching."""
|
|
return any(ds.supports_prefetch for ds in self.defn.values())
|
|
|
|
def prefetch(self, indices):
|
|
"""Prefetch the data required for this epoch."""
|
|
for ds in self.defn.values():
|
|
if getattr(ds, "supports_prefetch", False):
|
|
ds.prefetch(indices)
|
|
|
|
@property
|
|
def can_reuse_epoch_itr_across_epochs(self):
|
|
return all(ds.can_reuse_epoch_itr_across_epochs for ds in self.defn.values())
|
|
|
|
def set_epoch(self, epoch):
|
|
super().set_epoch(epoch)
|
|
for ds in self.defn.values():
|
|
ds.set_epoch(epoch)
|