mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-02-28 17:14:19 +00:00
1120 lines
42 KiB
Python
1120 lines
42 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import math
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from fairseq import utils
|
|
from fairseq.models import (
|
|
FairseqEncoder,
|
|
FairseqEncoderDecoderModel,
|
|
FairseqIncrementalDecoder,
|
|
register_model,
|
|
register_model_architecture,
|
|
)
|
|
from fairseq.modules import (
|
|
AdaptiveSoftmax,
|
|
DynamicConv_scripatable as DynamicConv,
|
|
FairseqDropout,
|
|
LayerNorm,
|
|
LightweightConv,
|
|
MultiheadAttention,
|
|
PositionalEmbedding,
|
|
)
|
|
from fairseq.utils import safe_hasattr
|
|
from torch import Tensor
|
|
|
|
|
|
@register_model("lightconv")
|
|
class LightConvModel(FairseqEncoderDecoderModel):
|
|
"""
|
|
LightConv and DynamicConv model from `"Pay Less Attention with Lightweight and Dynamic Convolutions" (Wu, et al, 2019)
|
|
<https://openreview.net/pdf?id=SkVhlh09tX>`_.
|
|
To use LightConv please set ``--encoder-conv-type lightweight --decoder-conv-type lightweight``
|
|
To use DynamicConv please set ``--encoder-conv-type dynamic --decoder-conv-type dynamic``
|
|
|
|
Args:
|
|
encoder (LightConvEncoder): the encoder
|
|
decoder (LightConvDecoder): the decoder
|
|
|
|
The LightConv model provides the following named architectures and
|
|
command-line arguments:
|
|
|
|
.. argparse::
|
|
:ref: fairseq.models.lightconv_parser
|
|
:prog:
|
|
"""
|
|
|
|
@classmethod
|
|
def hub_models(cls):
|
|
# fmt: off
|
|
|
|
def moses_subword(path):
|
|
return {
|
|
'path': path,
|
|
'tokenizer': 'moses',
|
|
'bpe': 'subword_nmt',
|
|
}
|
|
|
|
return {
|
|
'lightconv.no_glu.iwslt14.de-en': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/iwslt14.de-en.lightconv.tar.gz'),
|
|
'dynamicconv.no_glu.iwslt14.de-en': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/iwslt14.de-en.dynamicconv.tar.gz'),
|
|
'lightconv.no_glu.wmt16.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv.tar.gz'),
|
|
'dynamicconv.no_glu.wmt16.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv.tar.gz'),
|
|
'lightconv.glu.wmt16.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv-glu.tar.gz'),
|
|
'dynamicconv.glu.wmt16.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv-glu.tar.gz'),
|
|
'lightconv.glu.wmt17.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv-glu.tar.gz'),
|
|
'dynamicconv.glu.wmt17.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv-glu.tar.gz'),
|
|
'lightconv.glu.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt14.en-fr.joined-dict.lightconv-glu.tar.gz'),
|
|
'dynamicconv.glu.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt14.en-fr.joined-dict.dynamicconv-glu.tar.gz'),
|
|
'lightconv.glu.wmt17.zh-en': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt17.zh-en.lightconv-glu.tar.gz'),
|
|
'dynamicconv.glu.wmt17.zh-en': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt17.zh-en.dynamicconv-glu.tar.gz'),
|
|
}
|
|
# fmt: on
|
|
|
|
def __init__(self, encoder, decoder):
|
|
super().__init__(encoder, decoder)
|
|
|
|
@staticmethod
|
|
def add_args(parser):
|
|
"""Add model-specific arguments to the parser."""
|
|
parser.add_argument(
|
|
"--dropout", type=float, metavar="D", help="dropout probability"
|
|
)
|
|
parser.add_argument(
|
|
"--attention-dropout",
|
|
type=float,
|
|
metavar="D",
|
|
help="dropout probability for attention weights",
|
|
)
|
|
parser.add_argument(
|
|
"--relu-dropout",
|
|
type=float,
|
|
metavar="D",
|
|
help="dropout probability after ReLU in FFN",
|
|
)
|
|
parser.add_argument(
|
|
"--input-dropout",
|
|
type=float,
|
|
metavar="D",
|
|
help="dropout probability of the inputs",
|
|
)
|
|
parser.add_argument(
|
|
"--encoder-embed-path",
|
|
type=str,
|
|
metavar="STR",
|
|
help="path to pre-trained encoder embedding",
|
|
)
|
|
parser.add_argument(
|
|
"--encoder-embed-dim",
|
|
type=int,
|
|
metavar="N",
|
|
help="encoder embedding dimension",
|
|
)
|
|
parser.add_argument(
|
|
"--encoder-conv-dim",
|
|
type=int,
|
|
metavar="N",
|
|
help="encoder embedding dimension",
|
|
)
|
|
parser.add_argument(
|
|
"--encoder-ffn-embed-dim",
|
|
type=int,
|
|
metavar="N",
|
|
help="encoder embedding dimension for FFN",
|
|
)
|
|
parser.add_argument(
|
|
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
|
|
)
|
|
parser.add_argument(
|
|
"--encoder-attention-heads",
|
|
type=int,
|
|
metavar="N",
|
|
help="num encoder attention heads or LightConv/DynamicConv heads",
|
|
)
|
|
parser.add_argument(
|
|
"--encoder-normalize-before",
|
|
action="store_true",
|
|
help="apply layernorm before each encoder block",
|
|
)
|
|
parser.add_argument(
|
|
"--encoder-learned-pos",
|
|
action="store_true",
|
|
help="use learned positional embeddings in the encoder",
|
|
)
|
|
parser.add_argument(
|
|
"--decoder-embed-path",
|
|
type=str,
|
|
metavar="STR",
|
|
help="path to pre-trained decoder embedding",
|
|
)
|
|
parser.add_argument(
|
|
"--decoder-embed-dim",
|
|
type=int,
|
|
metavar="N",
|
|
help="decoder embedding dimension",
|
|
)
|
|
parser.add_argument(
|
|
"--decoder-conv-dim",
|
|
type=int,
|
|
metavar="N",
|
|
help="decoder embedding dimension",
|
|
)
|
|
parser.add_argument(
|
|
"--decoder-ffn-embed-dim",
|
|
type=int,
|
|
metavar="N",
|
|
help="decoder embedding dimension for FFN",
|
|
)
|
|
parser.add_argument(
|
|
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
|
|
)
|
|
parser.add_argument(
|
|
"--decoder-attention-heads",
|
|
type=int,
|
|
metavar="N",
|
|
help="num decoder attention heads or LightConv/DynamicConv heads",
|
|
)
|
|
parser.add_argument(
|
|
"--decoder-learned-pos",
|
|
action="store_true",
|
|
help="use learned positional embeddings in the decoder",
|
|
)
|
|
parser.add_argument(
|
|
"--decoder-normalize-before",
|
|
action="store_true",
|
|
help="apply layernorm before each decoder block",
|
|
)
|
|
parser.add_argument(
|
|
"--share-decoder-input-output-embed",
|
|
action="store_true",
|
|
help="share decoder input and output embeddings",
|
|
)
|
|
parser.add_argument(
|
|
"--share-all-embeddings",
|
|
action="store_true",
|
|
help="share encoder, decoder and output embeddings"
|
|
" (requires shared dictionary and embed dim)",
|
|
)
|
|
parser.add_argument(
|
|
"--adaptive-softmax-cutoff",
|
|
metavar="EXPR",
|
|
help="comma separated list of adaptive softmax cutoff points. "
|
|
"Must be used with adaptive_loss criterion",
|
|
),
|
|
parser.add_argument(
|
|
"--adaptive-softmax-dropout",
|
|
type=float,
|
|
metavar="D",
|
|
help="sets adaptive softmax dropout for the tail projections",
|
|
)
|
|
|
|
"""LightConv and DynamicConv arguments"""
|
|
parser.add_argument(
|
|
"--encoder-kernel-size-list",
|
|
type=lambda x: utils.eval_str_list(x, int),
|
|
help='list of kernel size (default: "[3,7,15,31,31,31,31]")',
|
|
)
|
|
parser.add_argument(
|
|
"--decoder-kernel-size-list",
|
|
type=lambda x: utils.eval_str_list(x, int),
|
|
help='list of kernel size (default: "[3,7,15,31,31,31]")',
|
|
)
|
|
parser.add_argument(
|
|
"--encoder-glu", type=utils.eval_bool, help="glu after in proj"
|
|
)
|
|
parser.add_argument(
|
|
"--decoder-glu", type=utils.eval_bool, help="glu after in proj"
|
|
)
|
|
parser.add_argument(
|
|
"--encoder-conv-type",
|
|
default="dynamic",
|
|
type=str,
|
|
choices=["dynamic", "lightweight"],
|
|
help="type of convolution",
|
|
)
|
|
parser.add_argument(
|
|
"--decoder-conv-type",
|
|
default="dynamic",
|
|
type=str,
|
|
choices=["dynamic", "lightweight"],
|
|
help="type of convolution",
|
|
)
|
|
parser.add_argument("--weight-softmax", default=True, type=utils.eval_bool)
|
|
parser.add_argument(
|
|
"--weight-dropout",
|
|
type=float,
|
|
metavar="D",
|
|
help="dropout probability for conv weights",
|
|
)
|
|
|
|
@classmethod
|
|
def build_model(cls, args, task):
|
|
"""Build a new model instance."""
|
|
|
|
# make sure all arguments are present in older models
|
|
base_architecture(args)
|
|
|
|
if not safe_hasattr(args, "max_source_positions"):
|
|
args.max_source_positions = 1024
|
|
if not safe_hasattr(args, "max_target_positions"):
|
|
args.max_target_positions = 1024
|
|
|
|
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
|
|
|
|
def build_embedding(dictionary, embed_dim, path=None):
|
|
num_embeddings = len(dictionary)
|
|
padding_idx = dictionary.pad()
|
|
emb = Embedding(num_embeddings, embed_dim, padding_idx)
|
|
# if provided, load from preloaded dictionaries
|
|
if path:
|
|
embed_dict = utils.parse_embedding(path)
|
|
utils.load_embedding(embed_dict, dictionary, emb)
|
|
return emb
|
|
|
|
if args.share_all_embeddings:
|
|
if src_dict != tgt_dict:
|
|
raise RuntimeError(
|
|
"--share-all-embeddings requires a joined dictionary"
|
|
)
|
|
if args.encoder_embed_dim != args.decoder_embed_dim:
|
|
raise RuntimeError(
|
|
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
|
|
)
|
|
if args.decoder_embed_path and (
|
|
args.decoder_embed_path != args.encoder_embed_path
|
|
):
|
|
raise RuntimeError(
|
|
"--share-all-embeddings not compatible with --decoder-embed-path"
|
|
)
|
|
encoder_embed_tokens = build_embedding(
|
|
src_dict, args.encoder_embed_dim, args.encoder_embed_path
|
|
)
|
|
decoder_embed_tokens = encoder_embed_tokens
|
|
args.share_decoder_input_output_embed = True
|
|
else:
|
|
encoder_embed_tokens = build_embedding(
|
|
src_dict, args.encoder_embed_dim, args.encoder_embed_path
|
|
)
|
|
decoder_embed_tokens = build_embedding(
|
|
tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
|
|
)
|
|
|
|
encoder = LightConvEncoder(args, src_dict, encoder_embed_tokens)
|
|
decoder = LightConvDecoder(args, tgt_dict, decoder_embed_tokens)
|
|
return LightConvModel(encoder, decoder)
|
|
|
|
def forward(
|
|
self,
|
|
src_tokens: Tensor,
|
|
src_lengths: Tensor,
|
|
prev_output_tokens: Tensor,
|
|
):
|
|
"""
|
|
(The forward method inherited from the base class has a **kwargs
|
|
argument in its input, which is not supported in torchscript. This
|
|
method overwrites the forward method definition without **kwargs.)
|
|
|
|
Run the forward pass for an encoder-decoder model.
|
|
|
|
First feed a batch of source tokens through the encoder. Then, feed the
|
|
encoder output and previous decoder outputs (i.e., teacher forcing) to
|
|
the decoder to produce the next outputs::
|
|
|
|
encoder_out = self.encoder(src_tokens, src_lengths)
|
|
return self.decoder(prev_output_tokens, encoder_out)
|
|
|
|
Args:
|
|
src_tokens (LongTensor): tokens in the source language of shape
|
|
`(batch, src_len)`
|
|
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
|
|
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
|
`(batch, tgt_len)`, for teacher forcing
|
|
|
|
Returns:
|
|
tuple:
|
|
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
|
- a dictionary with any model-specific outputs
|
|
"""
|
|
encoder_out = self.encoder(src_tokens, src_lengths)
|
|
decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out)
|
|
return decoder_out
|
|
|
|
|
|
class LightConvEncoder(FairseqEncoder):
|
|
"""
|
|
LightConv encoder consisting of *args.encoder_layers* layers. Each layer
|
|
is a :class:`LightConvEncoderLayer`.
|
|
|
|
Args:
|
|
args (argparse.Namespace): parsed command-line arguments
|
|
dictionary (~fairseq.data.Dictionary): encoding dictionary
|
|
embed_tokens (torch.nn.Embedding): input embedding
|
|
"""
|
|
|
|
def __init__(self, args, dictionary, embed_tokens):
|
|
super().__init__(dictionary)
|
|
self.dropout_module = FairseqDropout(
|
|
args.dropout, module_name=self.__class__.__name__
|
|
)
|
|
|
|
embed_dim = embed_tokens.embedding_dim
|
|
self.padding_idx = embed_tokens.padding_idx
|
|
self.max_source_positions = args.max_source_positions
|
|
|
|
self.embed_tokens = embed_tokens
|
|
self.embed_scale = math.sqrt(embed_dim)
|
|
self.embed_positions = (
|
|
PositionalEmbedding(
|
|
args.max_source_positions,
|
|
embed_dim,
|
|
self.padding_idx,
|
|
learned=args.encoder_learned_pos,
|
|
)
|
|
if not args.no_token_positional_embeddings
|
|
else None
|
|
)
|
|
|
|
self.layers = nn.ModuleList([])
|
|
self.layers.extend(
|
|
[
|
|
LightConvEncoderLayer(
|
|
args, kernel_size=args.encoder_kernel_size_list[i]
|
|
)
|
|
for i in range(args.encoder_layers)
|
|
]
|
|
)
|
|
self.register_buffer("version", torch.Tensor([2]))
|
|
self.normalize = args.encoder_normalize_before
|
|
if self.normalize:
|
|
self.layer_norm = LayerNorm(embed_dim)
|
|
else:
|
|
self.layer_norm = None
|
|
|
|
def forward(
|
|
self, src_tokens: Tensor, src_lengths: Optional[Tensor] = None
|
|
) -> Dict[str, List[Tensor]]:
|
|
"""
|
|
Args:
|
|
src_tokens (LongTensor): tokens in the source language of shape
|
|
`(batch, src_len)`
|
|
|
|
Returns:
|
|
dict:
|
|
- **encoder_out** (Tensor): the last encoder layer's output of
|
|
shape `(src_len, batch, embed_dim)`
|
|
- **encoder_padding_mask** (ByteTensor): the positions of
|
|
padding elements of shape `(batch, src_len)`
|
|
"""
|
|
# embed tokens and positions
|
|
x = self.embed_scale * self.embed_tokens(src_tokens)
|
|
if self.embed_positions is not None:
|
|
x += self.embed_positions(src_tokens)
|
|
x = self.dropout_module(x)
|
|
|
|
# B x T x C -> T x B x C
|
|
x = x.transpose(0, 1)
|
|
|
|
# compute padding mask
|
|
encoder_padding_mask = src_tokens.eq(self.padding_idx) # B x T
|
|
if not encoder_padding_mask.any():
|
|
encoder_mask = None
|
|
else:
|
|
encoder_mask = encoder_padding_mask
|
|
|
|
# encoder layers
|
|
for layer in self.layers:
|
|
x = layer(x, encoder_mask)
|
|
|
|
if self.layer_norm is not None:
|
|
x = self.layer_norm(x)
|
|
|
|
output_dict: Dict[str, List[Tensor]] = {}
|
|
if src_lengths is not None:
|
|
output_dict["src_lengths"] = [src_lengths]
|
|
output_dict["encoder_out"] = [x] # T x B x C
|
|
if encoder_mask is not None:
|
|
output_dict["encoder_padding_mask"] = [encoder_mask] # B x T
|
|
|
|
return output_dict
|
|
|
|
@torch.jit.export
|
|
def reorder_encoder_out(
|
|
self, encoder_out: Dict[str, List[Tensor]], new_order: Tensor
|
|
):
|
|
"""
|
|
Reorder encoder output according to *new_order*.
|
|
|
|
Args:
|
|
encoder_out: output from the ``forward()`` method
|
|
new_order (LongTensor): desired order
|
|
|
|
Returns:
|
|
*encoder_out* rearranged according to *new_order*
|
|
"""
|
|
if len(encoder_out["encoder_out"]) == 0:
|
|
encoder = []
|
|
else:
|
|
encoder = [encoder_out["encoder_out"][0].index_select(1, new_order)]
|
|
output_dict = {"encoder_out": encoder}
|
|
|
|
if ("encoder_padding_mask" not in encoder_out) or (
|
|
len(encoder_out["encoder_padding_mask"]) == 0
|
|
):
|
|
encoder_padding_mask = []
|
|
else:
|
|
encoder_padding_mask = [
|
|
encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
|
|
]
|
|
output_dict["encoder_padding_mask"] = encoder_padding_mask
|
|
return output_dict
|
|
|
|
def max_positions(self):
|
|
"""Maximum input length supported by the encoder."""
|
|
if self.embed_positions is None:
|
|
return self.max_source_positions
|
|
return min(self.max_source_positions, self.embed_positions.max_positions)
|
|
|
|
|
|
class LightConvDecoder(FairseqIncrementalDecoder):
|
|
"""
|
|
LightConv decoder consisting of *args.decoder_layers* layers. Each layer
|
|
is a :class:`LightConvDecoderLayer`.
|
|
|
|
Args:
|
|
args (argparse.Namespace): parsed command-line arguments
|
|
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
|
embed_tokens (torch.nn.Embedding): output embedding
|
|
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
|
|
Default: ``False``
|
|
"""
|
|
|
|
def __init__(
|
|
self, args, dictionary, embed_tokens, no_encoder_attn=False, final_norm=True
|
|
):
|
|
super().__init__(dictionary)
|
|
self.dropout_module = FairseqDropout(
|
|
args.dropout, module_name=self.__class__.__name__
|
|
)
|
|
self.share_input_output_embed = args.share_decoder_input_output_embed
|
|
|
|
input_embed_dim = embed_tokens.embedding_dim
|
|
embed_dim = args.decoder_embed_dim
|
|
output_embed_dim = args.decoder_output_dim
|
|
|
|
padding_idx = embed_tokens.padding_idx
|
|
self.max_target_positions = args.max_target_positions
|
|
|
|
self.embed_tokens = embed_tokens
|
|
self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
|
|
|
|
self.project_in_dim = (
|
|
Linear(input_embed_dim, embed_dim, bias=False)
|
|
if embed_dim != input_embed_dim
|
|
else None
|
|
)
|
|
|
|
self.embed_positions = (
|
|
PositionalEmbedding(
|
|
args.max_target_positions,
|
|
embed_dim,
|
|
padding_idx,
|
|
learned=args.decoder_learned_pos,
|
|
)
|
|
if not args.no_token_positional_embeddings
|
|
else None
|
|
)
|
|
|
|
self.layers = nn.ModuleList([])
|
|
self.layers.extend(
|
|
[
|
|
LightConvDecoderLayer(
|
|
args,
|
|
no_encoder_attn,
|
|
kernel_size=args.decoder_kernel_size_list[i],
|
|
dictionary=dictionary,
|
|
)
|
|
for i in range(args.decoder_layers)
|
|
]
|
|
)
|
|
|
|
self.adaptive_softmax = None
|
|
self.output_projection = None
|
|
|
|
self.project_out_dim = (
|
|
Linear(embed_dim, output_embed_dim, bias=False)
|
|
if embed_dim != output_embed_dim and not args.tie_adaptive_weights
|
|
else None
|
|
)
|
|
|
|
if args.adaptive_softmax_cutoff is not None:
|
|
self.adaptive_softmax = AdaptiveSoftmax(
|
|
len(dictionary),
|
|
output_embed_dim,
|
|
utils.eval_str_list(args.adaptive_softmax_cutoff, type=int),
|
|
dropout=args.adaptive_softmax_dropout,
|
|
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
|
|
factor=args.adaptive_softmax_factor,
|
|
tie_proj=args.tie_adaptive_proj,
|
|
)
|
|
elif self.share_input_output_embed:
|
|
self.output_projection = nn.Linear(
|
|
self.embed_tokens.weight.shape[1],
|
|
self.embed_tokens.weight.shape[0],
|
|
bias=False,
|
|
)
|
|
self.output_projection.weight = self.embed_tokens.weight
|
|
|
|
else:
|
|
self.output_projection = nn.Linear(
|
|
output_embed_dim, len(dictionary), bias=False
|
|
)
|
|
nn.init.normal_(
|
|
self.output_projection.weight, mean=0, std=output_embed_dim**-0.5
|
|
)
|
|
self.register_buffer("version", torch.Tensor([2]))
|
|
self.normalize = args.decoder_normalize_before and final_norm
|
|
if self.normalize:
|
|
self.layer_norm = LayerNorm(embed_dim)
|
|
else:
|
|
self.layer_norm = None
|
|
|
|
def forward(
|
|
self,
|
|
prev_output_tokens: Tensor,
|
|
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
|
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
|
src_lengths: Optional[Any] = None,
|
|
):
|
|
"""
|
|
Args:
|
|
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
|
`(batch, tgt_len)`, for teacher forcing
|
|
encoder_out (Tensor, optional): output from the encoder, used for
|
|
encoder-side attention
|
|
incremental_state (dict): dictionary used for storing state during
|
|
:ref:`Incremental decoding`
|
|
|
|
Returns:
|
|
tuple:
|
|
- the last decoder layer's output of shape `(batch, tgt_len,
|
|
vocab)`
|
|
- the last decoder layer's attention weights of shape `(batch,
|
|
tgt_len, src_len)`
|
|
"""
|
|
# embed positions
|
|
positions = (
|
|
self.embed_positions(
|
|
prev_output_tokens,
|
|
incremental_state=incremental_state,
|
|
)
|
|
if self.embed_positions is not None
|
|
else None
|
|
)
|
|
|
|
if incremental_state is not None:
|
|
prev_output_tokens = prev_output_tokens[:, -1:]
|
|
if positions is not None:
|
|
positions = positions[:, -1:]
|
|
|
|
# embed tokens and positions
|
|
x = self.embed_scale * self.embed_tokens(prev_output_tokens.contiguous())
|
|
|
|
if self.project_in_dim is not None:
|
|
x = self.project_in_dim(x)
|
|
|
|
if positions is not None:
|
|
x += positions
|
|
x = self.dropout_module(x)
|
|
|
|
# B x T x C -> T x B x C
|
|
x = x.transpose(0, 1)
|
|
attn = None
|
|
|
|
inner_states: List[Optional[Tensor]] = [x]
|
|
|
|
# decoder layers
|
|
attn: Optional[Tensor] = None
|
|
for layer in self.layers:
|
|
encoder: Optional[Tensor] = None
|
|
encoder_padding_mask: Optional[Tensor] = None
|
|
if encoder_out is not None:
|
|
if len(encoder_out["encoder_out"]) > 0:
|
|
encoder = encoder_out["encoder_out"][0]
|
|
if (
|
|
"encoder_padding_mask" in encoder_out
|
|
and len(encoder_out["encoder_padding_mask"]) > 0
|
|
):
|
|
encoder_padding_mask = encoder_out["encoder_padding_mask"][0]
|
|
x, attn = layer(
|
|
x,
|
|
encoder,
|
|
encoder_padding_mask,
|
|
incremental_state,
|
|
)
|
|
inner_states.append(x)
|
|
|
|
if self.layer_norm is not None:
|
|
x = self.layer_norm(x)
|
|
|
|
# T x B x C -> B x T x C
|
|
x = x.transpose(0, 1)
|
|
|
|
if self.project_out_dim is not None:
|
|
x = self.project_out_dim(x)
|
|
|
|
if self.adaptive_softmax is None:
|
|
# project back to size of vocabulary
|
|
x = self.output_projection(x)
|
|
|
|
return x, {"attn": [attn], "inner_states": inner_states}
|
|
|
|
def max_positions(self):
|
|
"""Maximum output length supported by the decoder."""
|
|
if self.embed_positions is None:
|
|
return self.max_target_positions
|
|
return min(self.max_target_positions, self.embed_positions.max_positions)
|
|
|
|
def buffered_future_mask(self, tensor):
|
|
dim = tensor.size(0)
|
|
if (
|
|
not hasattr(self, "_future_mask")
|
|
or self._future_mask is None
|
|
or self._future_mask.device != tensor.device
|
|
):
|
|
self._future_mask = torch.triu(
|
|
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
|
|
)
|
|
if self._future_mask.size(0) < dim:
|
|
self._future_mask = torch.triu(
|
|
utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
|
|
)
|
|
return self._future_mask[:dim, :dim]
|
|
|
|
|
|
class LightConvEncoderLayer(nn.Module):
|
|
"""Encoder layer block.
|
|
|
|
Args:
|
|
args (argparse.Namespace): parsed command-line arguments
|
|
kernel_size: kernel size of the convolution
|
|
"""
|
|
|
|
def __init__(self, args, kernel_size=0):
|
|
super().__init__()
|
|
self.embed_dim = args.encoder_embed_dim
|
|
self.conv_dim = args.encoder_conv_dim
|
|
padding_l = (
|
|
kernel_size // 2
|
|
if kernel_size % 2 == 1
|
|
else ((kernel_size - 1) // 2, kernel_size // 2)
|
|
)
|
|
|
|
if args.encoder_glu:
|
|
self.linear1 = Linear(self.embed_dim, 2 * self.conv_dim)
|
|
self.act = nn.GLU()
|
|
else:
|
|
self.linear1 = Linear(self.embed_dim, self.conv_dim)
|
|
self.act = None
|
|
if args.encoder_conv_type == "lightweight":
|
|
self.conv = LightweightConv(
|
|
self.conv_dim,
|
|
kernel_size,
|
|
padding_l=padding_l,
|
|
weight_softmax=args.weight_softmax,
|
|
num_heads=args.encoder_attention_heads,
|
|
weight_dropout=args.weight_dropout,
|
|
)
|
|
elif args.encoder_conv_type == "dynamic":
|
|
self.conv = DynamicConv(
|
|
self.conv_dim,
|
|
kernel_size,
|
|
padding_l=padding_l,
|
|
weight_softmax=args.weight_softmax,
|
|
num_heads=args.encoder_attention_heads,
|
|
weight_dropout=args.weight_dropout,
|
|
)
|
|
else:
|
|
raise NotImplementedError
|
|
self.linear2 = Linear(self.conv_dim, self.embed_dim)
|
|
|
|
self.dropout_module = FairseqDropout(
|
|
args.dropout, module_name=self.__class__.__name__
|
|
)
|
|
self.relu_dropout_module = FairseqDropout(
|
|
args.relu_dropout, module_name=self.__class__.__name__
|
|
)
|
|
self.input_dropout_module = FairseqDropout(
|
|
args.input_dropout, module_name=self.__class__.__name__
|
|
)
|
|
self.normalize_before = args.encoder_normalize_before
|
|
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
|
|
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
|
|
self.layer_norm1 = LayerNorm(self.embed_dim)
|
|
self.layer_norm2 = LayerNorm(self.embed_dim)
|
|
|
|
def forward(self, x, encoder_padding_mask: Optional[Tensor] = None) -> Tensor:
|
|
"""
|
|
Args:
|
|
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
|
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
|
|
`(batch, src_len)` where padding elements are indicated by ``1``.
|
|
|
|
Returns:
|
|
encoded output of shape `(batch, src_len, embed_dim)`
|
|
"""
|
|
residual = x
|
|
normalize = self.maybe_layer_norm(before=True)
|
|
if normalize:
|
|
x = self.layer_norm1(x)
|
|
x = self.input_dropout_module(x)
|
|
x = self.linear1(x)
|
|
if self.act is not None:
|
|
x = self.act(x)
|
|
if encoder_padding_mask is not None:
|
|
x = x.masked_fill(encoder_padding_mask.transpose(0, 1).unsqueeze(2), 0)
|
|
x = self.conv(x)
|
|
x = self.linear2(x)
|
|
x = self.dropout_module(x)
|
|
x = residual + x
|
|
normalize = self.maybe_layer_norm(after=True)
|
|
if normalize:
|
|
x = self.layer_norm1(x)
|
|
|
|
residual = x
|
|
normalize = self.maybe_layer_norm(before=True)
|
|
if normalize:
|
|
x = self.layer_norm2(x)
|
|
x = F.relu(self.fc1(x))
|
|
x = self.relu_dropout_module(x)
|
|
x = self.fc2(x)
|
|
x = self.dropout_module(x)
|
|
x = residual + x
|
|
normalize = self.maybe_layer_norm(after=True)
|
|
if normalize:
|
|
x = self.layer_norm2(x)
|
|
return x
|
|
|
|
def maybe_layer_norm(self, before: bool = False, after: bool = False):
|
|
assert before ^ after, "Incorrect arguments"
|
|
return after ^ self.normalize_before
|
|
|
|
def extra_repr(self):
|
|
return (
|
|
"dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}".format(
|
|
self.dropout_module.p,
|
|
self.relu_dropout_module.p,
|
|
self.input_dropout_module.p,
|
|
self.normalize_before,
|
|
)
|
|
)
|
|
|
|
|
|
class LightConvDecoderLayer(nn.Module):
|
|
"""Decoder layer block.
|
|
|
|
Args:
|
|
args (argparse.Namespace): parsed command-line arguments
|
|
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
|
|
Default: ``False``
|
|
kernel_size: kernel size of the convolution
|
|
"""
|
|
|
|
def __init__(self, args, no_encoder_attn=False, kernel_size=0, dictionary=None):
|
|
super().__init__()
|
|
self.embed_dim = args.decoder_embed_dim
|
|
self.conv_dim = args.decoder_conv_dim
|
|
if args.decoder_glu:
|
|
self.linear1 = Linear(self.embed_dim, 2 * self.conv_dim)
|
|
self.act = nn.GLU()
|
|
else:
|
|
self.linear1 = Linear(self.embed_dim, self.conv_dim)
|
|
self.act = None
|
|
if args.decoder_conv_type == "lightweight":
|
|
self.conv = LightweightConv(
|
|
self.conv_dim,
|
|
kernel_size,
|
|
padding_l=kernel_size - 1,
|
|
weight_softmax=args.weight_softmax,
|
|
num_heads=args.decoder_attention_heads,
|
|
weight_dropout=args.weight_dropout,
|
|
)
|
|
elif args.decoder_conv_type == "dynamic":
|
|
self.conv = DynamicConv(
|
|
self.conv_dim,
|
|
kernel_size,
|
|
padding_l=kernel_size - 1,
|
|
weight_softmax=args.weight_softmax,
|
|
num_heads=args.decoder_attention_heads,
|
|
weight_dropout=args.weight_dropout,
|
|
)
|
|
else:
|
|
raise NotImplementedError
|
|
self.linear2 = Linear(self.conv_dim, self.embed_dim)
|
|
|
|
self.dropout_module = FairseqDropout(
|
|
args.dropout, module_name=self.__class__.__name__
|
|
)
|
|
self.relu_dropout_module = FairseqDropout(
|
|
args.relu_dropout, module_name=self.__class__.__name__
|
|
)
|
|
self.input_dropout_module = FairseqDropout(
|
|
args.input_dropout, module_name=self.__class__.__name__
|
|
)
|
|
self.normalize_before = args.decoder_normalize_before
|
|
|
|
self.conv_layer_norm = LayerNorm(self.embed_dim)
|
|
|
|
if no_encoder_attn:
|
|
self.encoder_attn = None
|
|
self.encoder_attn_layer_norm = None
|
|
else:
|
|
self.encoder_attn = MultiheadAttention(
|
|
self.embed_dim,
|
|
args.decoder_attention_heads,
|
|
dropout=args.attention_dropout,
|
|
encoder_decoder_attention=True,
|
|
dictionary=dictionary,
|
|
)
|
|
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
|
|
|
|
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
|
|
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
|
|
|
|
self.final_layer_norm = LayerNorm(self.embed_dim)
|
|
self.need_attn = True
|
|
|
|
def forward(
|
|
self,
|
|
x: Tensor,
|
|
encoder_out: Optional[Tensor],
|
|
encoder_padding_mask: Optional[Tensor],
|
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
|
prev_conv_state: Optional[Tensor] = None,
|
|
prev_attn_state: Optional[Tuple[Tensor, Tensor]] = None,
|
|
conv_mask: Optional[Tensor] = None,
|
|
conv_padding_mask: Optional[Tensor] = None,
|
|
):
|
|
"""
|
|
Args:
|
|
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
|
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
|
|
`(batch, src_len)` where padding elements are indicated by ``1``.
|
|
|
|
Returns:
|
|
encoded output of shape `(batch, src_len, embed_dim)`
|
|
"""
|
|
residual = x
|
|
normalize = self.maybe_layer_norm(before=True)
|
|
if normalize:
|
|
x = self.conv_layer_norm(x)
|
|
if prev_conv_state is not None:
|
|
self.conv._set_input_buffer(incremental_state, prev_conv_state)
|
|
x = self.input_dropout_module(x)
|
|
x = self.linear1(x)
|
|
if self.act is not None:
|
|
x = self.act(x)
|
|
x = self.conv(x, incremental_state=incremental_state)
|
|
x = self.linear2(x)
|
|
x = self.dropout_module(x)
|
|
x = residual + x
|
|
normalize = self.maybe_layer_norm(after=True)
|
|
if normalize:
|
|
x = self.conv_layer_norm(x)
|
|
|
|
attn: Optional[Tensor] = None
|
|
if self.encoder_attn is not None:
|
|
residual = x
|
|
normalize = self.maybe_layer_norm(before=True)
|
|
if normalize:
|
|
x = self.encoder_attn_layer_norm(x)
|
|
|
|
if prev_attn_state is not None:
|
|
saved_state: Dict[str, Optional[Tensor]] = {
|
|
"prev_key": prev_attn_state[0],
|
|
"prev_value": prev_attn_state[1],
|
|
}
|
|
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
|
|
x, attn = self.encoder_attn(
|
|
query=x,
|
|
key=encoder_out,
|
|
value=encoder_out,
|
|
key_padding_mask=encoder_padding_mask,
|
|
incremental_state=incremental_state,
|
|
static_kv=True,
|
|
need_weights=(not self.training and self.need_attn),
|
|
)
|
|
x = self.dropout_module(x)
|
|
x = residual + x
|
|
normalize = self.maybe_layer_norm(after=True)
|
|
if normalize:
|
|
x = self.encoder_attn_layer_norm(x)
|
|
|
|
residual = x
|
|
normalize = self.maybe_layer_norm(before=True)
|
|
if normalize:
|
|
x = self.final_layer_norm(x)
|
|
x = F.relu(self.fc1(x))
|
|
x = self.relu_dropout_module(x)
|
|
x = self.fc2(x)
|
|
x = self.dropout_module(x)
|
|
x = residual + x
|
|
normalize = self.maybe_layer_norm(after=True)
|
|
if normalize:
|
|
x = self.final_layer_norm(x)
|
|
return x, attn
|
|
|
|
def maybe_layer_norm(self, before: bool = False, after: bool = False):
|
|
assert before ^ after, "Incorrect usage"
|
|
return after ^ self.normalize_before
|
|
|
|
def make_generation_fast_(self, need_attn: bool = False, **kwargs):
|
|
self.need_attn = need_attn
|
|
|
|
def extra_repr(self):
|
|
return (
|
|
"dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}".format(
|
|
self.dropout_module.p,
|
|
self.relu_dropout_module.p,
|
|
self.input_dropout_module.p,
|
|
self.normalize_before,
|
|
)
|
|
)
|
|
|
|
|
|
def Embedding(num_embeddings, embedding_dim, padding_idx):
|
|
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
|
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
|
|
nn.init.constant_(m.weight[padding_idx], 0)
|
|
return m
|
|
|
|
|
|
def Linear(in_features, out_features, bias=True):
|
|
m = nn.Linear(in_features, out_features, bias)
|
|
nn.init.xavier_uniform_(m.weight)
|
|
if bias:
|
|
nn.init.constant_(m.bias, 0.0)
|
|
return m
|
|
|
|
|
|
@register_model_architecture("lightconv", "lightconv")
|
|
def base_architecture(args):
|
|
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
|
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
|
|
args.encoder_layers = getattr(args, "encoder_layers", 7)
|
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
|
|
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
|
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
|
|
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
|
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
|
|
args.decoder_ffn_embed_dim = getattr(
|
|
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
|
|
)
|
|
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
|
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
|
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
|
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
|
args.relu_dropout = getattr(args, "relu_dropout", 0.0)
|
|
args.dropout = getattr(args, "dropout", 0.1)
|
|
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
|
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
|
args.share_decoder_input_output_embed = getattr(
|
|
args, "share_decoder_input_output_embed", False
|
|
)
|
|
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
|
|
args.no_token_positional_embeddings = getattr(
|
|
args, "no_token_positional_embeddings", False
|
|
)
|
|
|
|
args.decoder_output_dim = getattr(
|
|
args, "decoder_output_dim", args.decoder_embed_dim
|
|
)
|
|
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
|
|
|
args.encoder_conv_dim = getattr(args, "encoder_conv_dim", args.encoder_embed_dim)
|
|
args.decoder_conv_dim = getattr(args, "decoder_conv_dim", args.decoder_embed_dim)
|
|
|
|
args.encoder_kernel_size_list = getattr(
|
|
args, "encoder_kernel_size_list", [3, 7, 15, 31, 31, 31, 31]
|
|
)
|
|
args.decoder_kernel_size_list = getattr(
|
|
args, "decoder_kernel_size_list", [3, 7, 15, 31, 31, 31]
|
|
)
|
|
if len(args.encoder_kernel_size_list) == 1:
|
|
args.encoder_kernel_size_list = (
|
|
args.encoder_kernel_size_list * args.encoder_layers
|
|
)
|
|
if len(args.decoder_kernel_size_list) == 1:
|
|
args.decoder_kernel_size_list = (
|
|
args.decoder_kernel_size_list * args.decoder_layers
|
|
)
|
|
assert (
|
|
len(args.encoder_kernel_size_list) == args.encoder_layers
|
|
), "encoder_kernel_size_list doesn't match encoder_layers"
|
|
assert (
|
|
len(args.decoder_kernel_size_list) == args.decoder_layers
|
|
), "decoder_kernel_size_list doesn't match decoder_layers"
|
|
args.encoder_glu = getattr(args, "encoder_glu", True)
|
|
args.decoder_glu = getattr(args, "decoder_glu", True)
|
|
args.input_dropout = getattr(args, "input_dropout", 0.1)
|
|
args.weight_dropout = getattr(args, "weight_dropout", args.attention_dropout)
|
|
|
|
|
|
@register_model_architecture("lightconv", "lightconv_iwslt_de_en")
|
|
def lightconv_iwslt_de_en(args):
|
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
|
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
|
|
args.encoder_layers = getattr(args, "encoder_layers", 7)
|
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
|
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
|
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
|
|
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
|
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
|
args.weight_dropout = getattr(args, "weight_dropout", 0.1)
|
|
args.encoder_glu = getattr(args, "encoder_glu", False)
|
|
args.decoder_glu = getattr(args, "decoder_glu", False)
|
|
args.input_dropout = getattr(args, "input_dropout", 0.0)
|
|
base_architecture(args)
|
|
|
|
|
|
@register_model_architecture("lightconv", "lightconv_wmt_en_de")
|
|
def lightconv_wmt_en_de(args):
|
|
base_architecture(args)
|
|
|
|
|
|
@register_model_architecture("lightconv", "lightconv_wmt_en_de_big")
|
|
def lightconv_wmt_en_de_big(args):
|
|
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
|
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
|
|
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
|
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
|
args.dropout = getattr(args, "dropout", 0.3)
|
|
base_architecture(args)
|
|
|
|
|
|
@register_model_architecture("lightconv", "lightconv_wmt_en_fr_big")
|
|
def lightconv_wmt_en_fr_big(args):
|
|
args.dropout = getattr(args, "dropout", 0.1)
|
|
lightconv_wmt_en_de_big(args)
|
|
|
|
|
|
@register_model_architecture("lightconv", "lightconv_wmt_zh_en_big")
|
|
def lightconv_wmt_zh_en_big(args):
|
|
args.dropout = getattr(args, "dropout", 0.2)
|
|
args.attention_dropout = getattr(args, "attention_dropout", 0.2)
|
|
args.weight_dropout = getattr(args, "weight_dropout", 0.2)
|
|
lightconv_wmt_en_de_big(args)
|