mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-03 02:20:02 +00:00
154 lines
5.3 KiB
Python
154 lines
5.3 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import sys
|
|
|
|
import torch
|
|
from fairseq import utils
|
|
|
|
|
|
class SequenceScorer(object):
|
|
"""Scores the target for a given source sentence."""
|
|
|
|
def __init__(
|
|
self,
|
|
tgt_dict,
|
|
softmax_batch=None,
|
|
compute_alignment=False,
|
|
eos=None,
|
|
symbols_to_strip_from_output=None,
|
|
):
|
|
self.pad = tgt_dict.pad()
|
|
self.eos = tgt_dict.eos() if eos is None else eos
|
|
self.softmax_batch = softmax_batch or sys.maxsize
|
|
assert self.softmax_batch > 0
|
|
self.compute_alignment = compute_alignment
|
|
self.symbols_to_strip_from_output = (
|
|
symbols_to_strip_from_output.union({self.eos})
|
|
if symbols_to_strip_from_output is not None
|
|
else {self.eos}
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def generate(self, models, sample, **kwargs):
|
|
"""Score a batch of translations."""
|
|
net_input = sample["net_input"]
|
|
|
|
def batch_for_softmax(dec_out, target):
|
|
# assumes decoder_out[0] is the only thing needed (may not be correct for future models!)
|
|
first, rest = dec_out[0], dec_out[1:]
|
|
bsz, tsz, dim = first.shape
|
|
if bsz * tsz < self.softmax_batch:
|
|
yield dec_out, target, True
|
|
else:
|
|
flat = first.contiguous().view(1, -1, dim)
|
|
flat_tgt = target.contiguous().view(flat.shape[:-1])
|
|
s = 0
|
|
while s < flat.size(1):
|
|
e = s + self.softmax_batch
|
|
yield (flat[:, s:e],) + rest, flat_tgt[:, s:e], False
|
|
s = e
|
|
|
|
def gather_target_probs(probs, target):
|
|
probs = probs.gather(
|
|
dim=2,
|
|
index=target.unsqueeze(-1),
|
|
)
|
|
return probs
|
|
|
|
orig_target = sample["target"]
|
|
|
|
# compute scores for each model in the ensemble
|
|
avg_probs = None
|
|
avg_attn = None
|
|
for model in models:
|
|
model.eval()
|
|
decoder_out = model(**net_input)
|
|
attn = decoder_out[1] if len(decoder_out) > 1 else None
|
|
if type(attn) is dict:
|
|
attn = attn.get("attn", None)
|
|
|
|
batched = batch_for_softmax(decoder_out, orig_target)
|
|
probs, idx = None, 0
|
|
for bd, tgt, is_single in batched:
|
|
sample["target"] = tgt
|
|
curr_prob = model.get_normalized_probs(
|
|
bd, log_probs=len(models) == 1, sample=sample
|
|
).data
|
|
if is_single:
|
|
probs = gather_target_probs(curr_prob, orig_target)
|
|
else:
|
|
if probs is None:
|
|
probs = curr_prob.new(orig_target.numel())
|
|
step = curr_prob.size(0) * curr_prob.size(1)
|
|
end = step + idx
|
|
tgt_probs = gather_target_probs(
|
|
curr_prob.view(tgt.shape + (curr_prob.size(-1),)), tgt
|
|
)
|
|
probs[idx:end] = tgt_probs.view(-1)
|
|
idx = end
|
|
sample["target"] = orig_target
|
|
|
|
probs = probs.view(sample["target"].shape)
|
|
|
|
if avg_probs is None:
|
|
avg_probs = probs
|
|
else:
|
|
avg_probs.add_(probs)
|
|
if attn is not None:
|
|
if torch.is_tensor(attn):
|
|
attn = attn.data
|
|
else:
|
|
attn = attn[0]
|
|
if avg_attn is None:
|
|
avg_attn = attn
|
|
else:
|
|
avg_attn.add_(attn)
|
|
if len(models) > 1:
|
|
avg_probs.div_(len(models))
|
|
avg_probs.log_()
|
|
if avg_attn is not None:
|
|
avg_attn.div_(len(models))
|
|
|
|
bsz = avg_probs.size(0)
|
|
hypos = []
|
|
start_idxs = sample["start_indices"] if "start_indices" in sample else [0] * bsz
|
|
for i in range(bsz):
|
|
# remove padding from ref
|
|
ref = (
|
|
utils.strip_pad(sample["target"][i, start_idxs[i] :], self.pad)
|
|
if sample["target"] is not None
|
|
else None
|
|
)
|
|
tgt_len = ref.numel()
|
|
avg_probs_i = avg_probs[i][start_idxs[i] : start_idxs[i] + tgt_len]
|
|
score_i = avg_probs_i.sum() / tgt_len
|
|
if avg_attn is not None:
|
|
avg_attn_i = avg_attn[i]
|
|
if self.compute_alignment:
|
|
alignment = utils.extract_hard_alignment(
|
|
avg_attn_i,
|
|
sample["net_input"]["src_tokens"][i],
|
|
sample["target"][i],
|
|
self.pad,
|
|
self.eos,
|
|
)
|
|
else:
|
|
alignment = None
|
|
else:
|
|
avg_attn_i = alignment = None
|
|
hypos.append(
|
|
[
|
|
{
|
|
"tokens": ref,
|
|
"score": score_i,
|
|
"attention": avg_attn_i,
|
|
"alignment": alignment,
|
|
"positional_scores": avg_probs_i,
|
|
}
|
|
]
|
|
)
|
|
return hypos
|