mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-02-24 07:04:12 +00:00
93 lines
2.9 KiB
Python
93 lines
2.9 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from typing import Dict, List, NamedTuple, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
|
|
|
|
EncoderOut = NamedTuple(
|
|
"EncoderOut",
|
|
[
|
|
("encoder_out", Tensor), # T x B x C
|
|
("encoder_padding_mask", Optional[Tensor]), # B x T
|
|
("encoder_embedding", Optional[Tensor]), # B x T x C
|
|
("encoder_states", Optional[List[Tensor]]), # List[T x B x C]
|
|
("src_tokens", Optional[Tensor]), # B x T
|
|
("src_lengths", Optional[Tensor]), # B x 1
|
|
],
|
|
)
|
|
|
|
|
|
class FairseqEncoder(nn.Module):
|
|
"""Base class for encoders."""
|
|
|
|
def __init__(self, dictionary):
|
|
super().__init__()
|
|
self.dictionary = dictionary
|
|
|
|
def forward(self, src_tokens, src_lengths=None, **kwargs):
|
|
"""
|
|
Args:
|
|
src_tokens (LongTensor): tokens in the source language of shape
|
|
`(batch, src_len)`
|
|
src_lengths (LongTensor): lengths of each source sentence of shape
|
|
`(batch)`
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def forward_torchscript(self, net_input: Dict[str, Tensor]):
|
|
"""A TorchScript-compatible version of forward.
|
|
|
|
Encoders which use additional arguments may want to override
|
|
this method for TorchScript compatibility.
|
|
"""
|
|
if torch.jit.is_scripting():
|
|
return self.forward(
|
|
src_tokens=net_input["src_tokens"],
|
|
src_lengths=net_input["src_lengths"],
|
|
)
|
|
else:
|
|
return self.forward_non_torchscript(net_input)
|
|
|
|
@torch.jit.unused
|
|
def forward_non_torchscript(self, net_input: Dict[str, Tensor]):
|
|
encoder_input = {
|
|
k: v for k, v in net_input.items() if k != "prev_output_tokens"
|
|
}
|
|
return self.forward(**encoder_input)
|
|
|
|
def reorder_encoder_out(self, encoder_out, new_order):
|
|
"""
|
|
Reorder encoder output according to `new_order`.
|
|
|
|
Args:
|
|
encoder_out: output from the ``forward()`` method
|
|
new_order (LongTensor): desired order
|
|
|
|
Returns:
|
|
`encoder_out` rearranged according to `new_order`
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def max_positions(self):
|
|
"""Maximum input length supported by the encoder."""
|
|
return 1e6 # an arbitrary large number
|
|
|
|
def upgrade_state_dict_named(self, state_dict, name):
|
|
"""Upgrade old state dicts to work with newer code."""
|
|
return state_dict
|
|
|
|
def set_num_updates(self, num_updates):
|
|
"""State from trainer to pass along to model at every update."""
|
|
|
|
def _apply(m):
|
|
if hasattr(m, "set_num_updates") and m != self:
|
|
m.set_num_updates(num_updates)
|
|
|
|
self.apply(_apply)
|