mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-22 15:39:01 +00:00
405 lines
15 KiB
Python
405 lines
15 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import logging
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from fairseq import utils
|
|
from fairseq.models import (
|
|
FairseqEncoder,
|
|
FairseqEncoderModel,
|
|
register_model,
|
|
register_model_architecture,
|
|
)
|
|
from fairseq.modules import (
|
|
LayerNorm,
|
|
SinusoidalPositionalEmbedding,
|
|
TransformerSentenceEncoder,
|
|
)
|
|
from fairseq.modules.transformer_sentence_encoder import init_bert_params
|
|
from fairseq.utils import safe_hasattr
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@register_model("masked_lm")
|
|
class MaskedLMModel(FairseqEncoderModel):
|
|
"""
|
|
Class for training a Masked Language Model. It also supports an
|
|
additional sentence level prediction if the sent-loss argument is set.
|
|
"""
|
|
|
|
def __init__(self, args, encoder):
|
|
super().__init__(encoder)
|
|
self.args = args
|
|
|
|
# if specified then apply bert initialization on the model. We need
|
|
# to explictly call this to make sure that the output embeddings
|
|
# and projection layers are also correctly initialized
|
|
if getattr(args, "apply_bert_init", False):
|
|
self.apply(init_bert_params)
|
|
|
|
@staticmethod
|
|
def add_args(parser):
|
|
"""Add model-specific arguments to the parser."""
|
|
# Arguments related to dropout
|
|
parser.add_argument(
|
|
"--dropout", type=float, metavar="D", help="dropout probability"
|
|
)
|
|
parser.add_argument(
|
|
"--attention-dropout",
|
|
type=float,
|
|
metavar="D",
|
|
help="dropout probability for" " attention weights",
|
|
)
|
|
parser.add_argument(
|
|
"--act-dropout",
|
|
type=float,
|
|
metavar="D",
|
|
help="dropout probability after" " activation in FFN",
|
|
)
|
|
|
|
# Arguments related to hidden states and self-attention
|
|
parser.add_argument(
|
|
"--encoder-ffn-embed-dim",
|
|
type=int,
|
|
metavar="N",
|
|
help="encoder embedding dimension for FFN",
|
|
)
|
|
parser.add_argument(
|
|
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
|
|
)
|
|
parser.add_argument(
|
|
"--encoder-attention-heads",
|
|
type=int,
|
|
metavar="N",
|
|
help="num encoder attention heads",
|
|
)
|
|
|
|
# Arguments related to input and output embeddings
|
|
parser.add_argument(
|
|
"--encoder-embed-dim",
|
|
type=int,
|
|
metavar="N",
|
|
help="encoder embedding dimension",
|
|
)
|
|
parser.add_argument(
|
|
"--share-encoder-input-output-embed",
|
|
action="store_true",
|
|
help="share encoder input" " and output embeddings",
|
|
)
|
|
parser.add_argument(
|
|
"--encoder-learned-pos",
|
|
action="store_true",
|
|
help="use learned positional embeddings in the encoder",
|
|
)
|
|
parser.add_argument(
|
|
"--no-token-positional-embeddings",
|
|
action="store_true",
|
|
help="if set, disables positional embeddings" " (outside self attention)",
|
|
)
|
|
parser.add_argument(
|
|
"--num-segment", type=int, metavar="N", help="num segment in the input"
|
|
)
|
|
parser.add_argument(
|
|
"--max-positions", type=int, help="number of positional embeddings to learn"
|
|
)
|
|
|
|
# Arguments related to sentence level prediction
|
|
parser.add_argument(
|
|
"--sentence-class-num",
|
|
type=int,
|
|
metavar="N",
|
|
help="number of classes for sentence task",
|
|
)
|
|
parser.add_argument(
|
|
"--sent-loss",
|
|
action="store_true",
|
|
help="if set," " calculate sentence level predictions",
|
|
)
|
|
|
|
# Arguments related to parameter initialization
|
|
parser.add_argument(
|
|
"--apply-bert-init",
|
|
action="store_true",
|
|
help="use custom param initialization for BERT",
|
|
)
|
|
|
|
# misc params
|
|
parser.add_argument(
|
|
"--activation-fn",
|
|
choices=utils.get_available_activation_fns(),
|
|
help="activation function to use",
|
|
)
|
|
parser.add_argument(
|
|
"--pooler-activation-fn",
|
|
choices=utils.get_available_activation_fns(),
|
|
help="Which activation function to use for pooler layer.",
|
|
)
|
|
parser.add_argument(
|
|
"--encoder-normalize-before",
|
|
action="store_true",
|
|
help="apply layernorm before each encoder block",
|
|
)
|
|
|
|
def forward(self, src_tokens, segment_labels=None, **kwargs):
|
|
return self.encoder(src_tokens, segment_labels=segment_labels, **kwargs)
|
|
|
|
def max_positions(self):
|
|
return self.encoder.max_positions
|
|
|
|
@classmethod
|
|
def build_model(cls, args, task):
|
|
"""Build a new model instance."""
|
|
# make sure all arguments are present in older models
|
|
base_architecture(args)
|
|
|
|
if not safe_hasattr(args, "max_positions"):
|
|
args.max_positions = args.tokens_per_sample
|
|
|
|
logger.info(args)
|
|
|
|
encoder = MaskedLMEncoder(args, task.dictionary)
|
|
return cls(args, encoder)
|
|
|
|
|
|
class MaskedLMEncoder(FairseqEncoder):
|
|
"""
|
|
Encoder for Masked Language Modelling.
|
|
"""
|
|
|
|
def __init__(self, args, dictionary):
|
|
super().__init__(dictionary)
|
|
|
|
self.padding_idx = dictionary.pad()
|
|
self.vocab_size = dictionary.__len__()
|
|
self.max_positions = args.max_positions
|
|
|
|
self.sentence_encoder = TransformerSentenceEncoder(
|
|
padding_idx=self.padding_idx,
|
|
vocab_size=self.vocab_size,
|
|
num_encoder_layers=args.encoder_layers,
|
|
embedding_dim=args.encoder_embed_dim,
|
|
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
|
num_attention_heads=args.encoder_attention_heads,
|
|
dropout=args.dropout,
|
|
attention_dropout=args.attention_dropout,
|
|
activation_dropout=args.act_dropout,
|
|
max_seq_len=self.max_positions,
|
|
num_segments=args.num_segment,
|
|
use_position_embeddings=not args.no_token_positional_embeddings,
|
|
encoder_normalize_before=args.encoder_normalize_before,
|
|
apply_bert_init=args.apply_bert_init,
|
|
activation_fn=args.activation_fn,
|
|
learned_pos_embedding=args.encoder_learned_pos,
|
|
)
|
|
|
|
self.share_input_output_embed = args.share_encoder_input_output_embed
|
|
self.embed_out = None
|
|
self.sentence_projection_layer = None
|
|
self.sentence_out_dim = args.sentence_class_num
|
|
self.lm_output_learned_bias = None
|
|
|
|
# Remove head is set to true during fine-tuning
|
|
self.load_softmax = not getattr(args, "remove_head", False)
|
|
|
|
self.masked_lm_pooler = nn.Linear(
|
|
args.encoder_embed_dim, args.encoder_embed_dim
|
|
)
|
|
self.pooler_activation = utils.get_activation_fn(args.pooler_activation_fn)
|
|
|
|
self.lm_head_transform_weight = nn.Linear(
|
|
args.encoder_embed_dim, args.encoder_embed_dim
|
|
)
|
|
self.activation_fn = utils.get_activation_fn(args.activation_fn)
|
|
self.layer_norm = LayerNorm(args.encoder_embed_dim)
|
|
|
|
self.lm_output_learned_bias = None
|
|
if self.load_softmax:
|
|
self.lm_output_learned_bias = nn.Parameter(torch.zeros(self.vocab_size))
|
|
|
|
if not self.share_input_output_embed:
|
|
self.embed_out = nn.Linear(
|
|
args.encoder_embed_dim, self.vocab_size, bias=False
|
|
)
|
|
|
|
if args.sent_loss:
|
|
self.sentence_projection_layer = nn.Linear(
|
|
args.encoder_embed_dim, self.sentence_out_dim, bias=False
|
|
)
|
|
|
|
def forward(self, src_tokens, segment_labels=None, masked_tokens=None, **unused):
|
|
"""
|
|
Forward pass for Masked LM encoder. This first computes the token
|
|
embedding using the token embedding matrix, position embeddings (if
|
|
specified) and segment embeddings (if specified).
|
|
|
|
Here we assume that the sentence representation corresponds to the
|
|
output of the classification_token (see bert_task or cross_lingual_lm
|
|
task for more details).
|
|
Args:
|
|
- src_tokens: B x T matrix representing sentences
|
|
- segment_labels: B x T matrix representing segment label for tokens
|
|
Returns:
|
|
- a tuple of the following:
|
|
- logits for predictions in format B x T x C to be used in
|
|
softmax afterwards
|
|
- a dictionary of additional data, where 'pooled_output' contains
|
|
the representation for classification_token and 'inner_states'
|
|
is a list of internal model states used to compute the
|
|
predictions (similar in ELMO). 'sentence_logits'
|
|
is the prediction logit for NSP task and is only computed if
|
|
this is specified in the input arguments.
|
|
"""
|
|
|
|
inner_states, sentence_rep = self.sentence_encoder(
|
|
src_tokens,
|
|
segment_labels=segment_labels,
|
|
)
|
|
|
|
x = inner_states[-1].transpose(0, 1)
|
|
# project masked tokens only
|
|
if masked_tokens is not None:
|
|
x = x[masked_tokens, :]
|
|
x = self.layer_norm(self.activation_fn(self.lm_head_transform_weight(x)))
|
|
|
|
pooled_output = self.pooler_activation(self.masked_lm_pooler(sentence_rep))
|
|
|
|
# project back to size of vocabulary
|
|
if self.share_input_output_embed and hasattr(
|
|
self.sentence_encoder.embed_tokens, "weight"
|
|
):
|
|
x = F.linear(x, self.sentence_encoder.embed_tokens.weight)
|
|
elif self.embed_out is not None:
|
|
x = self.embed_out(x)
|
|
if self.lm_output_learned_bias is not None:
|
|
x = x + self.lm_output_learned_bias
|
|
sentence_logits = None
|
|
if self.sentence_projection_layer:
|
|
sentence_logits = self.sentence_projection_layer(pooled_output)
|
|
|
|
return x, {
|
|
"inner_states": inner_states,
|
|
"pooled_output": pooled_output,
|
|
"sentence_logits": sentence_logits,
|
|
}
|
|
|
|
def max_positions(self):
|
|
"""Maximum output length supported by the encoder."""
|
|
return self.max_positions
|
|
|
|
def upgrade_state_dict_named(self, state_dict, name):
|
|
if isinstance(
|
|
self.sentence_encoder.embed_positions, SinusoidalPositionalEmbedding
|
|
):
|
|
state_dict[
|
|
name + ".sentence_encoder.embed_positions._float_tensor"
|
|
] = torch.FloatTensor(1)
|
|
if not self.load_softmax:
|
|
for k in list(state_dict.keys()):
|
|
if (
|
|
"embed_out.weight" in k
|
|
or "sentence_projection_layer.weight" in k
|
|
or "lm_output_learned_bias" in k
|
|
):
|
|
del state_dict[k]
|
|
return state_dict
|
|
|
|
|
|
@register_model_architecture("masked_lm", "masked_lm")
|
|
def base_architecture(args):
|
|
args.dropout = getattr(args, "dropout", 0.1)
|
|
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
|
args.act_dropout = getattr(args, "act_dropout", 0.0)
|
|
|
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
|
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
|
|
|
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
|
args.share_encoder_input_output_embed = getattr(
|
|
args, "share_encoder_input_output_embed", False
|
|
)
|
|
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
|
|
args.no_token_positional_embeddings = getattr(
|
|
args, "no_token_positional_embeddings", False
|
|
)
|
|
args.num_segment = getattr(args, "num_segment", 2)
|
|
|
|
args.sentence_class_num = getattr(args, "sentence_class_num", 2)
|
|
args.sent_loss = getattr(args, "sent_loss", False)
|
|
|
|
args.apply_bert_init = getattr(args, "apply_bert_init", False)
|
|
|
|
args.activation_fn = getattr(args, "activation_fn", "relu")
|
|
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
|
|
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
|
|
|
|
|
@register_model_architecture("masked_lm", "bert_base")
|
|
def bert_base_architecture(args):
|
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
|
|
args.share_encoder_input_output_embed = getattr(
|
|
args, "share_encoder_input_output_embed", True
|
|
)
|
|
args.no_token_positional_embeddings = getattr(
|
|
args, "no_token_positional_embeddings", False
|
|
)
|
|
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
|
|
args.num_segment = getattr(args, "num_segment", 2)
|
|
|
|
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
|
|
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
|
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
|
|
|
|
args.sentence_class_num = getattr(args, "sentence_class_num", 2)
|
|
args.sent_loss = getattr(args, "sent_loss", True)
|
|
|
|
args.apply_bert_init = getattr(args, "apply_bert_init", True)
|
|
|
|
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
|
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
|
|
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
|
|
base_architecture(args)
|
|
|
|
|
|
@register_model_architecture("masked_lm", "bert_large")
|
|
def bert_large_architecture(args):
|
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
|
args.encoder_layers = getattr(args, "encoder_layers", 24)
|
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
|
bert_base_architecture(args)
|
|
|
|
|
|
@register_model_architecture("masked_lm", "xlm_base")
|
|
def xlm_architecture(args):
|
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
|
args.share_encoder_input_output_embed = getattr(
|
|
args, "share_encoder_input_output_embed", True
|
|
)
|
|
args.no_token_positional_embeddings = getattr(
|
|
args, "no_token_positional_embeddings", False
|
|
)
|
|
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
|
|
args.num_segment = getattr(args, "num_segment", 1)
|
|
|
|
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
|
|
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
|
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
|
|
|
args.sent_loss = getattr(args, "sent_loss", False)
|
|
|
|
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
|
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
|
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
|
|
args.apply_bert_init = getattr(args, "apply_bert_init", True)
|
|
base_architecture(args)
|