mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-10 14:00:13 +00:00
52 lines
1.7 KiB
Python
52 lines
1.7 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
from fairseq.models.roberta.hub_interface import RobertaHubInterface
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class XMODHubInterface(RobertaHubInterface):
|
|
def extract_features(
|
|
self,
|
|
tokens: torch.LongTensor,
|
|
return_all_hiddens: bool = False,
|
|
lang_id=None,
|
|
) -> torch.Tensor:
|
|
if tokens.dim() == 1:
|
|
tokens = tokens.unsqueeze(0)
|
|
if tokens.size(-1) > self.model.max_positions():
|
|
raise ValueError(
|
|
"tokens exceeds maximum length: {} > {}".format(
|
|
tokens.size(-1), self.model.max_positions()
|
|
)
|
|
)
|
|
features, extra = self.model(
|
|
tokens.to(device=self.device),
|
|
features_only=True,
|
|
return_all_hiddens=return_all_hiddens,
|
|
lang_id=lang_id,
|
|
)
|
|
if return_all_hiddens:
|
|
# convert from T x B x C -> B x T x C
|
|
inner_states = extra["inner_states"]
|
|
return [inner_state.transpose(0, 1) for inner_state in inner_states]
|
|
else:
|
|
return features # just the last layer's features
|
|
|
|
def predict(
|
|
self,
|
|
head: str,
|
|
tokens: torch.LongTensor,
|
|
return_logits: bool = False,
|
|
lang_id=None,
|
|
):
|
|
features = self.extract_features(tokens.to(device=self.device), lang_id=lang_id)
|
|
logits = self.model.classification_heads[head](features)
|
|
if return_logits:
|
|
return logits
|
|
return F.log_softmax(logits, dim=-1)
|