mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-02-23 22:54:12 +00:00
58 lines
1.9 KiB
Python
58 lines
1.9 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from .fairseq_encoder import FairseqEncoder
|
|
|
|
|
|
class CompositeEncoder(FairseqEncoder):
|
|
"""
|
|
A wrapper around a dictionary of :class:`FairseqEncoder` objects.
|
|
|
|
We run forward on each encoder and return a dictionary of outputs. The first
|
|
encoder's dictionary is used for initialization.
|
|
|
|
Args:
|
|
encoders (dict): a dictionary of :class:`FairseqEncoder` objects.
|
|
"""
|
|
|
|
def __init__(self, encoders):
|
|
super().__init__(next(iter(encoders.values())).dictionary)
|
|
self.encoders = encoders
|
|
for key in self.encoders:
|
|
self.add_module(key, self.encoders[key])
|
|
|
|
def forward(self, src_tokens, src_lengths):
|
|
"""
|
|
Args:
|
|
src_tokens (LongTensor): tokens in the source language of shape
|
|
`(batch, src_len)`
|
|
src_lengths (LongTensor): lengths of each source sentence of shape
|
|
`(batch)`
|
|
|
|
Returns:
|
|
dict:
|
|
the outputs from each Encoder
|
|
"""
|
|
encoder_out = {}
|
|
for key in self.encoders:
|
|
encoder_out[key] = self.encoders[key](src_tokens, src_lengths)
|
|
return encoder_out
|
|
|
|
def reorder_encoder_out(self, encoder_out, new_order):
|
|
"""Reorder encoder output according to new_order."""
|
|
for key in self.encoders:
|
|
encoder_out[key] = self.encoders[key].reorder_encoder_out(
|
|
encoder_out[key], new_order
|
|
)
|
|
return encoder_out
|
|
|
|
def max_positions(self):
|
|
return min(self.encoders[key].max_positions() for key in self.encoders)
|
|
|
|
def upgrade_state_dict(self, state_dict):
|
|
for key in self.encoders:
|
|
self.encoders[key].upgrade_state_dict(state_dict)
|
|
return state_dict
|