mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-21 06:58:58 +00:00
571 lines
20 KiB
Python
571 lines
20 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from omegaconf import II
|
|
|
|
from fairseq import utils
|
|
from fairseq.data.data_utils import compute_mask_indices
|
|
from fairseq.data.dictionary import Dictionary
|
|
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
|
from fairseq.models import BaseFairseqModel, register_model
|
|
from fairseq.models.wav2vec.wav2vec2 import (
|
|
EXTRACTOR_MODE_CHOICES,
|
|
MASKING_DISTRIBUTION_CHOICES,
|
|
LAYER_TYPE_CHOICES,
|
|
ConvFeatureExtractionModel,
|
|
TransformerEncoder,
|
|
)
|
|
from fairseq.modules import GradMultiply, LayerNorm
|
|
from fairseq.tasks.hubert_pretraining import (
|
|
HubertPretrainingConfig,
|
|
HubertPretrainingTask,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class HubertConfig(FairseqDataclass):
|
|
label_rate: float = II("task.label_rate")
|
|
|
|
extractor_mode: EXTRACTOR_MODE_CHOICES = field(
|
|
default="default",
|
|
metadata={
|
|
"help": "mode for feature extractor. default has a single group "
|
|
"norm with d groups in the first conv block, whereas layer_norm "
|
|
"has layer norms in every block (meant to use with normalize=True)"
|
|
},
|
|
)
|
|
encoder_layers: int = field(
|
|
default=12, metadata={"help": "num encoder layers in the transformer"}
|
|
)
|
|
encoder_embed_dim: int = field(
|
|
default=768, metadata={"help": "encoder embedding dimension"}
|
|
)
|
|
encoder_ffn_embed_dim: int = field(
|
|
default=3072, metadata={"help": "encoder embedding dimension for FFN"}
|
|
)
|
|
encoder_attention_heads: int = field(
|
|
default=12, metadata={"help": "num encoder attention heads"}
|
|
)
|
|
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
|
default="gelu", metadata={"help": "activation function to use"}
|
|
)
|
|
layer_type: LAYER_TYPE_CHOICES = field(
|
|
default="transformer", metadata={"help": "layer type in encoder"}
|
|
)
|
|
|
|
# dropouts
|
|
dropout: float = field(
|
|
default=0.1,
|
|
metadata={"help": "dropout probability for the transformer"},
|
|
)
|
|
attention_dropout: float = field(
|
|
default=0.1,
|
|
metadata={"help": "dropout probability for attention weights"},
|
|
)
|
|
activation_dropout: float = field(
|
|
default=0.0,
|
|
metadata={"help": "dropout probability after activation in FFN"},
|
|
)
|
|
encoder_layerdrop: float = field(
|
|
default=0.0,
|
|
metadata={"help": "probability of dropping a tarnsformer layer"},
|
|
)
|
|
dropout_input: float = field(
|
|
default=0.0,
|
|
metadata={"help": "dropout to apply to the input (after feat extr)"},
|
|
)
|
|
dropout_features: float = field(
|
|
default=0.0,
|
|
metadata={"help": "dropout to apply to the features (after feat extr)"},
|
|
)
|
|
|
|
final_dim: int = field(
|
|
default=0,
|
|
metadata={
|
|
"help": "project final representations and targets to this many "
|
|
"dimensions. set to encoder_embed_dim is <= 0"
|
|
},
|
|
)
|
|
untie_final_proj: bool = field(
|
|
default=False,
|
|
metadata={"help": "use separate projection for each target"},
|
|
)
|
|
layer_norm_first: bool = field(
|
|
default=False,
|
|
metadata={"help": "apply layernorm first in the transformer"},
|
|
)
|
|
conv_feature_layers: str = field(
|
|
default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
|
|
metadata={
|
|
"help": "string describing convolutional feature extraction "
|
|
"layers in form of a python list that contains "
|
|
"[(dim, kernel_size, stride), ...]"
|
|
},
|
|
)
|
|
conv_bias: bool = field(
|
|
default=False, metadata={"help": "include bias in conv encoder"}
|
|
)
|
|
logit_temp: float = field(
|
|
default=0.1, metadata={"help": "temperature to divide logits by"}
|
|
)
|
|
target_glu: bool = field(
|
|
default=False, metadata={"help": "adds projection + glu to targets"}
|
|
)
|
|
feature_grad_mult: float = field(
|
|
default=1.0,
|
|
metadata={"help": "multiply feature extractor var grads by this"},
|
|
)
|
|
|
|
# masking
|
|
mask_length: int = field(default=10, metadata={"help": "mask length"})
|
|
mask_prob: float = field(
|
|
default=0.65,
|
|
metadata={"help": "probability of replacing a token with mask"},
|
|
)
|
|
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
|
default="static", metadata={"help": "how to choose mask length"}
|
|
)
|
|
mask_other: float = field(
|
|
default=0,
|
|
metadata={
|
|
"help": "secondary mask argument "
|
|
"(used for more complex distributions), "
|
|
"see help in compute_mask_indicesh"
|
|
},
|
|
)
|
|
no_mask_overlap: bool = field(
|
|
default=False, metadata={"help": "whether to allow masks to overlap"}
|
|
)
|
|
mask_min_space: int = field(
|
|
default=1,
|
|
metadata={"help": "min space between spans (if no overlap is enabled)"},
|
|
)
|
|
|
|
# channel masking
|
|
mask_channel_length: int = field(
|
|
default=10,
|
|
metadata={"help": "length of the mask for features (channels)"},
|
|
)
|
|
mask_channel_prob: float = field(
|
|
default=0.0,
|
|
metadata={"help": "probability of replacing a feature with 0"},
|
|
)
|
|
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
|
default="static",
|
|
metadata={"help": "how to choose mask length for channel masking"},
|
|
)
|
|
mask_channel_other: float = field(
|
|
default=0,
|
|
metadata={
|
|
"help": "secondary mask argument "
|
|
"(used for more complex distributions), "
|
|
"see help in compute_mask_indicesh"
|
|
},
|
|
)
|
|
no_mask_channel_overlap: bool = field(
|
|
default=False,
|
|
metadata={"help": "whether to allow channel masks to overlap"},
|
|
)
|
|
mask_channel_min_space: int = field(
|
|
default=1,
|
|
metadata={"help": "min space between spans (if no overlap is enabled)"},
|
|
)
|
|
|
|
# positional embeddings
|
|
conv_pos: int = field(
|
|
default=128,
|
|
metadata={"help": "number of filters for convolutional positional embeddings"},
|
|
)
|
|
conv_pos_groups: int = field(
|
|
default=16,
|
|
metadata={"help": "number of groups for convolutional positional embedding"},
|
|
)
|
|
|
|
latent_temp: Tuple[float, float, float] = field(
|
|
default=(2, 0.5, 0.999995),
|
|
metadata={"help": "legacy (to be removed)"},
|
|
)
|
|
|
|
# loss computation
|
|
skip_masked: bool = field(
|
|
default=False,
|
|
metadata={"help": "skip computing losses over masked frames"},
|
|
)
|
|
skip_nomask: bool = field(
|
|
default=False,
|
|
metadata={"help": "skip computing losses over unmasked frames"},
|
|
)
|
|
|
|
checkpoint_activations: bool = field(
|
|
default=False,
|
|
metadata={"help": "recompute activations and save memory for extra compute"},
|
|
)
|
|
|
|
# FP16 optimization
|
|
required_seq_len_multiple: int = field(
|
|
default=2,
|
|
metadata={
|
|
"help": "pad the input to encoder such that the sequence length is divisible by multiple"
|
|
},
|
|
)
|
|
|
|
# Conformer
|
|
depthwise_conv_kernel_size: int = field(
|
|
default=31,
|
|
metadata={
|
|
"help": "depthwise-conv-kernel-size for convolution in conformer layer"
|
|
},
|
|
)
|
|
attn_type: str = field(
|
|
default="",
|
|
metadata={"help": "if espnet use ESPNET MHA"},
|
|
)
|
|
pos_enc_type: str = field(
|
|
default="abs",
|
|
metadata={"help": "Positional encoding type to use in conformer"},
|
|
)
|
|
fp16: bool = field(default=False, metadata={"help": "If fp16 is being used"})
|
|
|
|
|
|
@register_model("hubert", dataclass=HubertConfig)
|
|
class HubertModel(BaseFairseqModel):
|
|
def __init__(
|
|
self,
|
|
cfg: HubertConfig,
|
|
task_cfg: HubertPretrainingConfig,
|
|
dictionaries: List[Dictionary],
|
|
) -> None:
|
|
super().__init__()
|
|
logger.info(f"HubertModel Config: {cfg}")
|
|
|
|
feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
|
|
self.embed = feature_enc_layers[-1][0]
|
|
|
|
self.feature_extractor = ConvFeatureExtractionModel(
|
|
conv_layers=feature_enc_layers,
|
|
dropout=0.0,
|
|
mode=cfg.extractor_mode,
|
|
conv_bias=cfg.conv_bias,
|
|
)
|
|
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
|
|
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
|
|
|
|
self.post_extract_proj = (
|
|
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
|
if self.embed != cfg.encoder_embed_dim
|
|
else None
|
|
)
|
|
|
|
self.mask_prob = cfg.mask_prob
|
|
self.mask_selection = cfg.mask_selection
|
|
self.mask_other = cfg.mask_other
|
|
self.mask_length = cfg.mask_length
|
|
self.no_mask_overlap = cfg.no_mask_overlap
|
|
self.mask_min_space = cfg.mask_min_space
|
|
|
|
self.mask_channel_prob = cfg.mask_channel_prob
|
|
self.mask_channel_selection = cfg.mask_channel_selection
|
|
self.mask_channel_other = cfg.mask_channel_other
|
|
self.mask_channel_length = cfg.mask_channel_length
|
|
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
|
self.mask_channel_min_space = cfg.mask_channel_min_space
|
|
|
|
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
|
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
|
|
|
self.feature_grad_mult = cfg.feature_grad_mult
|
|
self.logit_temp = cfg.logit_temp
|
|
self.skip_masked = cfg.skip_masked
|
|
self.skip_nomask = cfg.skip_nomask
|
|
|
|
final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
|
|
|
|
self.mask_emb = nn.Parameter(
|
|
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
|
)
|
|
|
|
self.encoder = TransformerEncoder(cfg)
|
|
self.layer_norm = LayerNorm(self.embed)
|
|
|
|
self.target_glu = None
|
|
if cfg.target_glu:
|
|
self.target_glu = nn.Sequential(
|
|
nn.Linear(final_dim, final_dim * 2), nn.GLU()
|
|
)
|
|
|
|
self.untie_final_proj = cfg.untie_final_proj
|
|
if self.untie_final_proj:
|
|
self.final_proj = nn.Linear(
|
|
cfg.encoder_embed_dim, final_dim * len(dictionaries)
|
|
)
|
|
else:
|
|
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
|
|
|
|
# modules below are not needed during fine-tuning
|
|
if any([d is None for d in dictionaries]):
|
|
logger.info("cannot find dictionary. assume will be used for fine-tuning")
|
|
else:
|
|
self.num_classes = [len(d) for d in dictionaries]
|
|
self.label_embs_concat = nn.Parameter(
|
|
torch.FloatTensor(sum(self.num_classes), final_dim)
|
|
)
|
|
nn.init.uniform_(self.label_embs_concat)
|
|
|
|
def upgrade_state_dict_named(self, state_dict, name):
|
|
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
|
|
|
super().upgrade_state_dict_named(state_dict, name)
|
|
return state_dict
|
|
|
|
@classmethod
|
|
def build_model(cls, cfg: HubertConfig, task: HubertPretrainingTask):
|
|
"""Build a new model instance."""
|
|
|
|
model = HubertModel(cfg, task.cfg, task.dictionaries)
|
|
return model
|
|
|
|
def apply_mask(self, x, padding_mask, target_list):
|
|
B, T, C = x.shape
|
|
if self.mask_prob > 0:
|
|
mask_indices = compute_mask_indices(
|
|
(B, T),
|
|
padding_mask,
|
|
self.mask_prob,
|
|
self.mask_length,
|
|
self.mask_selection,
|
|
self.mask_other,
|
|
min_masks=2,
|
|
no_overlap=self.no_mask_overlap,
|
|
min_space=self.mask_min_space,
|
|
)
|
|
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
|
x[mask_indices] = self.mask_emb
|
|
else:
|
|
mask_indices = None
|
|
|
|
if self.mask_channel_prob > 0:
|
|
mask_channel_indices = compute_mask_indices(
|
|
(B, C),
|
|
None,
|
|
self.mask_channel_prob,
|
|
self.mask_channel_length,
|
|
self.mask_channel_selection,
|
|
self.mask_channel_other,
|
|
no_overlap=self.no_mask_channel_overlap,
|
|
min_space=self.mask_channel_min_space,
|
|
)
|
|
mask_channel_indices = (
|
|
torch.from_numpy(mask_channel_indices)
|
|
.to(x.device)
|
|
.unsqueeze(1)
|
|
.expand(-1, T, -1)
|
|
)
|
|
x[mask_channel_indices] = 0
|
|
|
|
return x, mask_indices
|
|
|
|
def compute_nce(self, x, pos, negs):
|
|
neg_is_pos = (pos == negs).all(-1)
|
|
pos = pos.unsqueeze(0)
|
|
targets = torch.cat([pos, negs], dim=0)
|
|
|
|
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
|
|
logits /= self.logit_temp
|
|
if neg_is_pos.any():
|
|
logits[1:][neg_is_pos] = float("-inf")
|
|
logits = logits.transpose(0, 1) # (num_x, num_cls+1)
|
|
return logits
|
|
|
|
def forward_features(self, source: torch.Tensor) -> torch.Tensor:
|
|
if self.feature_grad_mult > 0:
|
|
features = self.feature_extractor(source)
|
|
if self.feature_grad_mult != 1.0:
|
|
features = GradMultiply.apply(features, self.feature_grad_mult)
|
|
else:
|
|
with torch.no_grad():
|
|
features = self.feature_extractor(source)
|
|
return features
|
|
|
|
def forward_targets(
|
|
self,
|
|
features: torch.Tensor,
|
|
target_list: List[torch.Tensor],
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
# Trim features to ensure labels exist and then get aligned labels
|
|
feat_tsz = features.size(2)
|
|
targ_tsz = min([t.size(1) for t in target_list])
|
|
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
|
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
|
features = features[..., :feat_tsz]
|
|
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
|
|
target_list = [t[:, target_inds.long()] for t in target_list]
|
|
return features, target_list
|
|
|
|
def forward_padding_mask(
|
|
self,
|
|
features: torch.Tensor,
|
|
padding_mask: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
extra = padding_mask.size(1) % features.size(1)
|
|
if extra > 0:
|
|
padding_mask = padding_mask[:, :-extra]
|
|
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
|
|
padding_mask = padding_mask.all(-1)
|
|
return padding_mask
|
|
|
|
def forward(
|
|
self,
|
|
source: torch.Tensor,
|
|
target_list: Optional[List[torch.Tensor]] = None,
|
|
padding_mask: Optional[torch.Tensor] = None,
|
|
mask: bool = True,
|
|
features_only: bool = False,
|
|
output_layer: Optional[int] = None,
|
|
) -> Dict[str, torch.Tensor]:
|
|
"""output layer is 1-based"""
|
|
features = self.forward_features(source)
|
|
if target_list is not None:
|
|
features, target_list = self.forward_targets(features, target_list)
|
|
|
|
features_pen = features.float().pow(2).mean()
|
|
|
|
features = features.transpose(1, 2)
|
|
features = self.layer_norm(features)
|
|
unmasked_features = features.clone()
|
|
|
|
if padding_mask is not None:
|
|
padding_mask = self.forward_padding_mask(features, padding_mask)
|
|
|
|
if self.post_extract_proj is not None:
|
|
features = self.post_extract_proj(features)
|
|
|
|
features = self.dropout_input(features)
|
|
unmasked_features = self.dropout_features(unmasked_features)
|
|
|
|
if mask:
|
|
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
|
|
else:
|
|
x = features
|
|
mask_indices = None
|
|
|
|
# feature: (B, T, D), float
|
|
# target: (B, T), long
|
|
# x: (B, T, D), float
|
|
# padding_mask: (B, T), bool
|
|
# mask_indices: (B, T), bool
|
|
x, _ = self.encoder(
|
|
x,
|
|
padding_mask=padding_mask,
|
|
layer=None if output_layer is None else output_layer - 1,
|
|
)
|
|
|
|
if features_only:
|
|
return {"x": x, "padding_mask": padding_mask, "features": features}
|
|
|
|
def compute_pred(proj_x, target, label_embs):
|
|
# compute logits for the i-th label set
|
|
y = torch.index_select(label_embs, 0, target.long())
|
|
negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
|
|
if self.target_glu:
|
|
y = self.target_glu(y)
|
|
negs = self.target_glu(negs)
|
|
# proj_x: (S, D)
|
|
# y: (S, D)
|
|
# negs: (Neg, S, D)
|
|
return self.compute_nce(proj_x, y, negs)
|
|
|
|
label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
|
|
|
|
if not self.skip_masked:
|
|
masked_indices = torch.logical_and(~padding_mask, mask_indices)
|
|
proj_x_m = self.final_proj(x[masked_indices])
|
|
if self.untie_final_proj:
|
|
proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1)
|
|
else:
|
|
proj_x_m_list = [proj_x_m for _ in range(len(target_list))]
|
|
logit_m_list = [
|
|
compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
|
|
for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list))
|
|
]
|
|
else:
|
|
logit_m_list = [None for _ in target_list]
|
|
|
|
if not self.skip_nomask:
|
|
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
|
|
proj_x_u = self.final_proj(x[nomask_indices])
|
|
if self.untie_final_proj:
|
|
proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1)
|
|
else:
|
|
proj_x_u_list = [proj_x_u for _ in range(len(target_list))]
|
|
|
|
logit_u_list = [
|
|
compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i])
|
|
for i, (proj_x_u, t) in enumerate(zip(proj_x_u_list, target_list))
|
|
]
|
|
else:
|
|
logit_u_list = [None for _ in target_list]
|
|
|
|
result = {
|
|
"logit_m_list": logit_m_list,
|
|
"logit_u_list": logit_u_list,
|
|
"padding_mask": padding_mask,
|
|
"features_pen": features_pen,
|
|
}
|
|
return result
|
|
|
|
def extract_features(
|
|
self,
|
|
source: torch.Tensor,
|
|
padding_mask: Optional[torch.Tensor] = None,
|
|
mask: bool = False,
|
|
ret_conv: bool = False,
|
|
output_layer: Optional[int] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
res = self.forward(
|
|
source,
|
|
padding_mask=padding_mask,
|
|
mask=mask,
|
|
features_only=True,
|
|
output_layer=output_layer,
|
|
)
|
|
feature = res["features"] if ret_conv else res["x"]
|
|
return feature, res["padding_mask"]
|
|
|
|
def get_logits(self, net_output, is_masked=True):
|
|
if is_masked:
|
|
logits_list = net_output["logit_m_list"]
|
|
else:
|
|
logits_list = net_output["logit_u_list"]
|
|
logits_list = [x.float() for x in logits_list if x is not None]
|
|
return logits_list
|
|
|
|
def get_targets(self, net_output, is_masked=True):
|
|
logits_list = self.get_logits(net_output, is_masked)
|
|
targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
|
|
return targets_list
|
|
|
|
def get_extra_losses(self, net_output):
|
|
extra_losses = []
|
|
names = []
|
|
|
|
if "features_pen" in net_output:
|
|
extra_losses.append(net_output["features_pen"])
|
|
names.append("features_pen")
|
|
|
|
return extra_losses, names
|
|
|
|
def remove_pretraining_modules(self):
|
|
self.target_glu = None
|
|
self.final_proj = None
|