mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-02-26 08:04:10 +00:00
360 lines
13 KiB
Python
360 lines
13 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from collections import namedtuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
from fairseq import utils
|
|
|
|
|
|
DecoderOut = namedtuple(
|
|
"IterativeRefinementDecoderOut",
|
|
["output_tokens", "output_scores", "attn", "step", "max_step", "history"],
|
|
)
|
|
|
|
|
|
class IterativeRefinementGenerator(object):
|
|
def __init__(
|
|
self,
|
|
tgt_dict,
|
|
models=None,
|
|
eos_penalty=0.0,
|
|
max_iter=10,
|
|
max_ratio=2,
|
|
beam_size=1,
|
|
decoding_format=None,
|
|
retain_dropout=False,
|
|
adaptive=True,
|
|
retain_history=False,
|
|
reranking=False,
|
|
):
|
|
"""
|
|
Generates translations based on iterative refinement.
|
|
|
|
Args:
|
|
tgt_dict: target dictionary
|
|
eos_penalty: if > 0.0, it penalized early-stopping in decoding
|
|
max_iter: maximum number of refinement iterations
|
|
max_ratio: generate sequences of maximum length ax, where x is the source length
|
|
decoding_format: decoding mode in {'unigram', 'ensemble', 'vote', 'dp', 'bs'}
|
|
retain_dropout: retaining dropout in the inference
|
|
adaptive: decoding with early stop
|
|
"""
|
|
self.bos = tgt_dict.bos()
|
|
self.pad = tgt_dict.pad()
|
|
self.unk = tgt_dict.unk()
|
|
self.eos = tgt_dict.eos()
|
|
self.vocab_size = len(tgt_dict)
|
|
self.eos_penalty = eos_penalty
|
|
self.max_iter = max_iter
|
|
self.max_ratio = max_ratio
|
|
self.beam_size = beam_size
|
|
self.reranking = reranking
|
|
self.decoding_format = decoding_format
|
|
self.retain_dropout = retain_dropout
|
|
self.retain_history = retain_history
|
|
self.adaptive = adaptive
|
|
self.models = models
|
|
|
|
def generate_batched_itr(
|
|
self,
|
|
data_itr,
|
|
maxlen_a=None,
|
|
maxlen_b=None,
|
|
cuda=False,
|
|
timer=None,
|
|
prefix_size=0,
|
|
):
|
|
"""Iterate over a batched dataset and yield individual translations.
|
|
|
|
Args:
|
|
maxlen_a/b: generate sequences of maximum length ax + b,
|
|
where x is the source sentence length.
|
|
cuda: use GPU for generation
|
|
timer: StopwatchMeter for timing generations.
|
|
"""
|
|
|
|
for sample in data_itr:
|
|
if "net_input" not in sample:
|
|
continue
|
|
if timer is not None:
|
|
timer.start()
|
|
with torch.no_grad():
|
|
hypos = self.generate(
|
|
self.models,
|
|
sample,
|
|
prefix_tokens=sample["target"][:, :prefix_size]
|
|
if prefix_size > 0
|
|
else None,
|
|
)
|
|
if timer is not None:
|
|
timer.stop(sample["ntokens"])
|
|
for i, id in enumerate(sample["id"]):
|
|
# remove padding
|
|
src = utils.strip_pad(sample["net_input"]["src_tokens"][i, :], self.pad)
|
|
ref = utils.strip_pad(sample["target"][i, :], self.pad)
|
|
yield id, src, ref, hypos[i]
|
|
|
|
@torch.no_grad()
|
|
def generate(self, models, sample, prefix_tokens=None, constraints=None):
|
|
if constraints is not None:
|
|
raise NotImplementedError(
|
|
"Constrained decoding with the IterativeRefinementGenerator is not supported"
|
|
)
|
|
|
|
# TODO: iterative refinement generator does not support ensemble for now.
|
|
if not self.retain_dropout:
|
|
for model in models:
|
|
model.eval()
|
|
|
|
model, reranker = models[0], None
|
|
if self.reranking:
|
|
assert len(models) > 1, "Assuming the last checkpoint is the reranker"
|
|
assert (
|
|
self.beam_size > 1
|
|
), "Reranking requires multiple translation for each example"
|
|
|
|
reranker = models[-1]
|
|
models = models[:-1]
|
|
|
|
if len(models) > 1 and hasattr(model, "enable_ensemble"):
|
|
assert model.allow_ensemble, "{} does not support ensembling".format(
|
|
model.__class__.__name__
|
|
)
|
|
model.enable_ensemble(models)
|
|
|
|
# TODO: better encoder inputs?
|
|
src_tokens = sample["net_input"]["src_tokens"]
|
|
src_lengths = sample["net_input"]["src_lengths"]
|
|
bsz, src_len = src_tokens.size()
|
|
|
|
# initialize
|
|
encoder_out = model.forward_encoder([src_tokens, src_lengths])
|
|
prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
|
|
|
|
if self.beam_size > 1:
|
|
assert (
|
|
model.allow_length_beam
|
|
), "{} does not support decoding with length beam.".format(
|
|
model.__class__.__name__
|
|
)
|
|
|
|
# regenerate data based on length-beam
|
|
length_beam_order = (
|
|
utils.new_arange(src_tokens, self.beam_size, bsz).t().reshape(-1)
|
|
)
|
|
encoder_out = model.encoder.reorder_encoder_out(
|
|
encoder_out, length_beam_order
|
|
)
|
|
prev_decoder_out = model.regenerate_length_beam(
|
|
prev_decoder_out, self.beam_size
|
|
)
|
|
bsz = bsz * self.beam_size
|
|
|
|
sent_idxs = torch.arange(bsz)
|
|
prev_output_tokens = prev_decoder_out.output_tokens.clone()
|
|
|
|
if self.retain_history:
|
|
prev_decoder_out = prev_decoder_out._replace(history=[prev_output_tokens])
|
|
|
|
finalized = [[] for _ in range(bsz)]
|
|
|
|
def is_a_loop(x, y, s, a):
|
|
b, l_x, l_y = x.size(0), x.size(1), y.size(1)
|
|
if l_x > l_y:
|
|
y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)], 1)
|
|
s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1)
|
|
if a is not None:
|
|
a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1)
|
|
elif l_x < l_y:
|
|
x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)], 1)
|
|
return (x == y).all(1), y, s, a
|
|
|
|
def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
|
|
cutoff = prev_out_token.ne(self.pad)
|
|
tokens = prev_out_token[cutoff]
|
|
if prev_out_score is None:
|
|
scores, score = None, None
|
|
else:
|
|
scores = prev_out_score[cutoff]
|
|
score = scores.mean()
|
|
|
|
if prev_out_attn is None:
|
|
hypo_attn, alignment = None, None
|
|
else:
|
|
hypo_attn = prev_out_attn[cutoff]
|
|
alignment = hypo_attn.max(dim=1)[1]
|
|
return {
|
|
"steps": step,
|
|
"tokens": tokens,
|
|
"positional_scores": scores,
|
|
"score": score,
|
|
"hypo_attn": hypo_attn,
|
|
"alignment": alignment,
|
|
}
|
|
|
|
for step in range(self.max_iter + 1):
|
|
|
|
decoder_options = {
|
|
"eos_penalty": self.eos_penalty,
|
|
"max_ratio": self.max_ratio,
|
|
"decoding_format": self.decoding_format,
|
|
}
|
|
prev_decoder_out = prev_decoder_out._replace(
|
|
step=step,
|
|
max_step=self.max_iter + 1,
|
|
)
|
|
|
|
decoder_out = model.forward_decoder(
|
|
prev_decoder_out, encoder_out, **decoder_options
|
|
)
|
|
|
|
if self.adaptive:
|
|
# terminate if there is a loop
|
|
terminated, out_tokens, out_scores, out_attn = is_a_loop(
|
|
prev_output_tokens,
|
|
decoder_out.output_tokens,
|
|
decoder_out.output_scores,
|
|
decoder_out.attn,
|
|
)
|
|
decoder_out = decoder_out._replace(
|
|
output_tokens=out_tokens,
|
|
output_scores=out_scores,
|
|
attn=out_attn,
|
|
)
|
|
|
|
else:
|
|
terminated = decoder_out.output_tokens.new_zeros(
|
|
decoder_out.output_tokens.size(0)
|
|
).bool()
|
|
|
|
if step == self.max_iter: # reach last iteration, terminate
|
|
terminated.fill_(1)
|
|
|
|
# collect finalized sentences
|
|
finalized_idxs = sent_idxs[terminated]
|
|
finalized_tokens = decoder_out.output_tokens[terminated]
|
|
finalized_scores = decoder_out.output_scores[terminated]
|
|
finalized_attn = (
|
|
None
|
|
if (decoder_out.attn is None or decoder_out.attn.size(0) == 0)
|
|
else decoder_out.attn[terminated]
|
|
)
|
|
|
|
if self.retain_history:
|
|
finalized_history_tokens = [h[terminated] for h in decoder_out.history]
|
|
|
|
for i in range(finalized_idxs.size(0)):
|
|
finalized[finalized_idxs[i]] = [
|
|
finalized_hypos(
|
|
step,
|
|
finalized_tokens[i],
|
|
finalized_scores[i],
|
|
None if finalized_attn is None else finalized_attn[i],
|
|
)
|
|
]
|
|
|
|
if self.retain_history:
|
|
finalized[finalized_idxs[i]][0]["history"] = []
|
|
for j in range(len(finalized_history_tokens)):
|
|
finalized[finalized_idxs[i]][0]["history"].append(
|
|
finalized_hypos(
|
|
step, finalized_history_tokens[j][i], None, None
|
|
)
|
|
)
|
|
|
|
# check if all terminated
|
|
if terminated.sum() == terminated.size(0):
|
|
break
|
|
|
|
# for next step
|
|
not_terminated = ~terminated
|
|
prev_decoder_out = decoder_out._replace(
|
|
output_tokens=decoder_out.output_tokens[not_terminated],
|
|
output_scores=decoder_out.output_scores[not_terminated],
|
|
attn=decoder_out.attn[not_terminated]
|
|
if (decoder_out.attn is not None and decoder_out.attn.size(0) > 0)
|
|
else None,
|
|
history=[h[not_terminated] for h in decoder_out.history]
|
|
if decoder_out.history is not None
|
|
else None,
|
|
)
|
|
encoder_out = model.encoder.reorder_encoder_out(
|
|
encoder_out, not_terminated.nonzero(as_tuple=False).squeeze()
|
|
)
|
|
sent_idxs = sent_idxs[not_terminated]
|
|
prev_output_tokens = prev_decoder_out.output_tokens.clone()
|
|
|
|
if self.beam_size > 1:
|
|
if reranker is not None:
|
|
finalized = self.rerank(
|
|
reranker, finalized, [src_tokens, src_lengths], self.beam_size
|
|
)
|
|
|
|
# aggregate information from length beam
|
|
finalized = [
|
|
finalized[
|
|
np.argmax(
|
|
[
|
|
finalized[self.beam_size * i + j][0]["score"]
|
|
for j in range(self.beam_size)
|
|
]
|
|
)
|
|
+ self.beam_size * i
|
|
]
|
|
for i in range(len(finalized) // self.beam_size)
|
|
]
|
|
|
|
return finalized
|
|
|
|
def rerank(self, reranker, finalized, encoder_input, beam_size):
|
|
def rebuild_batch(finalized):
|
|
finalized_tokens = [f[0]["tokens"] for f in finalized]
|
|
finalized_maxlen = max(f.size(0) for f in finalized_tokens)
|
|
final_output_tokens = (
|
|
finalized_tokens[0]
|
|
.new_zeros(len(finalized_tokens), finalized_maxlen)
|
|
.fill_(self.pad)
|
|
)
|
|
for i, f in enumerate(finalized_tokens):
|
|
final_output_tokens[i, : f.size(0)] = f
|
|
return final_output_tokens
|
|
|
|
final_output_tokens = rebuild_batch(finalized)
|
|
final_output_tokens[
|
|
:, 0
|
|
] = self.eos # autoregressive model assumes starting with EOS
|
|
|
|
reranker_encoder_out = reranker.encoder(*encoder_input)
|
|
length_beam_order = (
|
|
utils.new_arange(
|
|
final_output_tokens, beam_size, reranker_encoder_out.encoder_out.size(1)
|
|
)
|
|
.t()
|
|
.reshape(-1)
|
|
)
|
|
reranker_encoder_out = reranker.encoder.reorder_encoder_out(
|
|
reranker_encoder_out, length_beam_order
|
|
)
|
|
reranking_scores = reranker.get_normalized_probs(
|
|
reranker.decoder(final_output_tokens[:, :-1], reranker_encoder_out),
|
|
True,
|
|
None,
|
|
)
|
|
reranking_scores = reranking_scores.gather(2, final_output_tokens[:, 1:, None])
|
|
reranking_masks = final_output_tokens[:, 1:].ne(self.pad)
|
|
reranking_scores = (
|
|
reranking_scores[:, :, 0].masked_fill_(~reranking_masks, 0).sum(1)
|
|
)
|
|
reranking_scores = reranking_scores / reranking_masks.sum(1).type_as(
|
|
reranking_scores
|
|
)
|
|
|
|
for i in range(len(finalized)):
|
|
finalized[i][0]["score"] = reranking_scores[i]
|
|
|
|
return finalized
|