mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-25 08:59:01 +00:00
409 lines
14 KiB
Python
409 lines
14 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
from dataclasses import dataclass, field
|
|
from fairseq.models.fairseq_decoder import FairseqDecoder
|
|
import numpy as np
|
|
from typing import Optional, Dict, Any, List
|
|
import torch
|
|
from torch import nn
|
|
from fairseq.data.data_utils import compute_mask_indices
|
|
from fairseq.dataclass import ChoiceEnum
|
|
from fairseq.models import (
|
|
FairseqLanguageModel,
|
|
register_model,
|
|
register_model_architecture,
|
|
)
|
|
from fairseq.tasks.speech_ulm_task import SpeechUnitLanguageModelingTask
|
|
from fairseq.models.transformer import Embedding, TransformerDecoder, Linear
|
|
from fairseq.models.transformer_lm import TransformerLanguageModelConfig
|
|
from torch import Tensor
|
|
|
|
|
|
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
|
MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"])
|
|
|
|
|
|
@dataclass
|
|
class SpeechUnitLanguageModelConfig(TransformerLanguageModelConfig):
|
|
mask_unit_seg_prob: float = field(
|
|
default=0.0, metadata={"help": "probability to mask a segment of unit sequence"}
|
|
)
|
|
mask_unit_seg_leng: int = field(
|
|
default=5, metadata={"help": "length of unit segment mask"}
|
|
)
|
|
mask_unit_seg_type: MASKING_DISTRIBUTION_CHOICES = field(
|
|
default="static", metadata={"help": "how to choose unit mask length"}
|
|
)
|
|
|
|
mask_dur_prob: float = field(
|
|
default=0.0, metadata={"help": "probability to mask entire duration sequence"}
|
|
)
|
|
mask_dur_seg_prob: float = field(
|
|
default=0.0,
|
|
metadata={"help": "probability to mask a segment of duration sequence"},
|
|
)
|
|
mask_dur_seg_leng: int = field(
|
|
default=5, metadata={"help": "length of duration segment mask"}
|
|
)
|
|
mask_dur_seg_type: MASKING_DISTRIBUTION_CHOICES = field(
|
|
default="static", metadata={"help": "how to choose duration mask length"}
|
|
)
|
|
|
|
mask_f0_prob: float = field(
|
|
default=0.0, metadata={"help": "probability to mask entire duration sequence"}
|
|
)
|
|
mask_f0_seg_prob: float = field(
|
|
default=0.0, metadata={"help": "probability to mask a segment of f0 sequence"}
|
|
)
|
|
mask_f0_seg_leng: int = field(
|
|
default=5, metadata={"help": "length of f0 segment mask"}
|
|
)
|
|
mask_f0_seg_type: MASKING_DISTRIBUTION_CHOICES = field(
|
|
default="static", metadata={"help": "how to choose f0 mask length"}
|
|
)
|
|
|
|
|
|
@register_model("transformer_ulm", dataclass=SpeechUnitLanguageModelConfig)
|
|
class TransformerUnitLanguageModel(FairseqLanguageModel):
|
|
def __init__(
|
|
self,
|
|
cfg: SpeechUnitLanguageModelConfig,
|
|
task: SpeechUnitLanguageModelingTask,
|
|
decoder: FairseqDecoder,
|
|
):
|
|
super().__init__(decoder)
|
|
self.cfg = cfg
|
|
|
|
self.channel_names = task.channel_names
|
|
self.channel_sizes = task.channel_sizes
|
|
|
|
self.unit_mask_val = task.source_dictionary.unk()
|
|
self.dur_mask_val = (
|
|
task.source_duration_dictionary.unk() if task.cfg.discrete_duration else 0
|
|
)
|
|
self.f0_mask_val = (
|
|
task.source_f0_dictionary.unk() if task.cfg.discrete_f0 else 0
|
|
)
|
|
|
|
self.ignore_duration_input = task.cfg.ignore_duration_input
|
|
self.ignore_f0_input = task.cfg.ignore_f0_input
|
|
|
|
@classmethod
|
|
def build_model(cls, args, task):
|
|
base_ulm_architecture(args)
|
|
|
|
if getattr(args, "max_target_positions", None) is None:
|
|
args.max_target_positions = getattr(
|
|
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
|
|
)
|
|
|
|
embed_tokens = Embedding(
|
|
len(task.source_dictionary),
|
|
args.decoder_input_dim,
|
|
padding_idx=task.source_dictionary.pad(),
|
|
)
|
|
embed_duration = None
|
|
if task.cfg.discrete_duration:
|
|
embed_duration = Embedding(
|
|
len(task.source_duration_dictionary),
|
|
args.decoder_input_dim,
|
|
padding_idx=0, # duration uses 0 for padding
|
|
)
|
|
embed_f0 = None
|
|
if task.cfg.discrete_f0:
|
|
embed_f0 = Embedding(
|
|
len(task.source_f0_dictionary),
|
|
args.decoder_input_dim,
|
|
padding_idx=task.source_f0_dictionary.pad(),
|
|
)
|
|
|
|
decoder = MultiStreamTransformerDecoder(
|
|
args,
|
|
task.target_dictionary,
|
|
embed_tokens,
|
|
[embed_duration, embed_f0],
|
|
no_encoder_attn=True,
|
|
channel_sizes=task.channel_sizes,
|
|
)
|
|
|
|
return cls(args, task, decoder)
|
|
|
|
def apply_seg_dropout(self, inp, mask_prob, mask_leng, mask_type, mask_val):
|
|
B, T = inp.size()
|
|
if mask_prob > 0:
|
|
mask_indices = compute_mask_indices(
|
|
(B, T), None, mask_prob, mask_leng, mask_type # may mask padding
|
|
)
|
|
mask_indices = torch.from_numpy(mask_indices).to(inp.device)
|
|
inp[mask_indices] = mask_val
|
|
else:
|
|
mask_indices = torch.zeros_like(inp).bool()
|
|
return inp, mask_indices
|
|
|
|
def apply_seq_dropout(self, inp, mask_prob, mask_val):
|
|
B, T = inp.size()
|
|
if mask_prob > 0:
|
|
mask_indices = np.random.uniform(0, 1, (B,)) < mask_prob
|
|
mask_indices = (
|
|
torch.from_numpy(mask_indices).to(inp.device).unsqueeze(1).expand(-1, T)
|
|
)
|
|
inp[mask_indices] = mask_val
|
|
else:
|
|
mask_indices = torch.zeros_like(inp).bool()
|
|
return inp, mask_indices
|
|
|
|
def apply_dropout(self, src_tokens, dur_src, f0_src):
|
|
src_tokens, unit_mask = self.apply_seg_dropout(
|
|
src_tokens,
|
|
self.cfg.mask_unit_seg_prob,
|
|
self.cfg.mask_unit_seg_leng,
|
|
self.cfg.mask_unit_seg_type,
|
|
self.unit_mask_val,
|
|
)
|
|
|
|
dur_src, dur_mask = self.apply_seq_dropout(
|
|
dur_src, self.cfg.mask_dur_prob, self.dur_mask_val
|
|
)
|
|
dur_src, _dur_mask = self.apply_seg_dropout(
|
|
dur_src,
|
|
self.cfg.mask_dur_seg_prob,
|
|
self.cfg.mask_dur_seg_leng,
|
|
self.cfg.mask_dur_seg_type,
|
|
self.dur_mask_val,
|
|
)
|
|
dur_mask = dur_mask.logical_or(_dur_mask)
|
|
|
|
f0_src, f0_mask = self.apply_seq_dropout(
|
|
f0_src, self.cfg.mask_f0_prob, self.f0_mask_val
|
|
)
|
|
f0_src, _f0_mask = self.apply_seg_dropout(
|
|
f0_src,
|
|
self.cfg.mask_f0_seg_prob,
|
|
self.cfg.mask_f0_seg_leng,
|
|
self.cfg.mask_f0_seg_type,
|
|
self.f0_mask_val,
|
|
)
|
|
f0_mask = f0_mask.logical_or(_f0_mask)
|
|
|
|
return src_tokens, unit_mask, dur_src, dur_mask, f0_src, f0_mask
|
|
|
|
def forward(
|
|
self,
|
|
src_tokens: torch.Tensor,
|
|
dur_src: torch.Tensor,
|
|
f0_src: torch.Tensor,
|
|
src_lengths: Optional[Any] = None,
|
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
|
):
|
|
if self.ignore_duration_input:
|
|
dur_src = torch.zeros_like(dur_src)
|
|
|
|
if self.ignore_f0_input:
|
|
f0_src = torch.zeros_like(f0_src)
|
|
|
|
if self.training:
|
|
(
|
|
src_tokens,
|
|
unit_mask,
|
|
dur_src,
|
|
dur_mask,
|
|
f0_src,
|
|
f0_mask,
|
|
) = self.apply_dropout(src_tokens, dur_src, f0_src)
|
|
else:
|
|
unit_masks = dur_mask = f0_mask = None
|
|
|
|
prediction, _ = self.decoder(
|
|
prev_output_tokens=(src_tokens, dur_src, f0_src),
|
|
incremental_state=incremental_state,
|
|
src_lengths=src_lengths,
|
|
features_only=True,
|
|
)
|
|
|
|
result = dict(zip(self.channel_names, prediction))
|
|
|
|
return result
|
|
|
|
|
|
def base_ulm_architecture(args):
|
|
from .transformer_lm import base_lm_architecture
|
|
|
|
base_lm_architecture(args)
|
|
|
|
|
|
@register_model_architecture("transformer_ulm", "transformer_ulm_big")
|
|
def transformer_ulm_big(args):
|
|
from .transformer_lm import transformer_lm_big
|
|
|
|
transformer_lm_big(args)
|
|
base_ulm_architecture(args)
|
|
|
|
|
|
@register_model_architecture("transformer_ulm", "transformer_ulm_tiny")
|
|
def transformer_ulm_tiny(args):
|
|
from .transformer_lm import transformer_lm_gpt2_tiny
|
|
|
|
transformer_lm_gpt2_tiny(args)
|
|
base_ulm_architecture(args)
|
|
|
|
|
|
class MultiStreamTransformerDecoder(TransformerDecoder):
|
|
def __init__(
|
|
self,
|
|
args,
|
|
dictionary,
|
|
embed_tokens,
|
|
embed_other_list,
|
|
no_encoder_attn,
|
|
channel_sizes,
|
|
):
|
|
super().__init__(
|
|
args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
|
|
)
|
|
|
|
# embed each channel and project if dimensions do not match
|
|
self.embed_other_list = torch.nn.ModuleList(embed_other_list)
|
|
self.proj_other_list = torch.nn.ModuleList()
|
|
dim = embed_tokens.embedding_dim
|
|
for embed_other in embed_other_list:
|
|
other_dim = 1 if embed_other is None else embed_other.embedding_dim
|
|
self.proj_other_list.append(
|
|
nn.Linear(other_dim, dim) if other_dim != dim else None
|
|
)
|
|
|
|
# tranformer output to prediction
|
|
self.channel_sizes = channel_sizes
|
|
self.project_out_dim = Linear(
|
|
embed_tokens.embedding_dim, sum(channel_sizes), bias=False
|
|
)
|
|
|
|
def extract_features_scriptable(
|
|
self,
|
|
prev_output_tokens,
|
|
encoder_out: Optional[Dict[str, List[Tensor]]],
|
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
|
full_context_alignment: bool = False,
|
|
alignment_layer: Optional[int] = None,
|
|
alignment_heads: Optional[int] = None,
|
|
):
|
|
if alignment_layer is None:
|
|
alignment_layer = self.num_layers - 1
|
|
|
|
# XXX: first multi-channel change start
|
|
prev_output_tokens, *other_channels = prev_output_tokens
|
|
# XXX: first multi-channel change end
|
|
|
|
# embed positions
|
|
positions = None
|
|
if self.embed_positions is not None:
|
|
positions = self.embed_positions(
|
|
prev_output_tokens, incremental_state=incremental_state
|
|
)
|
|
|
|
if incremental_state is not None:
|
|
prev_output_tokens = prev_output_tokens[:, -1:]
|
|
other_channels = [o[:, -1:] for o in other_channels]
|
|
if positions is not None:
|
|
positions = positions[:, -1:]
|
|
|
|
# embed tokens and positions
|
|
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
|
|
|
|
# XXX: second multi-channel change start
|
|
other_channels = [
|
|
o.unsqueeze(-1).to(dtype=x.dtype) if emb is None else emb(o)
|
|
for o, emb in zip(other_channels, self.embed_other_list)
|
|
]
|
|
other_channels = [
|
|
o if proj_other is None else proj_other(o)
|
|
for o, proj_other in zip(other_channels, self.proj_other_list)
|
|
]
|
|
for o in other_channels:
|
|
x = x + o
|
|
# XXX: second multi-channel change end
|
|
|
|
if self.quant_noise is not None:
|
|
x = self.quant_noise(x)
|
|
|
|
if self.project_in_dim is not None:
|
|
x = self.project_in_dim(x)
|
|
|
|
if positions is not None:
|
|
x += positions
|
|
|
|
if self.layernorm_embedding is not None:
|
|
x = self.layernorm_embedding(x)
|
|
|
|
x = self.dropout_module(x)
|
|
|
|
# B x T x C -> T x B x C
|
|
x = x.transpose(0, 1)
|
|
|
|
self_attn_padding_mask: Optional[Tensor] = None
|
|
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
|
|
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
|
|
|
|
# decoder layers
|
|
attn: Optional[Tensor] = None
|
|
inner_states: List[Optional[Tensor]] = [x]
|
|
for idx, layer in enumerate(self.layers):
|
|
if incremental_state is None and not full_context_alignment:
|
|
self_attn_mask = self.buffered_future_mask(x)
|
|
else:
|
|
self_attn_mask = None
|
|
|
|
x, layer_attn, _ = layer(
|
|
x,
|
|
encoder_out["encoder_out"][0]
|
|
if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0)
|
|
else None,
|
|
encoder_out["encoder_padding_mask"][0]
|
|
if (
|
|
encoder_out is not None
|
|
and len(encoder_out["encoder_padding_mask"]) > 0
|
|
)
|
|
else None,
|
|
incremental_state,
|
|
self_attn_mask=self_attn_mask,
|
|
self_attn_padding_mask=self_attn_padding_mask,
|
|
need_attn=bool((idx == alignment_layer)),
|
|
need_head_weights=bool((idx == alignment_layer)),
|
|
)
|
|
inner_states.append(x)
|
|
if layer_attn is not None and idx == alignment_layer:
|
|
attn = layer_attn.float().to(x)
|
|
|
|
if attn is not None:
|
|
if alignment_heads is not None:
|
|
attn = attn[:alignment_heads]
|
|
|
|
# average probabilities over heads
|
|
attn = attn.mean(dim=0)
|
|
|
|
if self.layer_norm is not None:
|
|
x = self.layer_norm(x)
|
|
|
|
# T x B x C -> B x T x C
|
|
x = x.transpose(0, 1)
|
|
|
|
if self.project_out_dim is not None:
|
|
x = self.project_out_dim(x)
|
|
else:
|
|
assert False
|
|
|
|
# XXX: the last change start
|
|
result = []
|
|
start = 0
|
|
for channel_size in self.channel_sizes:
|
|
end = start + channel_size
|
|
result.append(x[:, :, start:end])
|
|
start = end
|
|
assert end == x.size(-1)
|
|
# XXX: the last change end
|
|
|
|
return result, {"attn": [attn], "inner_states": inner_states}
|