mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-10 14:00:13 +00:00
94 lines
3.4 KiB
Python
94 lines
3.4 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from fairseq.models import register_model, register_model_architecture
|
|
from fairseq.models.transformer import (
|
|
TransformerModel,
|
|
base_architecture,
|
|
transformer_wmt_en_de_big,
|
|
)
|
|
|
|
|
|
@register_model("transformer_align")
|
|
class TransformerAlignModel(TransformerModel):
|
|
"""
|
|
See "Jointly Learning to Align and Translate with Transformer
|
|
Models" (Garg et al., EMNLP 2019).
|
|
"""
|
|
|
|
def __init__(self, encoder, decoder, args):
|
|
super().__init__(args, encoder, decoder)
|
|
self.alignment_heads = args.alignment_heads
|
|
self.alignment_layer = args.alignment_layer
|
|
self.full_context_alignment = args.full_context_alignment
|
|
|
|
@staticmethod
|
|
def add_args(parser):
|
|
# fmt: off
|
|
super(TransformerAlignModel, TransformerAlignModel).add_args(parser)
|
|
parser.add_argument('--alignment-heads', type=int, metavar='D',
|
|
help='Number of cross attention heads per layer to supervised with alignments')
|
|
parser.add_argument('--alignment-layer', type=int, metavar='D',
|
|
help='Layer number which has to be supervised. 0 corresponding to the bottommost layer.')
|
|
parser.add_argument('--full-context-alignment', action='store_true',
|
|
help='Whether or not alignment is supervised conditioned on the full target context.')
|
|
# fmt: on
|
|
|
|
@classmethod
|
|
def build_model(cls, args, task):
|
|
# set any default arguments
|
|
transformer_align(args)
|
|
|
|
transformer_model = TransformerModel.build_model(args, task)
|
|
return TransformerAlignModel(
|
|
transformer_model.encoder, transformer_model.decoder, args
|
|
)
|
|
|
|
def forward(self, src_tokens, src_lengths, prev_output_tokens):
|
|
encoder_out = self.encoder(src_tokens, src_lengths)
|
|
return self.forward_decoder(prev_output_tokens, encoder_out)
|
|
|
|
def forward_decoder(
|
|
self,
|
|
prev_output_tokens,
|
|
encoder_out=None,
|
|
incremental_state=None,
|
|
features_only=False,
|
|
**extra_args,
|
|
):
|
|
attn_args = {
|
|
"alignment_layer": self.alignment_layer,
|
|
"alignment_heads": self.alignment_heads,
|
|
}
|
|
decoder_out = self.decoder(prev_output_tokens, encoder_out, **attn_args)
|
|
|
|
if self.full_context_alignment:
|
|
attn_args["full_context_alignment"] = self.full_context_alignment
|
|
_, alignment_out = self.decoder(
|
|
prev_output_tokens,
|
|
encoder_out,
|
|
features_only=True,
|
|
**attn_args,
|
|
**extra_args,
|
|
)
|
|
decoder_out[1]["attn"] = alignment_out["attn"]
|
|
|
|
return decoder_out
|
|
|
|
|
|
@register_model_architecture("transformer_align", "transformer_align")
|
|
def transformer_align(args):
|
|
args.alignment_heads = getattr(args, "alignment_heads", 1)
|
|
args.alignment_layer = getattr(args, "alignment_layer", 4)
|
|
args.full_context_alignment = getattr(args, "full_context_alignment", False)
|
|
base_architecture(args)
|
|
|
|
|
|
@register_model_architecture("transformer_align", "transformer_wmt_en_de_big_align")
|
|
def transformer_wmt_en_de_big_align(args):
|
|
args.alignment_heads = getattr(args, "alignment_heads", 1)
|
|
args.alignment_layer = getattr(args, "alignment_layer", 4)
|
|
transformer_wmt_en_de_big(args)
|