mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-11 14:30:03 +00:00
357 lines
12 KiB
Python
357 lines
12 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import itertools
|
|
import logging
|
|
import os
|
|
import sys
|
|
from typing import Any, List, Optional, Union
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from fairseq.data import data_utils
|
|
from fairseq.data.fairseq_dataset import FairseqDataset
|
|
from fairseq.data.audio.audio_utils import (
|
|
parse_path,
|
|
read_from_stored_zip,
|
|
)
|
|
import io
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def load_audio(manifest_path, max_keep, min_keep):
|
|
n_long, n_short = 0, 0
|
|
names, inds, sizes = [], [], []
|
|
with open(manifest_path) as f:
|
|
root = f.readline().strip()
|
|
for ind, line in enumerate(f):
|
|
items = line.strip().split("\t")
|
|
assert len(items) == 2, line
|
|
sz = int(items[1])
|
|
if min_keep is not None and sz < min_keep:
|
|
n_short += 1
|
|
elif max_keep is not None and sz > max_keep:
|
|
n_long += 1
|
|
else:
|
|
names.append(items[0])
|
|
inds.append(ind)
|
|
sizes.append(sz)
|
|
tot = ind + 1
|
|
logger.info(
|
|
(
|
|
f"max_keep={max_keep}, min_keep={min_keep}, "
|
|
f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
|
|
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
|
|
)
|
|
)
|
|
return root, names, inds, tot, sizes
|
|
|
|
|
|
def load_label(label_path, inds, tot):
|
|
with open(label_path) as f:
|
|
labels = [line.rstrip() for line in f]
|
|
assert (
|
|
len(labels) == tot
|
|
), f"number of labels does not match ({len(labels)} != {tot})"
|
|
labels = [labels[i] for i in inds]
|
|
return labels
|
|
|
|
|
|
def load_label_offset(label_path, inds, tot):
|
|
with open(label_path) as f:
|
|
code_lengths = [len(line.encode("utf-8")) for line in f]
|
|
assert (
|
|
len(code_lengths) == tot
|
|
), f"number of labels does not match ({len(code_lengths)} != {tot})"
|
|
offsets = list(itertools.accumulate([0] + code_lengths))
|
|
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
|
|
return offsets
|
|
|
|
|
|
def verify_label_lengths(
|
|
audio_sizes,
|
|
audio_rate,
|
|
label_path,
|
|
label_rate,
|
|
inds,
|
|
tot,
|
|
tol=0.1, # tolerance in seconds
|
|
):
|
|
if label_rate < 0:
|
|
logger.info(f"{label_path} is sequence label. skipped")
|
|
return
|
|
|
|
with open(label_path) as f:
|
|
lengths = [len(line.rstrip().split()) for line in f]
|
|
assert len(lengths) == tot
|
|
lengths = [lengths[i] for i in inds]
|
|
num_invalid = 0
|
|
for i, ind in enumerate(inds):
|
|
dur_from_audio = audio_sizes[i] / audio_rate
|
|
dur_from_label = lengths[i] / label_rate
|
|
if abs(dur_from_audio - dur_from_label) > tol:
|
|
logger.warning(
|
|
(
|
|
f"audio and label duration differ too much "
|
|
f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
|
|
f"in line {ind+1} of {label_path}. Check if `label_rate` "
|
|
f"is correctly set (currently {label_rate}). "
|
|
f"num. of samples = {audio_sizes[i]}; "
|
|
f"label length = {lengths[i]}"
|
|
)
|
|
)
|
|
num_invalid += 1
|
|
if num_invalid > 0:
|
|
logger.warning(
|
|
f"total {num_invalid} (audio, label) pairs with mismatched lengths"
|
|
)
|
|
|
|
|
|
class HubertDataset(FairseqDataset):
|
|
def __init__(
|
|
self,
|
|
manifest_path: str,
|
|
sample_rate: float,
|
|
label_paths: List[str],
|
|
label_rates: Union[List[float], float], # -1 for sequence labels
|
|
pad_list: List[str],
|
|
eos_list: List[str],
|
|
label_processors: Optional[List[Any]] = None,
|
|
max_keep_sample_size: Optional[int] = None,
|
|
min_keep_sample_size: Optional[int] = None,
|
|
max_sample_size: Optional[int] = None,
|
|
shuffle: bool = True,
|
|
pad_audio: bool = False,
|
|
normalize: bool = False,
|
|
store_labels: bool = True,
|
|
random_crop: bool = False,
|
|
single_target: bool = False,
|
|
):
|
|
self.audio_root, self.audio_names, inds, tot, self.sizes = load_audio(
|
|
manifest_path, max_keep_sample_size, min_keep_sample_size
|
|
)
|
|
self.sample_rate = sample_rate
|
|
self.shuffle = shuffle
|
|
self.random_crop = random_crop
|
|
|
|
self.num_labels = len(label_paths)
|
|
self.pad_list = pad_list
|
|
self.eos_list = eos_list
|
|
self.label_processors = label_processors
|
|
self.single_target = single_target
|
|
self.label_rates = (
|
|
[label_rates for _ in range(len(label_paths))]
|
|
if isinstance(label_rates, float)
|
|
else label_rates
|
|
)
|
|
self.store_labels = store_labels
|
|
if store_labels:
|
|
self.label_list = [load_label(p, inds, tot) for p in label_paths]
|
|
else:
|
|
self.label_paths = label_paths
|
|
self.label_offsets_list = [
|
|
load_label_offset(p, inds, tot) for p in label_paths
|
|
]
|
|
assert label_processors is None or len(label_processors) == self.num_labels
|
|
for label_path, label_rate in zip(label_paths, self.label_rates):
|
|
verify_label_lengths(
|
|
self.sizes, sample_rate, label_path, label_rate, inds, tot
|
|
)
|
|
|
|
self.max_sample_size = (
|
|
max_sample_size if max_sample_size is not None else sys.maxsize
|
|
)
|
|
self.pad_audio = pad_audio
|
|
self.normalize = normalize
|
|
logger.info(
|
|
f"pad_audio={pad_audio}, random_crop={random_crop}, "
|
|
f"normalize={normalize}, max_sample_size={self.max_sample_size}"
|
|
)
|
|
|
|
def get_audio(self, index):
|
|
import soundfile as sf
|
|
|
|
wav_path = os.path.join(self.audio_root, self.audio_names[index])
|
|
_path, slice_ptr = parse_path(wav_path)
|
|
if len(slice_ptr) == 0:
|
|
wav, cur_sample_rate = sf.read(_path)
|
|
else:
|
|
assert _path.endswith(".zip")
|
|
data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
|
|
f = io.BytesIO(data)
|
|
wav, cur_sample_rate = sf.read(f)
|
|
wav = torch.from_numpy(wav).float()
|
|
wav = self.postprocess(wav, cur_sample_rate)
|
|
return wav
|
|
|
|
def get_label(self, index, label_idx):
|
|
if self.store_labels:
|
|
label = self.label_list[label_idx][index]
|
|
else:
|
|
with open(self.label_paths[label_idx]) as f:
|
|
offset_s, offset_e = self.label_offsets_list[label_idx][index]
|
|
f.seek(offset_s)
|
|
label = f.read(offset_e - offset_s)
|
|
|
|
if self.label_processors is not None:
|
|
label = self.label_processors[label_idx](label)
|
|
return label
|
|
|
|
def get_labels(self, index):
|
|
return [self.get_label(index, i) for i in range(self.num_labels)]
|
|
|
|
def __getitem__(self, index):
|
|
wav = self.get_audio(index)
|
|
labels = self.get_labels(index)
|
|
return {"id": index, "source": wav, "label_list": labels}
|
|
|
|
def __len__(self):
|
|
return len(self.sizes)
|
|
|
|
def crop_to_max_size(self, wav, target_size):
|
|
size = len(wav)
|
|
diff = size - target_size
|
|
if diff <= 0:
|
|
return wav, 0
|
|
|
|
start, end = 0, target_size
|
|
if self.random_crop:
|
|
start = np.random.randint(0, diff + 1)
|
|
end = size - diff + start
|
|
return wav[start:end], start
|
|
|
|
def collater(self, samples):
|
|
# target = max(sizes) -> random_crop not used
|
|
# target = max_sample_size -> random_crop used for long
|
|
samples = [s for s in samples if s["source"] is not None]
|
|
if len(samples) == 0:
|
|
return {}
|
|
|
|
audios = [s["source"] for s in samples]
|
|
audio_sizes = [len(s) for s in audios]
|
|
if self.pad_audio:
|
|
audio_size = min(max(audio_sizes), self.max_sample_size)
|
|
else:
|
|
audio_size = min(min(audio_sizes), self.max_sample_size)
|
|
collated_audios, padding_mask, audio_starts = self.collater_audio(
|
|
audios, audio_size
|
|
)
|
|
|
|
targets_by_label = [
|
|
[s["label_list"][i] for s in samples] for i in range(self.num_labels)
|
|
]
|
|
targets_list, lengths_list, ntokens_list = self.collater_label(
|
|
targets_by_label, audio_size, audio_starts
|
|
)
|
|
|
|
net_input = {"source": collated_audios, "padding_mask": padding_mask}
|
|
batch = {
|
|
"id": torch.LongTensor([s["id"] for s in samples]),
|
|
"net_input": net_input,
|
|
}
|
|
|
|
if self.single_target:
|
|
batch["target_lengths"] = lengths_list[0]
|
|
batch["ntokens"] = ntokens_list[0]
|
|
batch["target"] = targets_list[0]
|
|
else:
|
|
batch["target_lengths_list"] = lengths_list
|
|
batch["ntokens_list"] = ntokens_list
|
|
batch["target_list"] = targets_list
|
|
return batch
|
|
|
|
def collater_audio(self, audios, audio_size):
|
|
collated_audios = audios[0].new_zeros(len(audios), audio_size)
|
|
padding_mask = (
|
|
torch.BoolTensor(collated_audios.shape).fill_(False)
|
|
# if self.pad_audio else None
|
|
)
|
|
audio_starts = [0 for _ in audios]
|
|
for i, audio in enumerate(audios):
|
|
diff = len(audio) - audio_size
|
|
if diff == 0:
|
|
collated_audios[i] = audio
|
|
elif diff < 0:
|
|
assert self.pad_audio
|
|
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
|
|
padding_mask[i, diff:] = True
|
|
else:
|
|
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
|
|
audio, audio_size
|
|
)
|
|
return collated_audios, padding_mask, audio_starts
|
|
|
|
def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
|
|
assert label_rate > 0
|
|
s2f = label_rate / self.sample_rate
|
|
frm_starts = [int(round(s * s2f)) for s in audio_starts]
|
|
frm_size = int(round(audio_size * s2f))
|
|
if not self.pad_audio:
|
|
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
|
|
frm_size = min(frm_size, *rem_size)
|
|
targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
|
|
logger.debug(f"audio_starts={audio_starts}")
|
|
logger.debug(f"frame_starts={frm_starts}")
|
|
logger.debug(f"frame_size={frm_size}")
|
|
|
|
lengths = torch.LongTensor([len(t) for t in targets])
|
|
ntokens = lengths.sum().item()
|
|
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
|
return targets, lengths, ntokens
|
|
|
|
def collater_seq_label(self, targets, pad):
|
|
lengths = torch.LongTensor([len(t) for t in targets])
|
|
ntokens = lengths.sum().item()
|
|
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
|
return targets, lengths, ntokens
|
|
|
|
def collater_label(self, targets_by_label, audio_size, audio_starts):
|
|
targets_list, lengths_list, ntokens_list = [], [], []
|
|
itr = zip(targets_by_label, self.label_rates, self.pad_list)
|
|
for targets, label_rate, pad in itr:
|
|
if label_rate == -1.0:
|
|
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
|
|
else:
|
|
targets, lengths, ntokens = self.collater_frm_label(
|
|
targets, audio_size, audio_starts, label_rate, pad
|
|
)
|
|
targets_list.append(targets)
|
|
lengths_list.append(lengths)
|
|
ntokens_list.append(ntokens)
|
|
return targets_list, lengths_list, ntokens_list
|
|
|
|
def num_tokens(self, index):
|
|
return self.size(index)
|
|
|
|
def size(self, index):
|
|
if self.pad_audio:
|
|
return self.sizes[index]
|
|
return min(self.sizes[index], self.max_sample_size)
|
|
|
|
def ordered_indices(self):
|
|
if self.shuffle:
|
|
order = [np.random.permutation(len(self))]
|
|
else:
|
|
order = [np.arange(len(self))]
|
|
|
|
order.append(self.sizes)
|
|
return np.lexsort(order)[::-1]
|
|
|
|
def postprocess(self, wav, cur_sample_rate):
|
|
if wav.dim() == 2:
|
|
wav = wav.mean(-1)
|
|
assert wav.dim() == 1, wav.dim()
|
|
|
|
if cur_sample_rate != self.sample_rate:
|
|
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
|
|
|
if self.normalize:
|
|
with torch.no_grad():
|
|
wav = F.layer_norm(wav, wav.shape)
|
|
return wav
|