mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-26 09:28:57 +00:00
757 lines
26 KiB
Python
757 lines
26 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import contextlib
|
|
import copy
|
|
import logging
|
|
import math
|
|
import re
|
|
from argparse import Namespace
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from omegaconf import II, MISSING, open_dict
|
|
|
|
from fairseq import checkpoint_utils, tasks, utils
|
|
from fairseq.dataclass import FairseqDataclass
|
|
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
|
from fairseq.models import (
|
|
BaseFairseqModel,
|
|
FairseqEncoder,
|
|
FairseqEncoderDecoderModel,
|
|
FairseqIncrementalDecoder,
|
|
register_model,
|
|
)
|
|
from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES
|
|
from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerDecoderLayer
|
|
from fairseq.tasks import FairseqTask
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class Wav2Vec2AsrConfig(FairseqDataclass):
|
|
w2v_path: str = field(
|
|
default=MISSING, metadata={"help": "path to wav2vec 2.0 model"}
|
|
)
|
|
no_pretrained_weights: bool = field(
|
|
default=False, metadata={"help": "if true, does not load pretrained weights"}
|
|
)
|
|
dropout_input: float = field(
|
|
default=0.0,
|
|
metadata={"help": "dropout to apply to the input (after feat extr)"},
|
|
)
|
|
final_dropout: float = field(
|
|
default=0.0,
|
|
metadata={"help": "dropout after transformer and before final projection"},
|
|
)
|
|
dropout: float = field(
|
|
default=0.0, metadata={"help": "dropout probability inside wav2vec 2.0 model"}
|
|
)
|
|
attention_dropout: float = field(
|
|
default=0.0,
|
|
metadata={
|
|
"help": "dropout probability for attention weights inside wav2vec 2.0 model"
|
|
},
|
|
)
|
|
activation_dropout: float = field(
|
|
default=0.0,
|
|
metadata={
|
|
"help": "dropout probability after activation in FFN inside wav2vec 2.0 model"
|
|
},
|
|
)
|
|
conv_feature_layers: Optional[str] = field(
|
|
default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
|
|
metadata={
|
|
"help": (
|
|
"string describing convolutional feature extraction "
|
|
"layers in form of a python list that contains "
|
|
"[(dim, kernel_size, stride), ...]"
|
|
),
|
|
},
|
|
)
|
|
encoder_embed_dim: Optional[int] = field(
|
|
default=768, metadata={"help": "encoder embedding dimension"}
|
|
)
|
|
|
|
# masking
|
|
apply_mask: bool = field(
|
|
default=False, metadata={"help": "apply masking during fine-tuning"}
|
|
)
|
|
mask_length: int = field(
|
|
default=10, metadata={"help": "repeat the mask indices multiple times"}
|
|
)
|
|
mask_prob: float = field(
|
|
default=0.5,
|
|
metadata={
|
|
"help": "probability of replacing a token with mask (normalized by length)"
|
|
},
|
|
)
|
|
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
|
default="static", metadata={"help": "how to choose masks"}
|
|
)
|
|
mask_other: float = field(
|
|
default=0,
|
|
metadata={
|
|
"help": "secondary mask argument (used for more complex distributions), "
|
|
"see help in compute_mask_indices"
|
|
},
|
|
)
|
|
no_mask_overlap: bool = field(
|
|
default=False, metadata={"help": "whether to allow masks to overlap"}
|
|
)
|
|
mask_min_space: Optional[int] = field(
|
|
default=1,
|
|
metadata={"help": "min space between spans (if no overlap is enabled)"},
|
|
)
|
|
require_same_masks: bool = field(
|
|
default=True,
|
|
metadata={
|
|
"help": "whether to number of masked timesteps must be the same across all "
|
|
"examples in a batch"
|
|
},
|
|
)
|
|
mask_dropout: float = field(
|
|
default=0.0,
|
|
metadata={"help": "percent of masks to unmask for each sample"},
|
|
)
|
|
|
|
# channel masking
|
|
mask_channel_length: int = field(
|
|
default=10, metadata={"help": "length of the mask for features (channels)"}
|
|
)
|
|
mask_channel_prob: float = field(
|
|
default=0.0, metadata={"help": "probability of replacing a feature with 0"}
|
|
)
|
|
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
|
default="static",
|
|
metadata={"help": "how to choose mask length for channel masking"},
|
|
)
|
|
mask_channel_other: float = field(
|
|
default=0,
|
|
metadata={
|
|
"help": "secondary mask argument (used for more complex distributions), "
|
|
"see help in compute_mask_indicesh"
|
|
},
|
|
)
|
|
no_mask_channel_overlap: bool = field(
|
|
default=False, metadata={"help": "whether to allow channel masks to overlap"}
|
|
)
|
|
freeze_finetune_updates: int = field(
|
|
default=0, metadata={"help": "dont finetune wav2vec for this many updates"}
|
|
)
|
|
feature_grad_mult: float = field(
|
|
default=0.0, metadata={"help": "reset feature grad mult in wav2vec 2.0 to this"}
|
|
)
|
|
layerdrop: float = field(
|
|
default=0.0, metadata={"help": "probability of dropping a layer in wav2vec 2.0"}
|
|
)
|
|
mask_channel_min_space: Optional[int] = field(
|
|
default=1,
|
|
metadata={"help": "min space between spans (if no overlap is enabled)"},
|
|
)
|
|
mask_channel_before: bool = False
|
|
normalize: bool = II("task.normalize")
|
|
data: str = II("task.data")
|
|
# this holds the loaded wav2vec args
|
|
w2v_args: Any = None
|
|
offload_activations: bool = field(
|
|
default=False, metadata={"help": "offload_activations"}
|
|
)
|
|
min_params_to_wrap: int = field(
|
|
default=int(1e8),
|
|
metadata={
|
|
"help": "minimum number of params for a layer to be wrapped with FSDP() when "
|
|
"training with --ddp-backend=fully_sharded. Smaller values will "
|
|
"improve memory efficiency, but may make torch.distributed "
|
|
"communication less efficient due to smaller input sizes. This option "
|
|
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
|
|
"--offload-activations are passed."
|
|
},
|
|
)
|
|
|
|
checkpoint_activations: bool = field(
|
|
default=False,
|
|
metadata={"help": "recompute activations and save memory for extra compute"},
|
|
)
|
|
ddp_backend: str = II("distributed_training.ddp_backend")
|
|
|
|
|
|
@dataclass
|
|
class Wav2Vec2CtcConfig(Wav2Vec2AsrConfig):
|
|
blank_weight: float = 0
|
|
blank_mode: str = "add"
|
|
|
|
|
|
@register_model("wav2vec_ctc", dataclass=Wav2Vec2CtcConfig)
|
|
class Wav2VecCtc(BaseFairseqModel):
|
|
def __init__(self, cfg: Wav2Vec2CtcConfig, w2v_encoder: BaseFairseqModel):
|
|
super().__init__()
|
|
self.cfg = cfg
|
|
self.w2v_encoder = w2v_encoder
|
|
self.blank_weight = cfg.blank_weight
|
|
self.blank_mode = cfg.blank_mode
|
|
|
|
def upgrade_state_dict_named(self, state_dict, name):
|
|
super().upgrade_state_dict_named(state_dict, name)
|
|
return state_dict
|
|
|
|
@classmethod
|
|
def build_model(cls, cfg: Wav2Vec2CtcConfig, task: FairseqTask):
|
|
"""Build a new model instance."""
|
|
w2v_encoder = Wav2VecEncoder(cfg, len(task.target_dictionary))
|
|
return cls(cfg, w2v_encoder)
|
|
|
|
def get_logits(self, net_output, normalize=False):
|
|
logits = net_output["encoder_out"]
|
|
if self.blank_weight != 0:
|
|
if self.blank_mode == "add":
|
|
logits[..., 0] += self.blank_weight
|
|
elif self.blank_mode == "set":
|
|
logits[..., 0] = self.blank_weight
|
|
else:
|
|
raise Exception(f"invalid blank mode {self.blank_mode}")
|
|
|
|
if net_output["padding_mask"] is not None and net_output["padding_mask"].any():
|
|
number_of_classes = logits.size(-1)
|
|
masking_tensor = torch.ones(
|
|
number_of_classes, device=logits.device
|
|
) * float("-inf")
|
|
masking_tensor[0] = 0
|
|
logits[net_output["padding_mask"].T] = masking_tensor.type_as(logits)
|
|
|
|
if normalize:
|
|
logits = utils.log_softmax(logits.float(), dim=-1)
|
|
|
|
return logits
|
|
|
|
def get_normalized_probs(self, net_output, log_probs):
|
|
"""Get normalized probabilities (or log probs) from a net's output."""
|
|
|
|
logits = self.get_logits(net_output)
|
|
|
|
if log_probs:
|
|
return utils.log_softmax(logits.float(), dim=-1)
|
|
else:
|
|
return utils.softmax(logits.float(), dim=-1)
|
|
|
|
def forward(self, **kwargs):
|
|
x = self.w2v_encoder(**kwargs)
|
|
return x
|
|
|
|
|
|
@dataclass
|
|
class Wav2Vec2Seq2SeqConfig(Wav2Vec2AsrConfig):
|
|
decoder_embed_dim: int = field(
|
|
default=768, metadata={"help": "decoder embedding dimension"}
|
|
)
|
|
decoder_ffn_embed_dim: int = field(
|
|
default=3072, metadata={"help": "decoder embedding dimension for FFN"}
|
|
)
|
|
decoder_layers: int = field(default=6, metadata={"help": "num of decoder layers"})
|
|
decoder_layerdrop: float = field(
|
|
default=0.0, metadata={"help": "decoder layerdrop chance"}
|
|
)
|
|
decoder_attention_heads: int = field(
|
|
default=4, metadata={"help": "num decoder attention heads"}
|
|
)
|
|
decoder_learned_pos: bool = field(
|
|
default=False,
|
|
metadata={"help": "use learned positional embeddings in the decoder"},
|
|
)
|
|
decoder_normalize_before: bool = field(
|
|
default=False, metadata={"help": "apply layernorm before each decoder block"}
|
|
)
|
|
no_token_positional_embeddings: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "if set, disables positional embeddings (outside self attention)"
|
|
},
|
|
)
|
|
decoder_dropout: float = field(
|
|
default=0.0, metadata={"help": "dropout probability in the decoder"}
|
|
)
|
|
decoder_attention_dropout: float = field(
|
|
default=0.0,
|
|
metadata={
|
|
"help": "dropout probability for attention weights inside the decoder"
|
|
},
|
|
)
|
|
decoder_activation_dropout: float = field(
|
|
default=0.0,
|
|
metadata={
|
|
"help": "dropout probability after activation in FFN inside the decoder"
|
|
},
|
|
)
|
|
max_target_positions: int = field(
|
|
default=2048, metadata={"help": "max target positions"}
|
|
)
|
|
share_decoder_input_output_embed: bool = field(
|
|
default=False, metadata={"help": "share decoder input and output embeddings"}
|
|
)
|
|
autoregressive: bool = II("task.autoregressive")
|
|
|
|
|
|
@register_model("wav2vec_seq2seq", dataclass=Wav2Vec2Seq2SeqConfig)
|
|
class Wav2Vec2Seq2SeqModel(FairseqEncoderDecoderModel):
|
|
def __init__(self, encoder, decoder):
|
|
super().__init__(encoder, decoder)
|
|
|
|
@classmethod
|
|
def build_model(cls, cfg: Wav2Vec2Seq2SeqConfig, task: FairseqTask):
|
|
"""Build a new model instance."""
|
|
|
|
assert (
|
|
cfg.autoregressive
|
|
), "Please set task.autoregressive=true for seq2seq asr models"
|
|
|
|
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
|
|
|
|
def build_embedding(dictionary, embed_dim):
|
|
num_embeddings = len(dictionary)
|
|
padding_idx = dictionary.pad()
|
|
emb = Embedding(num_embeddings, embed_dim, padding_idx)
|
|
return emb
|
|
|
|
decoder_embed_tokens = build_embedding(tgt_dict, cfg.decoder_embed_dim)
|
|
|
|
encoder = cls.build_encoder(cfg)
|
|
decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens)
|
|
|
|
return Wav2Vec2Seq2SeqModel(encoder, decoder)
|
|
|
|
@classmethod
|
|
def build_encoder(cls, cfg: Wav2Vec2AsrConfig):
|
|
return Wav2VecEncoder(cfg)
|
|
|
|
@classmethod
|
|
def build_decoder(cls, cfg: Wav2Vec2Seq2SeqConfig, tgt_dict, embed_tokens):
|
|
return TransformerDecoder(cfg, tgt_dict, embed_tokens)
|
|
|
|
def forward(self, **kwargs):
|
|
encoder_out = self.encoder(**kwargs)
|
|
decoder_out = self.decoder(encoder_out=encoder_out, **kwargs)
|
|
return decoder_out
|
|
|
|
def upgrade_state_dict_named(self, state_dict, name):
|
|
super().upgrade_state_dict_named(state_dict, name)
|
|
return state_dict
|
|
|
|
|
|
class Wav2VecEncoder(FairseqEncoder):
|
|
def __init__(self, cfg: Wav2Vec2AsrConfig, output_size=None):
|
|
self.apply_mask = cfg.apply_mask
|
|
|
|
arg_overrides = {
|
|
"dropout": cfg.dropout,
|
|
"activation_dropout": cfg.activation_dropout,
|
|
"dropout_input": cfg.dropout_input,
|
|
"attention_dropout": cfg.attention_dropout,
|
|
"mask_length": cfg.mask_length,
|
|
"mask_prob": cfg.mask_prob,
|
|
"require_same_masks": getattr(cfg, "require_same_masks", True),
|
|
"pct_holes": getattr(cfg, "mask_dropout", 0),
|
|
"mask_selection": cfg.mask_selection,
|
|
"mask_other": cfg.mask_other,
|
|
"no_mask_overlap": cfg.no_mask_overlap,
|
|
"mask_channel_length": cfg.mask_channel_length,
|
|
"mask_channel_prob": cfg.mask_channel_prob,
|
|
"mask_channel_before": cfg.mask_channel_before,
|
|
"mask_channel_selection": cfg.mask_channel_selection,
|
|
"mask_channel_other": cfg.mask_channel_other,
|
|
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
|
|
"encoder_layerdrop": cfg.layerdrop,
|
|
"feature_grad_mult": cfg.feature_grad_mult,
|
|
"checkpoint_activations": cfg.checkpoint_activations,
|
|
"offload_activations": cfg.offload_activations,
|
|
"min_params_to_wrap": cfg.min_params_to_wrap,
|
|
}
|
|
|
|
if cfg.w2v_args is None:
|
|
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides)
|
|
w2v_args = state.get("cfg", None)
|
|
if w2v_args is None:
|
|
w2v_args = convert_namespace_to_omegaconf(state["args"])
|
|
w2v_args.criterion = None
|
|
w2v_args.lr_scheduler = None
|
|
cfg.w2v_args = w2v_args
|
|
|
|
logger.info(w2v_args)
|
|
|
|
else:
|
|
state = None
|
|
w2v_args = cfg.w2v_args
|
|
if isinstance(w2v_args, Namespace):
|
|
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args)
|
|
|
|
model_normalized = w2v_args.task.get(
|
|
"normalize", w2v_args.model.get("normalize", False)
|
|
)
|
|
assert cfg.normalize == model_normalized, (
|
|
"Fine-tuning works best when data normalization is the same. "
|
|
"Please check that --normalize is set or unset for both pre-training and here"
|
|
)
|
|
|
|
if hasattr(cfg, "checkpoint_activations") and cfg.checkpoint_activations:
|
|
with open_dict(w2v_args):
|
|
w2v_args.model.checkpoint_activations = cfg.checkpoint_activations
|
|
|
|
w2v_args.task.data = cfg.data
|
|
task = tasks.setup_task(w2v_args.task)
|
|
model = task.build_model(w2v_args.model, from_checkpoint=True)
|
|
|
|
model.remove_pretraining_modules()
|
|
|
|
if state is not None and not cfg.no_pretrained_weights:
|
|
self.load_model_weights(state, model, cfg)
|
|
|
|
super().__init__(task.source_dictionary)
|
|
|
|
d = w2v_args.model.encoder_embed_dim
|
|
|
|
self.w2v_model = model
|
|
|
|
self.final_dropout = nn.Dropout(cfg.final_dropout)
|
|
self.freeze_finetune_updates = cfg.freeze_finetune_updates
|
|
self.num_updates = 0
|
|
|
|
targ_d = None
|
|
self.proj = None
|
|
|
|
if output_size is not None:
|
|
targ_d = output_size
|
|
elif getattr(cfg, "decoder_embed_dim", d) != d:
|
|
targ_d = cfg.decoder_embed_dim
|
|
|
|
if targ_d is not None:
|
|
self.proj = Linear(d, targ_d)
|
|
|
|
def load_model_weights(self, state, model, cfg):
|
|
if cfg.ddp_backend == "fully_sharded":
|
|
from fairseq.distributed import FullyShardedDataParallel
|
|
|
|
for name, module in model.named_modules():
|
|
if "encoder.layers" in name and len(name.split(".")) == 3:
|
|
# Only for layers, we do a special handling and load the weights one by one
|
|
# We dont load all weights together as that wont be memory efficient and may
|
|
# cause oom
|
|
new_dict = {
|
|
k.replace(name + ".", ""): v
|
|
for (k, v) in state["model"].items()
|
|
if name + "." in k
|
|
}
|
|
assert isinstance(module, FullyShardedDataParallel)
|
|
with module.summon_full_params():
|
|
module.load_state_dict(new_dict, strict=True)
|
|
module._reset_lazy_init()
|
|
|
|
# Once layers are loaded, filter them out and load everything else.
|
|
r = re.compile("encoder.layers.\d.")
|
|
filtered_list = list(filter(r.match, state["model"].keys()))
|
|
|
|
new_big_dict = {
|
|
k: v for (k, v) in state["model"].items() if k not in filtered_list
|
|
}
|
|
|
|
model.load_state_dict(new_big_dict, strict=False)
|
|
else:
|
|
if "_ema" in state["model"]:
|
|
del state["model"]["_ema"]
|
|
model.load_state_dict(state["model"], strict=True)
|
|
|
|
def set_num_updates(self, num_updates):
|
|
"""Set the number of parameters updates."""
|
|
super().set_num_updates(num_updates)
|
|
self.num_updates = num_updates
|
|
|
|
def forward(self, source, padding_mask, **kwargs):
|
|
|
|
w2v_args = {
|
|
"source": source,
|
|
"padding_mask": padding_mask,
|
|
"mask": self.apply_mask and self.training,
|
|
}
|
|
|
|
ft = self.freeze_finetune_updates <= self.num_updates
|
|
|
|
with torch.no_grad() if not ft else contextlib.ExitStack():
|
|
res = self.w2v_model.extract_features(**w2v_args)
|
|
|
|
x = res["x"]
|
|
padding_mask = res["padding_mask"]
|
|
|
|
# B x T x C -> T x B x C
|
|
x = x.transpose(0, 1)
|
|
|
|
x = self.final_dropout(x)
|
|
|
|
if self.proj:
|
|
x = self.proj(x)
|
|
|
|
return {
|
|
"encoder_out": x, # T x B x C
|
|
"padding_mask": padding_mask, # B x T,
|
|
"layer_results": res["layer_results"],
|
|
}
|
|
|
|
def forward_torchscript(self, net_input):
|
|
if torch.jit.is_scripting():
|
|
return self.forward(net_input["source"], net_input["padding_mask"])
|
|
else:
|
|
return self.forward_non_torchscript(net_input)
|
|
|
|
def reorder_encoder_out(self, encoder_out, new_order):
|
|
if encoder_out["encoder_out"] is not None:
|
|
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
|
|
1, new_order
|
|
)
|
|
if encoder_out["padding_mask"] is not None:
|
|
encoder_out["padding_mask"] = encoder_out["padding_mask"].index_select(
|
|
0, new_order
|
|
)
|
|
return encoder_out
|
|
|
|
def max_positions(self):
|
|
"""Maximum input length supported by the encoder."""
|
|
return None
|
|
|
|
def upgrade_state_dict_named(self, state_dict, name):
|
|
return state_dict
|
|
|
|
|
|
class TransformerDecoder(FairseqIncrementalDecoder):
|
|
"""
|
|
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
|
|
is a :class:`TransformerDecoderLayer`.
|
|
|
|
Args:
|
|
args (argparse.Namespace): parsed command-line arguments
|
|
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
|
embed_tokens (torch.nn.Embedding): output embedding
|
|
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
|
(default: False).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
cfg: Wav2Vec2Seq2SeqConfig,
|
|
dictionary,
|
|
embed_tokens,
|
|
no_encoder_attn=False,
|
|
):
|
|
super().__init__(dictionary)
|
|
|
|
self.dropout = cfg.decoder_dropout
|
|
self.share_input_output_embed = cfg.share_decoder_input_output_embed
|
|
|
|
input_embed_dim = embed_tokens.embedding_dim
|
|
embed_dim = cfg.decoder_embed_dim
|
|
self.output_embed_dim = cfg.decoder_embed_dim
|
|
|
|
self.layerdrop = cfg.decoder_layerdrop
|
|
|
|
self.padding_idx = embed_tokens.padding_idx
|
|
self.max_target_positions = cfg.max_target_positions
|
|
|
|
self.embed_tokens = embed_tokens
|
|
self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
|
|
|
|
self.project_in_dim = (
|
|
Linear(input_embed_dim, embed_dim, bias=False)
|
|
if embed_dim != input_embed_dim
|
|
else None
|
|
)
|
|
|
|
self.embed_positions = (
|
|
PositionalEmbedding(
|
|
cfg.max_target_positions,
|
|
embed_dim,
|
|
self.padding_idx,
|
|
learned=cfg.decoder_learned_pos,
|
|
)
|
|
if not cfg.no_token_positional_embeddings
|
|
else None
|
|
)
|
|
|
|
# TODO: update this when transformer gets converted to dataclass configs
|
|
transformer_cfg = copy.deepcopy(cfg)
|
|
with open_dict(transformer_cfg):
|
|
transformer_cfg.dropout = transformer_cfg.decoder_dropout
|
|
transformer_cfg.attention_dropout = (
|
|
transformer_cfg.decoder_attention_dropout
|
|
)
|
|
transformer_cfg.activation_dropout = (
|
|
transformer_cfg.decoder_activation_dropout
|
|
)
|
|
|
|
self.layers = nn.ModuleList([])
|
|
self.layers.extend(
|
|
[
|
|
TransformerDecoderLayer(transformer_cfg, no_encoder_attn)
|
|
for _ in range(transformer_cfg.decoder_layers)
|
|
]
|
|
)
|
|
|
|
if not self.share_input_output_embed:
|
|
self.embed_out = nn.Parameter(
|
|
torch.Tensor(len(dictionary), self.output_embed_dim)
|
|
)
|
|
nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim**-0.5)
|
|
|
|
if transformer_cfg.decoder_normalize_before:
|
|
self.layer_norm = LayerNorm(embed_dim)
|
|
else:
|
|
self.layer_norm = None
|
|
|
|
def forward(
|
|
self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused
|
|
):
|
|
"""
|
|
Args:
|
|
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
|
`(batch, tgt_len)`, for teacher forcing
|
|
encoder_out (Tensor, optional): output from the encoder, used for
|
|
encoder-side attention
|
|
incremental_state (dict): dictionary used for storing state during
|
|
:ref:`Incremental decoding`
|
|
|
|
Returns:
|
|
tuple:
|
|
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
|
- a dictionary with any model-specific outputs
|
|
"""
|
|
if type(prev_output_tokens) == list:
|
|
max_len = max((len(x) for x in prev_output_tokens))
|
|
tmp = torch.zeros(
|
|
[len(prev_output_tokens), max_len], device=prev_output_tokens[0].device
|
|
)
|
|
for (i, p) in enumerate(prev_output_tokens):
|
|
tmp[i, : len(p)] = p
|
|
prev_output_tokens = tmp
|
|
prev_output_tokens = prev_output_tokens.long()
|
|
x, extra = self.extract_features(
|
|
prev_output_tokens, encoder_out, incremental_state
|
|
)
|
|
x = self.output_layer(x)
|
|
return x, extra
|
|
|
|
def extract_features(
|
|
self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused
|
|
):
|
|
"""
|
|
Similar to *forward* but only return features.
|
|
|
|
Returns:
|
|
tuple:
|
|
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
|
- a dictionary with any model-specific outputs
|
|
"""
|
|
|
|
# embed positions
|
|
positions = (
|
|
self.embed_positions(
|
|
prev_output_tokens, incremental_state=incremental_state
|
|
)
|
|
if self.embed_positions is not None
|
|
else None
|
|
)
|
|
|
|
if incremental_state is not None:
|
|
prev_output_tokens = prev_output_tokens[:, -1:]
|
|
if positions is not None:
|
|
positions = positions[:, -1:]
|
|
|
|
# embed tokens and positions
|
|
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
|
|
|
|
if self.project_in_dim is not None:
|
|
x = self.project_in_dim(x)
|
|
|
|
if positions is not None:
|
|
x += positions
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
|
|
# B x T x C -> T x B x C
|
|
x = x.transpose(0, 1)
|
|
attn = None
|
|
|
|
inner_states = [x]
|
|
|
|
# decoder layers
|
|
self_attn_padding_mask = None
|
|
if prev_output_tokens.eq(self.padding_idx).any():
|
|
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
|
|
for layer in self.layers:
|
|
dropout_probability = np.random.random()
|
|
if not self.training or (dropout_probability > self.layerdrop):
|
|
x, attn, _ = layer(
|
|
x,
|
|
encoder_out["encoder_out"] if encoder_out is not None else None,
|
|
encoder_out["padding_mask"] if encoder_out is not None else None,
|
|
incremental_state,
|
|
self_attn_mask=self.buffered_future_mask(x)
|
|
if incremental_state is None
|
|
else None,
|
|
self_attn_padding_mask=self_attn_padding_mask,
|
|
)
|
|
inner_states.append(x)
|
|
|
|
if self.layer_norm:
|
|
x = self.layer_norm(x)
|
|
|
|
# T x B x C -> B x T x C
|
|
x = x.transpose(0, 1)
|
|
|
|
return x, {"attn": attn, "inner_states": inner_states}
|
|
|
|
def output_layer(self, features, **kwargs):
|
|
"""Project features to the vocabulary size."""
|
|
# project back to size of vocabulary
|
|
if self.share_input_output_embed:
|
|
return F.linear(features, self.embed_tokens.weight)
|
|
else:
|
|
return F.linear(features, self.embed_out)
|
|
|
|
def max_positions(self):
|
|
"""Maximum output length supported by the decoder."""
|
|
if self.embed_positions is None:
|
|
return self.max_target_positions
|
|
return min(self.max_target_positions, self.embed_positions.max_positions)
|
|
|
|
def buffered_future_mask(self, tensor):
|
|
dim = tensor.size(0)
|
|
if (
|
|
not hasattr(self, "_future_mask")
|
|
or self._future_mask is None
|
|
or self._future_mask.device != tensor.device
|
|
or self._future_mask.size(0) < dim
|
|
):
|
|
self._future_mask = torch.triu(
|
|
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
|
|
)
|
|
return self._future_mask[:dim, :dim]
|
|
|
|
def upgrade_state_dict_named(self, state_dict, name):
|
|
return state_dict
|
|
|
|
|
|
def Embedding(num_embeddings, embedding_dim, padding_idx):
|
|
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
|
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
|
|
nn.init.constant_(m.weight[padding_idx], 0)
|
|
return m
|
|
|
|
|
|
def Linear(in_features, out_features, bias=True):
|
|
m = nn.Linear(in_features, out_features, bias)
|
|
nn.init.xavier_uniform_(m.weight)
|
|
if bias:
|
|
nn.init.constant_(m.bias, 0.0)
|
|
return m
|