mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-04 19:10:05 +00:00
575 lines
20 KiB
Python
575 lines
20 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
"""
|
|
Base classes for various fairseq models.
|
|
"""
|
|
|
|
import logging
|
|
from argparse import Namespace
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from fairseq import utils
|
|
from fairseq.data import Dictionary
|
|
from fairseq.dataclass.utils import (
|
|
convert_namespace_to_omegaconf,
|
|
gen_parser_from_dataclass,
|
|
)
|
|
from fairseq.models import FairseqDecoder, FairseqEncoder
|
|
from omegaconf import DictConfig
|
|
from torch import Tensor
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def check_type(module, expected_type):
|
|
if hasattr(module, "unwrapped_module"):
|
|
assert isinstance(
|
|
module.unwrapped_module, expected_type
|
|
), f"{type(module.unwrapped_module)} != {expected_type}"
|
|
else:
|
|
assert isinstance(module, expected_type), f"{type(module)} != {expected_type}"
|
|
|
|
|
|
class BaseFairseqModel(nn.Module):
|
|
"""Base class for fairseq models."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._is_generation_fast = False
|
|
|
|
@classmethod
|
|
def add_args(cls, parser):
|
|
"""Add model-specific arguments to the parser."""
|
|
dc = getattr(cls, "__dataclass", None)
|
|
if dc is not None:
|
|
# do not set defaults so that settings defaults from various architectures still works
|
|
gen_parser_from_dataclass(parser, dc(), delete_default=True)
|
|
|
|
@classmethod
|
|
def build_model(cls, args, task):
|
|
"""Build a new model instance."""
|
|
raise NotImplementedError("Model must implement the build_model method")
|
|
|
|
def get_targets(self, sample, net_output):
|
|
"""Get targets from either the sample or the net's output."""
|
|
return sample["target"]
|
|
|
|
def get_normalized_probs(
|
|
self,
|
|
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
|
|
log_probs: bool,
|
|
sample: Optional[Dict[str, Tensor]] = None,
|
|
):
|
|
"""Get normalized probabilities (or log probs) from a net's output."""
|
|
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
|
|
|
|
# TorchScript doesn't support super() method so that the scriptable Subclass
|
|
# can't access the base class model in Torchscript.
|
|
# Current workaround is to add a helper function with different name and
|
|
# call the helper function from scriptable Subclass.
|
|
def get_normalized_probs_scriptable(
|
|
self,
|
|
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
|
|
log_probs: bool,
|
|
sample: Optional[Dict[str, Tensor]] = None,
|
|
):
|
|
"""Scriptable helper function for get_normalized_probs in ~BaseFairseqModel"""
|
|
if hasattr(self, "decoder"):
|
|
return self.decoder.get_normalized_probs(net_output, log_probs, sample)
|
|
elif torch.is_tensor(net_output):
|
|
# syntactic sugar for simple models which don't have a decoder
|
|
# (e.g., the classification tutorial)
|
|
logits = net_output.float()
|
|
if log_probs:
|
|
return F.log_softmax(logits, dim=-1)
|
|
else:
|
|
return F.softmax(logits, dim=-1)
|
|
raise NotImplementedError
|
|
|
|
def extract_features(self, *args, **kwargs):
|
|
"""Similar to *forward* but only return features."""
|
|
return self(*args, **kwargs)
|
|
|
|
def max_positions(self):
|
|
"""Maximum length supported by the model."""
|
|
return None
|
|
|
|
def load_state_dict(
|
|
self,
|
|
state_dict,
|
|
strict=True,
|
|
model_cfg: Optional[DictConfig] = None,
|
|
args: Optional[Namespace] = None,
|
|
):
|
|
"""Copies parameters and buffers from *state_dict* into this module and
|
|
its descendants.
|
|
|
|
Overrides the method in :class:`nn.Module`. Compared with that method
|
|
this additionally "upgrades" *state_dicts* from old checkpoints.
|
|
"""
|
|
|
|
if model_cfg is None and args is not None:
|
|
logger.warn(
|
|
"using 'args' is deprecated, please update your code to use dataclass config"
|
|
)
|
|
model_cfg = convert_namespace_to_omegaconf(args).model
|
|
|
|
self.upgrade_state_dict(state_dict)
|
|
|
|
from fairseq.checkpoint_utils import prune_state_dict
|
|
|
|
new_state_dict = prune_state_dict(state_dict, model_cfg)
|
|
return super().load_state_dict(new_state_dict, strict)
|
|
|
|
def upgrade_state_dict(self, state_dict):
|
|
"""Upgrade old state dicts to work with newer code."""
|
|
self.upgrade_state_dict_named(state_dict, "")
|
|
|
|
def upgrade_state_dict_named(self, state_dict, name):
|
|
"""Upgrade old state dicts to work with newer code.
|
|
|
|
Args:
|
|
state_dict (dict): state dictionary to upgrade, in place
|
|
name (str): the state dict key corresponding to the current module
|
|
"""
|
|
assert state_dict is not None
|
|
|
|
def do_upgrade(m, prefix):
|
|
if len(prefix) > 0:
|
|
prefix += "."
|
|
|
|
for n, c in m.named_children():
|
|
name = prefix + n
|
|
if hasattr(c, "upgrade_state_dict_named"):
|
|
c.upgrade_state_dict_named(state_dict, name)
|
|
elif hasattr(c, "upgrade_state_dict"):
|
|
c.upgrade_state_dict(state_dict)
|
|
do_upgrade(c, name)
|
|
|
|
do_upgrade(self, name)
|
|
|
|
def set_num_updates(self, num_updates):
|
|
"""State from trainer to pass along to model at every update."""
|
|
for m in self.modules():
|
|
if hasattr(m, "set_num_updates") and m != self:
|
|
m.set_num_updates(num_updates)
|
|
|
|
def prepare_for_inference_(self, cfg: DictConfig):
|
|
"""Prepare model for inference."""
|
|
kwargs = {}
|
|
kwargs["beamable_mm_beam_size"] = (
|
|
None
|
|
if getattr(cfg.generation, "no_beamable_mm", False)
|
|
else getattr(cfg.generation, "beam", 5)
|
|
)
|
|
kwargs["need_attn"] = getattr(cfg.generation, "print_alignment", False)
|
|
if getattr(cfg.generation, "retain_dropout", False):
|
|
kwargs["retain_dropout"] = cfg.generation.retain_dropout
|
|
kwargs["retain_dropout_modules"] = cfg.generation.retain_dropout_modules
|
|
self.make_generation_fast_(**kwargs)
|
|
|
|
def make_generation_fast_(self, **kwargs):
|
|
"""
|
|
Legacy entry point to optimize model for faster generation.
|
|
Prefer prepare_for_inference_.
|
|
"""
|
|
if self._is_generation_fast:
|
|
return # only apply once
|
|
self._is_generation_fast = True
|
|
|
|
# remove weight norm from all modules in the network
|
|
def apply_remove_weight_norm(module):
|
|
try:
|
|
nn.utils.remove_weight_norm(module)
|
|
except (AttributeError, ValueError): # this module didn't have weight norm
|
|
return
|
|
|
|
self.apply(apply_remove_weight_norm)
|
|
|
|
def apply_make_generation_fast_(module, prefix):
|
|
if len(prefix) > 0:
|
|
prefix += "."
|
|
|
|
base_func = BaseFairseqModel.make_generation_fast_
|
|
for n, m in module.named_modules():
|
|
if (
|
|
m != self
|
|
and hasattr(m, "make_generation_fast_")
|
|
# don't call this implementation again, e.g., if
|
|
# children modules also inherit from BaseFairseqModel
|
|
and m.make_generation_fast_.__func__ is not base_func
|
|
):
|
|
name = prefix + n
|
|
m.make_generation_fast_(name=name, **kwargs)
|
|
|
|
apply_make_generation_fast_(self, "")
|
|
|
|
def train(mode=True):
|
|
if mode:
|
|
raise RuntimeError("cannot train after make_generation_fast")
|
|
|
|
# this model should no longer be used for training
|
|
self.eval()
|
|
self.train = train
|
|
|
|
def prepare_for_onnx_export_(self, **kwargs):
|
|
"""Make model exportable via ONNX trace."""
|
|
seen = set()
|
|
|
|
def apply_prepare_for_onnx_export_(module):
|
|
if (
|
|
module != self
|
|
and hasattr(module, "prepare_for_onnx_export_")
|
|
and module not in seen
|
|
):
|
|
seen.add(module)
|
|
module.prepare_for_onnx_export_(**kwargs)
|
|
|
|
self.apply(apply_prepare_for_onnx_export_)
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls,
|
|
model_name_or_path,
|
|
checkpoint_file="model.pt",
|
|
data_name_or_path=".",
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model
|
|
file. Downloads and caches the pre-trained model file if needed.
|
|
|
|
The base implementation returns a
|
|
:class:`~fairseq.hub_utils.GeneratorHubInterface`, which can be used to
|
|
generate translations or sample from language models. The underlying
|
|
:class:`~fairseq.models.FairseqModel` can be accessed via the
|
|
*generator.models* attribute.
|
|
|
|
Other models may override this to implement custom hub interfaces.
|
|
|
|
Args:
|
|
model_name_or_path (str): either the name of a pre-trained model to
|
|
load or a path/URL to a pre-trained model state dict
|
|
checkpoint_file (str, optional): colon-separated list of checkpoint
|
|
files in the model archive to ensemble (default: 'model.pt')
|
|
data_name_or_path (str, optional): point args.data to the archive
|
|
at the given path/URL. Can start with '.' or './' to reuse the
|
|
model archive path.
|
|
"""
|
|
from fairseq import hub_utils
|
|
|
|
x = hub_utils.from_pretrained(
|
|
model_name_or_path,
|
|
checkpoint_file,
|
|
data_name_or_path,
|
|
archive_map=cls.hub_models(),
|
|
**kwargs,
|
|
)
|
|
logger.info(x["args"])
|
|
return hub_utils.GeneratorHubInterface(x["args"], x["task"], x["models"])
|
|
|
|
@classmethod
|
|
def hub_models(cls):
|
|
return {}
|
|
|
|
|
|
class FairseqEncoderDecoderModel(BaseFairseqModel):
|
|
"""Base class for encoder-decoder models.
|
|
|
|
Args:
|
|
encoder (FairseqEncoder): the encoder
|
|
decoder (FairseqDecoder): the decoder
|
|
"""
|
|
|
|
def __init__(self, encoder, decoder):
|
|
super().__init__()
|
|
|
|
self.encoder = encoder
|
|
self.decoder = decoder
|
|
|
|
check_type(self.encoder, FairseqEncoder)
|
|
check_type(self.decoder, FairseqDecoder)
|
|
|
|
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
|
|
"""
|
|
Run the forward pass for an encoder-decoder model.
|
|
|
|
First feed a batch of source tokens through the encoder. Then, feed the
|
|
encoder output and previous decoder outputs (i.e., teacher forcing) to
|
|
the decoder to produce the next outputs::
|
|
|
|
encoder_out = self.encoder(src_tokens, src_lengths)
|
|
return self.decoder(prev_output_tokens, encoder_out)
|
|
|
|
Args:
|
|
src_tokens (LongTensor): tokens in the source language of shape
|
|
`(batch, src_len)`
|
|
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
|
|
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
|
`(batch, tgt_len)`, for teacher forcing
|
|
|
|
Returns:
|
|
tuple:
|
|
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
|
- a dictionary with any model-specific outputs
|
|
"""
|
|
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
|
|
decoder_out = self.decoder(
|
|
prev_output_tokens, encoder_out=encoder_out, **kwargs
|
|
)
|
|
return decoder_out
|
|
|
|
def forward_decoder(self, prev_output_tokens, **kwargs):
|
|
return self.decoder(prev_output_tokens, **kwargs)
|
|
|
|
def extract_features(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
|
|
"""
|
|
Similar to *forward* but only return features.
|
|
|
|
Returns:
|
|
tuple:
|
|
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
|
- a dictionary with any model-specific outputs
|
|
"""
|
|
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
|
|
features = self.decoder.extract_features(
|
|
prev_output_tokens, encoder_out=encoder_out, **kwargs
|
|
)
|
|
return features
|
|
|
|
def output_layer(self, features, **kwargs):
|
|
"""Project features to the default output size (typically vocabulary size)."""
|
|
return self.decoder.output_layer(features, **kwargs)
|
|
|
|
def max_positions(self):
|
|
"""Maximum length supported by the model."""
|
|
return (self.encoder.max_positions(), self.decoder.max_positions())
|
|
|
|
def max_decoder_positions(self):
|
|
"""Maximum length supported by the decoder."""
|
|
return self.decoder.max_positions()
|
|
|
|
|
|
class FairseqModel(FairseqEncoderDecoderModel):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
utils.deprecation_warning(
|
|
"FairseqModel is deprecated, please use FairseqEncoderDecoderModel "
|
|
"or BaseFairseqModel instead",
|
|
stacklevel=4,
|
|
)
|
|
|
|
|
|
class FairseqMultiModel(BaseFairseqModel):
|
|
"""Base class for combining multiple encoder-decoder models."""
|
|
|
|
def __init__(self, encoders, decoders):
|
|
super().__init__()
|
|
assert encoders.keys() == decoders.keys()
|
|
self.keys = list(encoders.keys())
|
|
for key in self.keys:
|
|
check_type(encoders[key], FairseqEncoder)
|
|
check_type(decoders[key], FairseqDecoder)
|
|
|
|
self.models = nn.ModuleDict(
|
|
{
|
|
key: FairseqEncoderDecoderModel(encoders[key], decoders[key])
|
|
for key in self.keys
|
|
}
|
|
)
|
|
|
|
@staticmethod
|
|
def build_shared_embeddings(
|
|
dicts: Dict[str, Dictionary],
|
|
langs: List[str],
|
|
embed_dim: int,
|
|
build_embedding: callable,
|
|
pretrained_embed_path: Optional[str] = None,
|
|
):
|
|
"""
|
|
Helper function to build shared embeddings for a set of languages after
|
|
checking that all dicts corresponding to those languages are equivalent.
|
|
|
|
Args:
|
|
dicts: Dict of lang_id to its corresponding Dictionary
|
|
langs: languages that we want to share embeddings for
|
|
embed_dim: embedding dimension
|
|
build_embedding: callable function to actually build the embedding
|
|
pretrained_embed_path: Optional path to load pretrained embeddings
|
|
"""
|
|
shared_dict = dicts[langs[0]]
|
|
if any(dicts[lang] != shared_dict for lang in langs):
|
|
raise ValueError(
|
|
"--share-*-embeddings requires a joined dictionary: "
|
|
"--share-encoder-embeddings requires a joined source "
|
|
"dictionary, --share-decoder-embeddings requires a joined "
|
|
"target dictionary, and --share-all-embeddings requires a "
|
|
"joint source + target dictionary."
|
|
)
|
|
return build_embedding(shared_dict, embed_dim, pretrained_embed_path)
|
|
|
|
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
def max_positions(self):
|
|
"""Maximum length supported by the model."""
|
|
return {
|
|
key: (
|
|
self.models[key].encoder.max_positions(),
|
|
self.models[key].decoder.max_positions(),
|
|
)
|
|
for key in self.keys
|
|
}
|
|
|
|
def max_decoder_positions(self):
|
|
"""Maximum length supported by the decoder."""
|
|
return min(model.decoder.max_positions() for model in self.models.values())
|
|
|
|
@property
|
|
def encoder(self):
|
|
return self.models[self.keys[0]].encoder
|
|
|
|
@property
|
|
def decoder(self):
|
|
return self.models[self.keys[0]].decoder
|
|
|
|
def forward_decoder(self, prev_output_tokens, **kwargs):
|
|
return self.decoder(prev_output_tokens, **kwargs)
|
|
|
|
def load_state_dict(
|
|
self,
|
|
state_dict,
|
|
strict=True,
|
|
model_cfg=None,
|
|
args: Optional[Namespace] = None,
|
|
):
|
|
"""Copies parameters and buffers from *state_dict* into this module and
|
|
its descendants.
|
|
|
|
Overrides the method in :class:`nn.Module`. Compared with that method
|
|
this additionally "upgrades" *state_dicts* from old checkpoints.
|
|
"""
|
|
|
|
if model_cfg is None and args is not None:
|
|
logger.warn(
|
|
"using 'args' is deprecated, please update your code to use dataclass config"
|
|
)
|
|
model_cfg = convert_namespace_to_omegaconf(args).model
|
|
|
|
self.upgrade_state_dict(state_dict)
|
|
|
|
from fairseq.checkpoint_utils import prune_state_dict
|
|
|
|
new_state_dict = prune_state_dict(state_dict, model_cfg)
|
|
return super().load_state_dict(new_state_dict, strict)
|
|
|
|
|
|
class FairseqLanguageModel(BaseFairseqModel):
|
|
"""Base class for decoder-only models.
|
|
|
|
Args:
|
|
decoder (FairseqDecoder): the decoder
|
|
"""
|
|
|
|
def __init__(self, decoder):
|
|
super().__init__()
|
|
self.decoder = decoder
|
|
check_type(self.decoder, FairseqDecoder)
|
|
|
|
def forward(self, src_tokens, **kwargs):
|
|
"""
|
|
Run the forward pass for a decoder-only model.
|
|
|
|
Feeds a batch of tokens through the decoder to predict the next tokens.
|
|
|
|
Args:
|
|
src_tokens (LongTensor): tokens on which to condition the decoder,
|
|
of shape `(batch, tgt_len)`
|
|
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
|
|
|
|
Returns:
|
|
tuple:
|
|
- the decoder's output of shape `(batch, seq_len, vocab)`
|
|
- a dictionary with any model-specific outputs
|
|
"""
|
|
return self.decoder(src_tokens, **kwargs)
|
|
|
|
def forward_decoder(self, prev_output_tokens, **kwargs):
|
|
return self.decoder(prev_output_tokens, **kwargs)
|
|
|
|
def extract_features(self, src_tokens, **kwargs):
|
|
"""
|
|
Similar to *forward* but only return features.
|
|
|
|
Returns:
|
|
tuple:
|
|
- the decoder's features of shape `(batch, seq_len, embed_dim)`
|
|
- a dictionary with any model-specific outputs
|
|
"""
|
|
return self.decoder.extract_features(src_tokens, **kwargs)
|
|
|
|
def output_layer(self, features, **kwargs):
|
|
"""Project features to the default output size (typically vocabulary size)."""
|
|
return self.decoder.output_layer(features, **kwargs)
|
|
|
|
def max_positions(self):
|
|
"""Maximum length supported by the model."""
|
|
return self.decoder.max_positions()
|
|
|
|
def max_decoder_positions(self):
|
|
"""Maximum length supported by the decoder."""
|
|
return self.decoder.max_positions()
|
|
|
|
@property
|
|
def supported_targets(self):
|
|
return {"future"}
|
|
|
|
|
|
class FairseqEncoderModel(BaseFairseqModel):
|
|
"""Base class for encoder-only models.
|
|
|
|
Args:
|
|
encoder (FairseqEncoder): the encoder
|
|
"""
|
|
|
|
def __init__(self, encoder):
|
|
super().__init__()
|
|
self.encoder = encoder
|
|
check_type(self.encoder, FairseqEncoder)
|
|
|
|
def forward(self, src_tokens, src_lengths, **kwargs):
|
|
"""
|
|
Run the forward pass for a encoder-only model.
|
|
|
|
Feeds a batch of tokens through the encoder to generate features.
|
|
|
|
Args:
|
|
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
|
|
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
|
|
|
|
Returns:
|
|
the encoder's output, typically of shape `(batch, src_len, features)`
|
|
"""
|
|
return self.encoder(src_tokens, src_lengths, **kwargs)
|
|
|
|
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
|
"""Get normalized probabilities (or log probs) from a net's output."""
|
|
encoder_out = net_output["encoder_out"]
|
|
if torch.is_tensor(encoder_out):
|
|
logits = encoder_out.float()
|
|
if log_probs:
|
|
return F.log_softmax(logits, dim=-1)
|
|
else:
|
|
return F.softmax(logits, dim=-1)
|
|
raise NotImplementedError
|
|
|
|
def max_positions(self):
|
|
"""Maximum length supported by the model."""
|
|
return self.encoder.max_positions()
|