mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-12 06:50:09 +00:00
502 lines
17 KiB
Python
502 lines
17 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import logging
|
|
import os
|
|
import os.path as op
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
|
|
from fairseq.data.audio.text_to_speech_dataset import TextToSpeechDatasetCreator
|
|
from fairseq.tasks import register_task
|
|
from fairseq.tasks.speech_to_text import SpeechToTextTask
|
|
from fairseq.speech_generator import (
|
|
AutoRegressiveSpeechGenerator,
|
|
NonAutoregressiveSpeechGenerator,
|
|
TeacherForcingAutoRegressiveSpeechGenerator,
|
|
)
|
|
|
|
logging.basicConfig(
|
|
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
level=logging.INFO,
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
try:
|
|
from tensorboardX import SummaryWriter
|
|
except ImportError:
|
|
logger.info("Please install tensorboardX: pip install tensorboardX")
|
|
SummaryWriter = None
|
|
|
|
|
|
@register_task("text_to_speech")
|
|
class TextToSpeechTask(SpeechToTextTask):
|
|
@staticmethod
|
|
def add_args(parser):
|
|
parser.add_argument("data", help="manifest root path")
|
|
parser.add_argument(
|
|
"--config-yaml",
|
|
type=str,
|
|
default="config.yaml",
|
|
help="Configuration YAML filename (under manifest root)",
|
|
)
|
|
parser.add_argument(
|
|
"--max-source-positions",
|
|
default=1024,
|
|
type=int,
|
|
metavar="N",
|
|
help="max number of tokens in the source sequence",
|
|
)
|
|
parser.add_argument(
|
|
"--max-target-positions",
|
|
default=1200,
|
|
type=int,
|
|
metavar="N",
|
|
help="max number of tokens in the target sequence",
|
|
)
|
|
parser.add_argument("--n-frames-per-step", type=int, default=1)
|
|
parser.add_argument("--eos-prob-threshold", type=float, default=0.5)
|
|
parser.add_argument("--eval-inference", action="store_true")
|
|
parser.add_argument("--eval-tb-nsample", type=int, default=8)
|
|
parser.add_argument("--vocoder", type=str, default="griffin_lim")
|
|
parser.add_argument("--spec-bwd-max-iter", type=int, default=8)
|
|
|
|
def __init__(self, args, src_dict):
|
|
super().__init__(args, src_dict)
|
|
self.src_dict = src_dict
|
|
self.sr = self.data_cfg.config.get("features").get("sample_rate")
|
|
|
|
self.tensorboard_writer = None
|
|
self.tensorboard_dir = ""
|
|
if args.tensorboard_logdir and SummaryWriter is not None:
|
|
self.tensorboard_dir = os.path.join(args.tensorboard_logdir, "valid_extra")
|
|
|
|
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
|
is_train_split = split.startswith("train")
|
|
pre_tokenizer = self.build_tokenizer(self.args)
|
|
bpe_tokenizer = self.build_bpe(self.args)
|
|
self.datasets[split] = TextToSpeechDatasetCreator.from_tsv(
|
|
self.args.data,
|
|
self.data_cfg,
|
|
split,
|
|
self.src_dict,
|
|
pre_tokenizer,
|
|
bpe_tokenizer,
|
|
is_train_split=is_train_split,
|
|
epoch=epoch,
|
|
seed=self.args.seed,
|
|
n_frames_per_step=self.args.n_frames_per_step,
|
|
speaker_to_id=self.speaker_to_id,
|
|
)
|
|
|
|
@property
|
|
def target_dictionary(self):
|
|
return None
|
|
|
|
@property
|
|
def source_dictionary(self):
|
|
return self.src_dict
|
|
|
|
def get_speaker_embeddings_path(self):
|
|
speaker_emb_path = None
|
|
if self.data_cfg.config.get("speaker_emb_filename") is not None:
|
|
speaker_emb_path = op.join(
|
|
self.args.data, self.data_cfg.config.get("speaker_emb_filename")
|
|
)
|
|
return speaker_emb_path
|
|
|
|
@classmethod
|
|
def get_speaker_embeddings(cls, args):
|
|
embed_speaker = None
|
|
if args.speaker_to_id is not None:
|
|
if args.speaker_emb_path is None:
|
|
embed_speaker = torch.nn.Embedding(
|
|
len(args.speaker_to_id), args.speaker_embed_dim
|
|
)
|
|
else:
|
|
speaker_emb_mat = np.load(args.speaker_emb_path)
|
|
assert speaker_emb_mat.shape[1] == args.speaker_embed_dim
|
|
embed_speaker = torch.nn.Embedding.from_pretrained(
|
|
torch.from_numpy(speaker_emb_mat),
|
|
freeze=True,
|
|
)
|
|
logger.info(
|
|
f"load speaker embeddings from {args.speaker_emb_path}. "
|
|
f"train embedding? {embed_speaker.weight.requires_grad}\n"
|
|
f"embeddings:\n{speaker_emb_mat}"
|
|
)
|
|
return embed_speaker
|
|
|
|
def build_model(self, cfg, from_checkpoint=False):
|
|
cfg.pitch_min = self.data_cfg.config["features"].get("pitch_min", None)
|
|
cfg.pitch_max = self.data_cfg.config["features"].get("pitch_max", None)
|
|
cfg.energy_min = self.data_cfg.config["features"].get("energy_min", None)
|
|
cfg.energy_max = self.data_cfg.config["features"].get("energy_max", None)
|
|
cfg.speaker_emb_path = self.get_speaker_embeddings_path()
|
|
model = super().build_model(cfg, from_checkpoint)
|
|
self.generator = None
|
|
if getattr(cfg, "eval_inference", False):
|
|
self.generator = self.build_generator([model], cfg)
|
|
return model
|
|
|
|
def build_generator(self, models, cfg, vocoder=None, **unused):
|
|
if vocoder is None:
|
|
vocoder = self.build_default_vocoder()
|
|
model = models[0]
|
|
if getattr(model, "NON_AUTOREGRESSIVE", False):
|
|
return NonAutoregressiveSpeechGenerator(model, vocoder, self.data_cfg)
|
|
else:
|
|
generator = AutoRegressiveSpeechGenerator
|
|
if getattr(cfg, "teacher_forcing", False):
|
|
generator = TeacherForcingAutoRegressiveSpeechGenerator
|
|
logger.info("Teacher forcing mode for generation")
|
|
return generator(
|
|
model,
|
|
vocoder,
|
|
self.data_cfg,
|
|
max_iter=self.args.max_target_positions,
|
|
eos_prob_threshold=self.args.eos_prob_threshold,
|
|
)
|
|
|
|
def build_default_vocoder(self):
|
|
from fairseq.models.text_to_speech.vocoder import get_vocoder
|
|
|
|
vocoder = get_vocoder(self.args, self.data_cfg)
|
|
if torch.cuda.is_available() and not self.args.cpu:
|
|
vocoder = vocoder.cuda()
|
|
else:
|
|
vocoder = vocoder.cpu()
|
|
return vocoder
|
|
|
|
def valid_step(self, sample, model, criterion):
|
|
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
|
|
|
|
if getattr(self.args, "eval_inference", False):
|
|
hypos, inference_losses = self.valid_step_with_inference(
|
|
sample, model, self.generator
|
|
)
|
|
for k, v in inference_losses.items():
|
|
assert k not in logging_output
|
|
logging_output[k] = v
|
|
|
|
picked_id = 0
|
|
if self.tensorboard_dir and (sample["id"] == picked_id).any():
|
|
self.log_tensorboard(
|
|
sample,
|
|
hypos[: self.args.eval_tb_nsample],
|
|
model._num_updates,
|
|
is_na_model=getattr(model, "NON_AUTOREGRESSIVE", False),
|
|
)
|
|
return loss, sample_size, logging_output
|
|
|
|
def valid_step_with_inference(self, sample, model, generator):
|
|
hypos = generator.generate(model, sample, has_targ=True)
|
|
|
|
losses = {
|
|
"mcd_loss": 0.0,
|
|
"targ_frames": 0.0,
|
|
"pred_frames": 0.0,
|
|
"nins": 0.0,
|
|
"ndel": 0.0,
|
|
}
|
|
rets = batch_mel_cepstral_distortion(
|
|
[hypo["targ_waveform"] for hypo in hypos],
|
|
[hypo["waveform"] for hypo in hypos],
|
|
self.sr,
|
|
normalize_type=None,
|
|
)
|
|
for d, extra in rets:
|
|
pathmap = extra[-1]
|
|
losses["mcd_loss"] += d.item()
|
|
losses["targ_frames"] += pathmap.size(0)
|
|
losses["pred_frames"] += pathmap.size(1)
|
|
losses["nins"] += (pathmap.sum(dim=1) - 1).sum().item()
|
|
losses["ndel"] += (pathmap.sum(dim=0) - 1).sum().item()
|
|
|
|
return hypos, losses
|
|
|
|
def log_tensorboard(self, sample, hypos, num_updates, is_na_model=False):
|
|
if self.tensorboard_writer is None:
|
|
self.tensorboard_writer = SummaryWriter(self.tensorboard_dir)
|
|
tb_writer = self.tensorboard_writer
|
|
for b in range(len(hypos)):
|
|
idx = sample["id"][b]
|
|
text = sample["src_texts"][b]
|
|
targ = hypos[b]["targ_feature"]
|
|
pred = hypos[b]["feature"]
|
|
attn = hypos[b]["attn"]
|
|
|
|
if is_na_model:
|
|
data = plot_tts_output(
|
|
[targ.transpose(0, 1), pred.transpose(0, 1)],
|
|
[f"target (idx={idx})", "output"],
|
|
attn,
|
|
"alignment",
|
|
ret_np=True,
|
|
suptitle=text,
|
|
)
|
|
else:
|
|
eos_prob = hypos[b]["eos_prob"]
|
|
data = plot_tts_output(
|
|
[targ.transpose(0, 1), pred.transpose(0, 1), attn],
|
|
[f"target (idx={idx})", "output", "alignment"],
|
|
eos_prob,
|
|
"eos prob",
|
|
ret_np=True,
|
|
suptitle=text,
|
|
)
|
|
|
|
tb_writer.add_image(
|
|
f"inference_sample_{b}", data, num_updates, dataformats="HWC"
|
|
)
|
|
|
|
if hypos[b]["waveform"] is not None:
|
|
targ_wave = hypos[b]["targ_waveform"].detach().cpu().float()
|
|
pred_wave = hypos[b]["waveform"].detach().cpu().float()
|
|
tb_writer.add_audio(
|
|
f"inference_targ_{b}", targ_wave, num_updates, sample_rate=self.sr
|
|
)
|
|
tb_writer.add_audio(
|
|
f"inference_pred_{b}", pred_wave, num_updates, sample_rate=self.sr
|
|
)
|
|
|
|
|
|
def save_figure_to_numpy(fig):
|
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
|
return data
|
|
|
|
|
|
DEFAULT_V_MIN = np.log(1e-5)
|
|
|
|
|
|
def plot_tts_output(
|
|
data_2d,
|
|
title_2d,
|
|
data_1d,
|
|
title_1d,
|
|
figsize=(24, 4),
|
|
v_min=DEFAULT_V_MIN,
|
|
v_max=3,
|
|
ret_np=False,
|
|
suptitle="",
|
|
):
|
|
try:
|
|
import matplotlib.pyplot as plt
|
|
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
|
except ImportError:
|
|
raise ImportError("Please install Matplotlib: pip install matplotlib")
|
|
|
|
data_2d = [
|
|
x.detach().cpu().float().numpy() if isinstance(x, torch.Tensor) else x
|
|
for x in data_2d
|
|
]
|
|
fig, axes = plt.subplots(1, len(data_2d) + 1, figsize=figsize)
|
|
if suptitle:
|
|
fig.suptitle(suptitle[:400]) # capped at 400 chars
|
|
axes = [axes] if len(data_2d) == 0 else axes
|
|
for ax, x, name in zip(axes, data_2d, title_2d):
|
|
ax.set_title(name)
|
|
divider = make_axes_locatable(ax)
|
|
cax = divider.append_axes("right", size="5%", pad=0.05)
|
|
im = ax.imshow(
|
|
x,
|
|
origin="lower",
|
|
aspect="auto",
|
|
vmin=max(x.min(), v_min),
|
|
vmax=min(x.max(), v_max),
|
|
)
|
|
fig.colorbar(im, cax=cax, orientation="vertical")
|
|
|
|
if isinstance(data_1d, torch.Tensor):
|
|
data_1d = data_1d.detach().cpu().numpy()
|
|
axes[-1].plot(data_1d)
|
|
axes[-1].set_title(title_1d)
|
|
plt.tight_layout()
|
|
|
|
if ret_np:
|
|
fig.canvas.draw()
|
|
data = save_figure_to_numpy(fig)
|
|
plt.close(fig)
|
|
return data
|
|
|
|
|
|
def antidiag_indices(offset, min_i=0, max_i=None, min_j=0, max_j=None):
|
|
"""
|
|
for a (3, 4) matrix with min_i=1, max_i=3, min_j=1, max_j=4, outputs
|
|
|
|
offset=2 (1, 1),
|
|
offset=3 (2, 1), (1, 2)
|
|
offset=4 (2, 2), (1, 3)
|
|
offset=5 (2, 3)
|
|
|
|
constraints:
|
|
i + j = offset
|
|
min_j <= j < max_j
|
|
min_i <= offset - j < max_i
|
|
"""
|
|
if max_i is None:
|
|
max_i = offset + 1
|
|
if max_j is None:
|
|
max_j = offset + 1
|
|
min_j = max(min_j, offset - max_i + 1, 0)
|
|
max_j = min(max_j, offset - min_i + 1, offset + 1)
|
|
j = torch.arange(min_j, max_j)
|
|
i = offset - j
|
|
return torch.stack([i, j])
|
|
|
|
|
|
def batch_dynamic_time_warping(distance, shapes=None):
|
|
"""full batched DTW without any constraints
|
|
|
|
distance: (batchsize, max_M, max_N) matrix
|
|
shapes: (batchsize,) vector specifying (M, N) for each entry
|
|
"""
|
|
# ptr: 0=left, 1=up-left, 2=up
|
|
ptr2dij = {0: (0, -1), 1: (-1, -1), 2: (-1, 0)}
|
|
|
|
bsz, m, n = distance.size()
|
|
cumdist = torch.zeros_like(distance)
|
|
backptr = torch.zeros_like(distance).type(torch.int32) - 1
|
|
|
|
# initialize
|
|
cumdist[:, 0, :] = distance[:, 0, :].cumsum(dim=-1)
|
|
cumdist[:, :, 0] = distance[:, :, 0].cumsum(dim=-1)
|
|
backptr[:, 0, :] = 0
|
|
backptr[:, :, 0] = 2
|
|
|
|
# DP with optimized anti-diagonal parallelization, O(M+N) steps
|
|
for offset in range(2, m + n - 1):
|
|
ind = antidiag_indices(offset, 1, m, 1, n)
|
|
c = torch.stack(
|
|
[
|
|
cumdist[:, ind[0], ind[1] - 1],
|
|
cumdist[:, ind[0] - 1, ind[1] - 1],
|
|
cumdist[:, ind[0] - 1, ind[1]],
|
|
],
|
|
dim=2,
|
|
)
|
|
v, b = c.min(axis=-1)
|
|
backptr[:, ind[0], ind[1]] = b.int()
|
|
cumdist[:, ind[0], ind[1]] = v + distance[:, ind[0], ind[1]]
|
|
|
|
# backtrace
|
|
pathmap = torch.zeros_like(backptr)
|
|
for b in range(bsz):
|
|
i = m - 1 if shapes is None else (shapes[b][0] - 1).item()
|
|
j = n - 1 if shapes is None else (shapes[b][1] - 1).item()
|
|
dtwpath = [(i, j)]
|
|
while (i != 0 or j != 0) and len(dtwpath) < 10000:
|
|
assert i >= 0 and j >= 0
|
|
di, dj = ptr2dij[backptr[b, i, j].item()]
|
|
i, j = i + di, j + dj
|
|
dtwpath.append((i, j))
|
|
dtwpath = dtwpath[::-1]
|
|
indices = torch.from_numpy(np.array(dtwpath))
|
|
pathmap[b, indices[:, 0], indices[:, 1]] = 1
|
|
|
|
return cumdist, backptr, pathmap
|
|
|
|
|
|
def compute_l2_dist(x1, x2):
|
|
"""compute an (m, n) L2 distance matrix from (m, d) and (n, d) matrices"""
|
|
return torch.cdist(x1.unsqueeze(0), x2.unsqueeze(0), p=2).squeeze(0).pow(2)
|
|
|
|
|
|
def compute_rms_dist(x1, x2):
|
|
l2_dist = compute_l2_dist(x1, x2)
|
|
return (l2_dist / x1.size(1)).pow(0.5)
|
|
|
|
|
|
def get_divisor(pathmap, normalize_type):
|
|
if normalize_type is None:
|
|
return 1
|
|
elif normalize_type == "len1":
|
|
return pathmap.size(0)
|
|
elif normalize_type == "len2":
|
|
return pathmap.size(1)
|
|
elif normalize_type == "path":
|
|
return pathmap.sum().item()
|
|
else:
|
|
raise ValueError(f"normalize_type {normalize_type} not supported")
|
|
|
|
|
|
def batch_compute_distortion(y1, y2, sr, feat_fn, dist_fn, normalize_type):
|
|
d, s, x1, x2 = [], [], [], []
|
|
for cur_y1, cur_y2 in zip(y1, y2):
|
|
assert cur_y1.ndim == 1 and cur_y2.ndim == 1
|
|
cur_x1 = feat_fn(cur_y1)
|
|
cur_x2 = feat_fn(cur_y2)
|
|
x1.append(cur_x1)
|
|
x2.append(cur_x2)
|
|
|
|
cur_d = dist_fn(cur_x1, cur_x2)
|
|
d.append(cur_d)
|
|
s.append(d[-1].size())
|
|
max_m = max(ss[0] for ss in s)
|
|
max_n = max(ss[1] for ss in s)
|
|
d = torch.stack(
|
|
[F.pad(dd, (0, max_n - dd.size(1), 0, max_m - dd.size(0))) for dd in d]
|
|
)
|
|
s = torch.LongTensor(s).to(d.device)
|
|
cumdists, backptrs, pathmaps = batch_dynamic_time_warping(d, s)
|
|
|
|
rets = []
|
|
itr = zip(s, x1, x2, d, cumdists, backptrs, pathmaps)
|
|
for (m, n), cur_x1, cur_x2, dist, cumdist, backptr, pathmap in itr:
|
|
cumdist = cumdist[:m, :n]
|
|
backptr = backptr[:m, :n]
|
|
pathmap = pathmap[:m, :n]
|
|
divisor = get_divisor(pathmap, normalize_type)
|
|
|
|
distortion = cumdist[-1, -1] / divisor
|
|
ret = distortion, (cur_x1, cur_x2, dist, cumdist, backptr, pathmap)
|
|
rets.append(ret)
|
|
return rets
|
|
|
|
|
|
def batch_mel_cepstral_distortion(y1, y2, sr, normalize_type="path", mfcc_fn=None):
|
|
"""
|
|
https://arxiv.org/pdf/2011.03568.pdf
|
|
|
|
The root mean squared error computed on 13-dimensional MFCC using DTW for
|
|
alignment. MFCC features are computed from an 80-channel log-mel
|
|
spectrogram using a 50ms Hann window and hop of 12.5ms.
|
|
|
|
y1: list of waveforms
|
|
y2: list of waveforms
|
|
sr: sampling rate
|
|
"""
|
|
|
|
try:
|
|
import torchaudio
|
|
except ImportError:
|
|
raise ImportError("Please install torchaudio: pip install torchaudio")
|
|
|
|
if mfcc_fn is None or mfcc_fn.sample_rate != sr:
|
|
melkwargs = {
|
|
"n_fft": int(0.05 * sr),
|
|
"win_length": int(0.05 * sr),
|
|
"hop_length": int(0.0125 * sr),
|
|
"f_min": 20,
|
|
"n_mels": 80,
|
|
"window_fn": torch.hann_window,
|
|
}
|
|
mfcc_fn = torchaudio.transforms.MFCC(
|
|
sr, n_mfcc=13, log_mels=True, melkwargs=melkwargs
|
|
).to(y1[0].device)
|
|
return batch_compute_distortion(
|
|
y1,
|
|
y2,
|
|
sr,
|
|
lambda y: mfcc_fn(y).transpose(-1, -2),
|
|
compute_rms_dist,
|
|
normalize_type,
|
|
)
|