mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-02-27 08:34:10 +00:00
395 lines
16 KiB
Python
395 lines
16 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
"""
|
|
BART: Denoising Sequence-to-Sequence Pre-training for
|
|
Natural Language Generation, Translation, and Comprehension
|
|
"""
|
|
import logging
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from fairseq import utils
|
|
from fairseq.models import register_model, register_model_architecture
|
|
from fairseq.models.transformer import TransformerModel
|
|
from fairseq.modules.transformer_sentence_encoder import init_bert_params
|
|
|
|
from .hub_interface import BARTHubInterface
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@register_model("bart")
|
|
class BARTModel(TransformerModel):
|
|
__jit_unused_properties__ = ["supported_targets"]
|
|
|
|
@classmethod
|
|
def hub_models(cls):
|
|
return {
|
|
"bart.base": "http://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz",
|
|
"bart.large": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz",
|
|
"bart.large.mnli": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz",
|
|
"bart.large.cnn": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz",
|
|
"bart.large.xsum": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz",
|
|
}
|
|
|
|
def __init__(self, args, encoder, decoder):
|
|
super().__init__(args, encoder, decoder)
|
|
|
|
# We follow BERT's random weight initialization
|
|
self.apply(init_bert_params)
|
|
|
|
self.classification_heads = nn.ModuleDict()
|
|
if hasattr(self.encoder, "dictionary"):
|
|
self.eos: int = self.encoder.dictionary.eos()
|
|
|
|
@staticmethod
|
|
def add_args(parser):
|
|
super(BARTModel, BARTModel).add_args(parser)
|
|
parser.add_argument(
|
|
"--pooler-dropout",
|
|
type=float,
|
|
metavar="D",
|
|
help="dropout probability in the masked_lm pooler layers",
|
|
)
|
|
parser.add_argument(
|
|
"--pooler-activation-fn",
|
|
choices=utils.get_available_activation_fns(),
|
|
help="activation function to use for pooler layer",
|
|
)
|
|
parser.add_argument(
|
|
"--spectral-norm-classification-head",
|
|
action="store_true",
|
|
help="Apply spectral normalization on the classification head",
|
|
)
|
|
|
|
@property
|
|
def supported_targets(self):
|
|
return {"self"}
|
|
|
|
def forward(
|
|
self,
|
|
src_tokens,
|
|
src_lengths,
|
|
prev_output_tokens,
|
|
features_only: bool = False,
|
|
classification_head_name: Optional[str] = None,
|
|
token_embeddings: Optional[torch.Tensor] = None,
|
|
return_all_hiddens: bool = True,
|
|
alignment_layer: Optional[int] = None,
|
|
alignment_heads: Optional[int] = None,
|
|
):
|
|
if classification_head_name is not None:
|
|
features_only = True
|
|
|
|
encoder_out = self.encoder(
|
|
src_tokens,
|
|
src_lengths=src_lengths,
|
|
token_embeddings=token_embeddings,
|
|
return_all_hiddens=return_all_hiddens,
|
|
)
|
|
x, extra = self.decoder(
|
|
prev_output_tokens,
|
|
encoder_out=encoder_out,
|
|
features_only=features_only,
|
|
alignment_layer=alignment_layer,
|
|
alignment_heads=alignment_heads,
|
|
src_lengths=src_lengths,
|
|
return_all_hiddens=return_all_hiddens,
|
|
)
|
|
eos: int = self.eos
|
|
if classification_head_name is not None:
|
|
sentence_representation = x[src_tokens.eq(eos), :].view(
|
|
x.size(0), -1, x.size(-1)
|
|
)[:, -1, :]
|
|
for k, head in self.classification_heads.items():
|
|
# for torch script only supports iteration
|
|
if k == classification_head_name:
|
|
x = head(sentence_representation)
|
|
break
|
|
return x, extra
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls,
|
|
model_name_or_path,
|
|
checkpoint_file="model.pt",
|
|
data_name_or_path=".",
|
|
bpe="gpt2",
|
|
sample_break_mode="eos",
|
|
**kwargs,
|
|
):
|
|
from fairseq import hub_utils
|
|
|
|
x = hub_utils.from_pretrained(
|
|
model_name_or_path,
|
|
checkpoint_file,
|
|
data_name_or_path,
|
|
archive_map=cls.hub_models(),
|
|
bpe=bpe,
|
|
load_checkpoint_heads=True,
|
|
sample_break_mode=sample_break_mode,
|
|
**kwargs,
|
|
)
|
|
return BARTHubInterface(x["args"], x["task"], x["models"][0])
|
|
|
|
def register_classification_head(
|
|
self, name, num_classes=None, inner_dim=None, **kwargs
|
|
):
|
|
"""Register a classification head."""
|
|
logger.info("Registering classification head: {0}".format(name))
|
|
if name in self.classification_heads:
|
|
prev_num_classes = self.classification_heads[name].out_proj.out_features
|
|
prev_inner_dim = self.classification_heads[name].dense.out_features
|
|
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
|
|
logger.warning(
|
|
're-registering head "{}" with num_classes {} (prev: {}) '
|
|
"and inner_dim {} (prev: {})".format(
|
|
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
|
|
)
|
|
)
|
|
self.classification_heads[name] = BARTClassificationHead(
|
|
input_dim=self.args.encoder_embed_dim,
|
|
inner_dim=inner_dim or self.args.encoder_embed_dim,
|
|
num_classes=num_classes,
|
|
activation_fn=self.args.pooler_activation_fn,
|
|
pooler_dropout=self.args.pooler_dropout,
|
|
do_spectral_norm=getattr(
|
|
self.args, "spectral_norm_classification_head", False
|
|
),
|
|
)
|
|
|
|
def upgrade_state_dict_named(self, state_dict, name):
|
|
super().upgrade_state_dict_named(state_dict, name)
|
|
|
|
prefix = name + "." if name != "" else ""
|
|
current_head_names = (
|
|
[]
|
|
if not hasattr(self, "classification_heads")
|
|
else self.classification_heads.keys()
|
|
)
|
|
|
|
# Handle new classification heads present in the state dict.
|
|
keys_to_delete = []
|
|
for k in state_dict.keys():
|
|
if not k.startswith(prefix + "classification_heads."):
|
|
continue
|
|
|
|
head_name = k[len(prefix + "classification_heads.") :].split(".")[0]
|
|
num_classes = state_dict[
|
|
prefix + "classification_heads." + head_name + ".out_proj.weight"
|
|
].size(0)
|
|
inner_dim = state_dict[
|
|
prefix + "classification_heads." + head_name + ".dense.weight"
|
|
].size(0)
|
|
|
|
if getattr(self.args, "load_checkpoint_heads", False):
|
|
if head_name not in current_head_names:
|
|
self.register_classification_head(head_name, num_classes, inner_dim)
|
|
else:
|
|
if head_name not in current_head_names:
|
|
logger.warning(
|
|
"deleting classification head ({}) from checkpoint "
|
|
"not present in current model: {}".format(head_name, k)
|
|
)
|
|
keys_to_delete.append(k)
|
|
elif (
|
|
num_classes
|
|
!= self.classification_heads[head_name].out_proj.out_features
|
|
or inner_dim
|
|
!= self.classification_heads[head_name].dense.out_features
|
|
):
|
|
logger.warning(
|
|
"deleting classification head ({}) from checkpoint "
|
|
"with different dimensions than current model: {}".format(
|
|
head_name, k
|
|
)
|
|
)
|
|
keys_to_delete.append(k)
|
|
for k in keys_to_delete:
|
|
del state_dict[k]
|
|
|
|
def truncate_emb(key):
|
|
if key in state_dict:
|
|
state_dict[key] = state_dict[key][:-1, :]
|
|
|
|
# When finetuning on translation task, remove last row of
|
|
# embedding matrix that corresponds to mask_idx token.
|
|
loaded_dict_size = state_dict["encoder.embed_tokens.weight"].size(0)
|
|
if (
|
|
loaded_dict_size == len(self.encoder.dictionary) + 1
|
|
and "<mask>" not in self.encoder.dictionary
|
|
):
|
|
truncate_emb("encoder.embed_tokens.weight")
|
|
truncate_emb("decoder.embed_tokens.weight")
|
|
truncate_emb("encoder.output_projection.weight")
|
|
truncate_emb("decoder.output_projection.weight")
|
|
|
|
# When continued pretraining on new set of languages for mbart,
|
|
# add extra lang embeddings at the end of embed_tokens.
|
|
# Note: newly added languages are assumed to have been added at the end.
|
|
if self.args.task == "multilingual_denoising" and loaded_dict_size < len(
|
|
self.encoder.dictionary
|
|
):
|
|
logger.info(
|
|
"Adding extra language embeddings not found in pretrained model for "
|
|
"continued pretraining of MBART on new set of languages."
|
|
)
|
|
loaded_mask_token_embedding = state_dict["encoder.embed_tokens.weight"][
|
|
-1, :
|
|
]
|
|
|
|
num_langids_to_add = len(self.encoder.dictionary) - loaded_dict_size
|
|
embed_dim = state_dict["encoder.embed_tokens.weight"].size(1)
|
|
|
|
new_lang_embed_to_add = torch.zeros(num_langids_to_add, embed_dim)
|
|
nn.init.normal_(new_lang_embed_to_add, mean=0, std=embed_dim**-0.5)
|
|
new_lang_embed_to_add = new_lang_embed_to_add.to(
|
|
dtype=state_dict["encoder.embed_tokens.weight"].dtype,
|
|
)
|
|
|
|
state_dict["encoder.embed_tokens.weight"] = torch.cat(
|
|
[
|
|
state_dict["encoder.embed_tokens.weight"][
|
|
: loaded_dict_size - 1, :
|
|
],
|
|
new_lang_embed_to_add,
|
|
loaded_mask_token_embedding.unsqueeze(0),
|
|
]
|
|
)
|
|
state_dict["decoder.embed_tokens.weight"] = torch.cat(
|
|
[
|
|
state_dict["decoder.embed_tokens.weight"][
|
|
: loaded_dict_size - 1, :
|
|
],
|
|
new_lang_embed_to_add,
|
|
loaded_mask_token_embedding.unsqueeze(0),
|
|
]
|
|
)
|
|
|
|
# Copy any newly-added classification heads into the state dict
|
|
# with their current weights.
|
|
if hasattr(self, "classification_heads"):
|
|
cur_state = self.classification_heads.state_dict()
|
|
for k, v in cur_state.items():
|
|
if prefix + "classification_heads." + k not in state_dict:
|
|
logger.info("Overwriting " + prefix + "classification_heads." + k)
|
|
state_dict[prefix + "classification_heads." + k] = v
|
|
|
|
def set_beam_size(self, beam):
|
|
"""Set beam size for efficient beamable enc-dec attention."""
|
|
beamable = False
|
|
for layer in self.decoder.layers:
|
|
if layer.encoder_attn is not None:
|
|
if hasattr(layer.encoder_attn, "set_beam_size"):
|
|
layer.encoder_attn.set_beam_size(beam)
|
|
beamable = True
|
|
if beamable:
|
|
self.encoder.reorder_encoder_out = self.encoder._reorder_encoder_out
|
|
|
|
|
|
class BARTClassificationHead(nn.Module):
|
|
"""Head for sentence-level classification tasks."""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim,
|
|
inner_dim,
|
|
num_classes,
|
|
activation_fn,
|
|
pooler_dropout,
|
|
do_spectral_norm=False,
|
|
):
|
|
super().__init__()
|
|
self.dense = nn.Linear(input_dim, inner_dim)
|
|
self.activation_fn = utils.get_activation_fn(activation_fn)
|
|
self.dropout = nn.Dropout(p=pooler_dropout)
|
|
self.out_proj = nn.Linear(inner_dim, num_classes)
|
|
|
|
if do_spectral_norm:
|
|
self.out_proj = torch.nn.utils.spectral_norm(self.out_proj)
|
|
|
|
def forward(self, features, **kwargs):
|
|
x = features
|
|
x = self.dropout(x)
|
|
x = self.dense(x)
|
|
x = self.activation_fn(x)
|
|
x = self.dropout(x)
|
|
x = self.out_proj(x)
|
|
return x
|
|
|
|
|
|
@register_model_architecture("bart", "bart_large")
|
|
def bart_large_architecture(args):
|
|
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
|
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 1024)
|
|
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
|
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
|
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
|
|
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
|
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
|
|
args.decoder_ffn_embed_dim = getattr(
|
|
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
|
|
)
|
|
args.decoder_layers = getattr(args, "decoder_layers", 12)
|
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
|
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
|
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True)
|
|
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
|
args.relu_dropout = getattr(args, "relu_dropout", 0.0)
|
|
args.dropout = getattr(args, "dropout", 0.1)
|
|
args.max_target_positions = getattr(args, "max_target_positions", 1024)
|
|
args.max_source_positions = getattr(args, "max_source_positions", 1024)
|
|
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
|
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
|
args.share_decoder_input_output_embed = getattr(
|
|
args, "share_decoder_input_output_embed", True
|
|
)
|
|
args.share_all_embeddings = getattr(args, "share_all_embeddings", True)
|
|
|
|
args.decoder_output_dim = getattr(
|
|
args, "decoder_output_dim", args.decoder_embed_dim
|
|
)
|
|
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
|
|
|
args.no_scale_embedding = getattr(args, "no_scale_embedding", True)
|
|
args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
|
|
|
|
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
|
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
|
|
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
|
|
|
|
|
|
@register_model_architecture("bart", "bart_base")
|
|
def bart_base_architecture(args):
|
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
|
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 768)
|
|
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
|
|
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12)
|
|
bart_large_architecture(args)
|
|
|
|
|
|
@register_model_architecture("bart", "mbart_large")
|
|
def mbart_large_architecture(args):
|
|
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
|
bart_large_architecture(args)
|
|
|
|
|
|
@register_model_architecture("bart", "mbart_base")
|
|
def mbart_base_architecture(args):
|
|
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
|
bart_base_architecture(args)
|
|
|
|
|
|
@register_model_architecture("bart", "mbart_base_wmt20")
|
|
def mbart_base_wmt20_architecture(args):
|
|
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
|
|
mbart_base_architecture(args)
|