mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-10 05:50:10 +00:00
507 lines
16 KiB
Python
507 lines
16 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""Implements tracking of constraints for a beam item.
|
|
|
|
A list of constraints is given as a list of one or more token
|
|
sequences, each of length at least one token. For example, for an input sentence
|
|
|
|
> Die maschinelle Übersetzung ist schwer zu kontrollieren.
|
|
|
|
We could have the constraints:
|
|
* to influence
|
|
* hard
|
|
|
|
There are two implementations:
|
|
* OrderedConstraintState: Tracks progress through an ordered list of multitoken constraints.
|
|
* UnorderedConstraintState: Tracks progress through an unordered list of multitoken constraints.
|
|
|
|
The difference is that in the first, the constraints are assumed to be
|
|
in order; the algorithm will permit zero or more tokens between them.
|
|
In the second, the constraints are not ordered, so many orderings will
|
|
be explored.
|
|
|
|
The same sequence can be present any number of times, and will appear
|
|
that many times in the output.
|
|
"""
|
|
|
|
from collections import Counter
|
|
from typing import List, Optional, Set, Tuple
|
|
|
|
import torch
|
|
|
|
|
|
class ConstraintState:
|
|
def __init__(self):
|
|
pass
|
|
|
|
|
|
def pack_constraints(batch_constraints: List[List[torch.Tensor]]) -> torch.Tensor:
|
|
"""Takes a list of list of constraints in tensor form (a list of
|
|
tensor constraints for each sentence) and transforms it into a
|
|
packed Tensor. For example, here is a batch of size 3 with 3, 0,
|
|
and 1 constraints:
|
|
|
|
[ [ [3 1 2], [3], [4 5 6 7], ]
|
|
[],
|
|
[ [1 8 9 10 1 4 11 12], ]
|
|
]
|
|
|
|
Its corresponding packed structure is:
|
|
|
|
[ [ 3 3 1 2 0 3 0 4 5 6 7 0],
|
|
[ 0 0 0 0 0 0 0 0 0 0 0 0],
|
|
[ 1 1 8 9 10 1 4 11 12 0 0 0] ]
|
|
|
|
The packed tensor has shape (batch size, maxlen), where
|
|
maxlen is defined below. Each row contains concatenated
|
|
constraint tokens for that sentence, with 0 appended after
|
|
each constraint. The first item in each row is the number
|
|
of constraints for that sentence. So maxlen is the maximum
|
|
of
|
|
|
|
(number of constraints) + (sum length of constraints) + 1.
|
|
|
|
across all sentences in the batch.
|
|
"""
|
|
# The maximum word length of concatenated constraints for any sentence
|
|
max_constraints_len = 1
|
|
for sentence_constraints in batch_constraints:
|
|
if len(sentence_constraints):
|
|
# number of constraints, plus sum of constrain lens, plus a zero after each
|
|
constraints_len = (
|
|
1
|
|
+ sum([c.size(0) for c in sentence_constraints])
|
|
+ len(sentence_constraints)
|
|
)
|
|
max_constraints_len = max(max_constraints_len, constraints_len)
|
|
|
|
batch_size = len(batch_constraints)
|
|
constraints_tensor = torch.zeros((batch_size, max_constraints_len)).long()
|
|
for i, sentence_constraints in enumerate(batch_constraints):
|
|
constraints_tensor[i, 0] = len(sentence_constraints)
|
|
offset = 1
|
|
for j, constraint in enumerate(sentence_constraints):
|
|
this_len = constraint.size(0)
|
|
constraints_tensor[i, offset : offset + this_len] = constraint
|
|
offset += this_len + 1
|
|
|
|
return constraints_tensor.long()
|
|
|
|
|
|
def unpack_constraints(constraint_tensor: torch.Tensor) -> List[torch.Tensor]:
|
|
"""
|
|
Transforms *one row* of a packed constraint tensor (e.g., for one
|
|
sentence in the batch) into a list of constraint tensors.
|
|
"""
|
|
constraint_list = []
|
|
num_constraints = constraint_tensor[0]
|
|
constraints = constraint_tensor.tolist()
|
|
offset = 1
|
|
for i in range(num_constraints):
|
|
where = constraints.index(0, offset)
|
|
constraint_list.append(constraint_tensor[offset:where])
|
|
offset = where + 1
|
|
|
|
return constraint_list
|
|
|
|
|
|
class ConstraintNode:
|
|
"""
|
|
Represents a node in a trie managing unordered constraints.
|
|
"""
|
|
|
|
def __init__(self, token: int = None, parent=None):
|
|
# The token associate with this node (None for the root)
|
|
self.token = int(token) if token is not None else None
|
|
# The parent (None at the root)
|
|
self.parent = parent
|
|
# Whether this node is a completed constraint
|
|
self.terminal = 0
|
|
# List of child nodes
|
|
self.children = {}
|
|
|
|
# The cumulative number of constraints from this point in the
|
|
# trie forward
|
|
self.num_constraints = 0
|
|
|
|
@property
|
|
def id(self):
|
|
return self.token
|
|
|
|
def __str__(self):
|
|
term = self.terminal != 0
|
|
return f"[{self.token}].{term}#{self.num_constraints}"
|
|
|
|
def __getitem__(self, key: int):
|
|
return self.children.get(key, None)
|
|
|
|
def next_tokens(self) -> Set[int]:
|
|
"""The set of child labels."""
|
|
return set(self.children.keys())
|
|
|
|
@staticmethod
|
|
def create(constraints: List[List[int]]):
|
|
root = ConstraintNode()
|
|
for sequence in constraints:
|
|
root.add_sequence(sequence)
|
|
|
|
return root
|
|
|
|
@staticmethod
|
|
def print_graph(node: "ConstraintNode"):
|
|
if len(node.children) == 0:
|
|
return str(node)
|
|
else:
|
|
s = f"({node}"
|
|
for child in node.children.values():
|
|
s += " " + ConstraintNode.print_graph(child)
|
|
s += ")"
|
|
return s
|
|
|
|
def token_counts(self) -> Counter:
|
|
"""Returns a counter of the number of times each token is used
|
|
in a constraint.
|
|
"""
|
|
token_counts = Counter()
|
|
kids = list(self.children.values())
|
|
while len(kids) > 0:
|
|
kid = kids.pop()
|
|
token_counts[kid.id] += kid.num_constraints
|
|
kids += list(kid.children.values())
|
|
|
|
return token_counts
|
|
|
|
def tokens(self) -> Set[int]:
|
|
"""Returns the set of tokens in constraints."""
|
|
return set(self.token_counts().keys())
|
|
|
|
def add_sequence(self, sequence: List[int]):
|
|
"""Adds a constraint, represented as a list of integers, to
|
|
the trie."""
|
|
assert len(sequence) > 0
|
|
|
|
token = int(sequence[0])
|
|
if token not in self.children:
|
|
self.children[token] = ConstraintNode(token, parent=self)
|
|
|
|
node = self.children[token]
|
|
if len(sequence) == 1:
|
|
node.terminal += 1
|
|
node.num_constraints += 1
|
|
parent = node.parent
|
|
while parent is not None:
|
|
parent.num_constraints += 1
|
|
parent = parent.parent
|
|
else:
|
|
node.add_sequence(sequence[1:])
|
|
|
|
|
|
class UnorderedConstraintState(ConstraintState):
|
|
"""
|
|
Records progress through the set of constraints for each item in the beam
|
|
using a trie.
|
|
"""
|
|
|
|
def __init__(self, node: ConstraintNode, copy_from: "ConstraintState" = None):
|
|
self.node = node
|
|
|
|
if copy_from is None:
|
|
# The root node
|
|
self.root = node
|
|
# The set of states in the graph that have been completed
|
|
self.completed = Counter()
|
|
# The...
|
|
self.generated = Counter()
|
|
# The list of tokens we need to generate
|
|
self.needed_tokens = self.root.tokens()
|
|
else:
|
|
self.completed = Counter(copy_from.completed)
|
|
self.generated = Counter(copy_from.generated)
|
|
self.root = copy_from.root
|
|
|
|
# Mark the node as generated
|
|
if self.node != self.root:
|
|
self.generated[node] += 1
|
|
|
|
@staticmethod
|
|
def create(constraint_tensor: torch.Tensor):
|
|
constraint_list = unpack_constraints(constraint_tensor)
|
|
constraint_trie_root = ConstraintNode.create(constraint_list)
|
|
return UnorderedConstraintState(constraint_trie_root)
|
|
|
|
def __str__(self):
|
|
gen_str = ",".join([str(node) for node in self.generated])
|
|
return f"{self.name}/{self.bank}({gen_str})x{self.num_completed}"
|
|
|
|
def __copy__(self):
|
|
copied_state = UnorderedConstraintState(self.node, copy_from=self)
|
|
return copied_state
|
|
|
|
def copy(self):
|
|
return self.__copy__()
|
|
|
|
@property
|
|
def name(self):
|
|
if self.node.id is None:
|
|
return "ROOT"
|
|
else:
|
|
return str(self.node.id)
|
|
|
|
@property
|
|
def is_root(self):
|
|
return self.node == self.root
|
|
|
|
@property
|
|
def bank(self):
|
|
return sum(self.generated.values())
|
|
|
|
@property
|
|
def num_completed(self):
|
|
"""The number of constraints (not constraint tokens) that are completed.
|
|
In addition to the already-completed states, we need to account for the
|
|
current state, which might get marked as completed when another token
|
|
is generated.
|
|
"""
|
|
in_final = self.node.terminal and self.completed[self.node] < self.node.terminal
|
|
return sum(self.completed.values()) + in_final
|
|
|
|
@property
|
|
def finished(self):
|
|
return self.root.num_constraints - self.num_completed == 0
|
|
|
|
@property
|
|
def token_counts(self):
|
|
return self.root.token_counts()
|
|
|
|
@property
|
|
def tokens(self):
|
|
return self.root.tokens()
|
|
|
|
@property
|
|
def num_constraint_tokens(self):
|
|
return sum(self.token_counts.values())
|
|
|
|
def next_tokens(self) -> Set[int]:
|
|
"""Returns the list of tokens that could come next.
|
|
These are (a) all tokens extending the root state and, for
|
|
non-root states, additionally all tokens extending the current
|
|
state."""
|
|
|
|
if self.node != self.root:
|
|
return self.root.next_tokens().union(self.node.next_tokens())
|
|
else:
|
|
return self.root.next_tokens()
|
|
|
|
def advance(self, token: int):
|
|
"""Reads in a token and advances the state. Here's how it works.
|
|
|
|
We can advance to the next state if:
|
|
- there is a matching child
|
|
- its path isn't blocked
|
|
|
|
A path is blocked when all constraints that are descendants of
|
|
that node have already been generated, in the current state.
|
|
|
|
If we are not able to advance from the current state, we "fall
|
|
off the graph" and return to the root state. There, we again
|
|
try to advance, checking the same criteria.
|
|
|
|
In any case, when falling off the graph, we need to do some
|
|
bookkeeping. We:
|
|
- check whether any constraints were met (all prefixes of
|
|
current state)
|
|
- if one is found, mark it as completed
|
|
- adjust visited nodes accordingly
|
|
"""
|
|
token = int(token)
|
|
|
|
next_state = None
|
|
child = self.node[token]
|
|
if child is not None and self.generated[child] < child.num_constraints:
|
|
next_state = UnorderedConstraintState(child, copy_from=self)
|
|
|
|
def rewind():
|
|
"""If we're mid-trie and an "illegal" token is chosen next, we need
|
|
to reset our state to the root state. However, along the way, we need
|
|
to check whether a prefix of the current trie state represents a state
|
|
we could mark as completed.
|
|
"""
|
|
node = self.node
|
|
while node != self.root:
|
|
if node.terminal and self.completed[node] < node.terminal:
|
|
next_state.completed[node] += 1
|
|
return
|
|
|
|
next_state.generated[node] -= 1
|
|
node = node.parent
|
|
|
|
# Fall off the graph, check the root
|
|
if next_state is None and token in self.root.next_tokens():
|
|
child = self.root[token]
|
|
# We can only traverse this edge if it's not saturated
|
|
if self.generated[child] < child.num_constraints:
|
|
next_state = UnorderedConstraintState(child, copy_from=self)
|
|
else:
|
|
next_state = UnorderedConstraintState(self.root, copy_from=self)
|
|
|
|
# Rewind
|
|
rewind()
|
|
|
|
elif next_state is None:
|
|
next_state = UnorderedConstraintState(self.root, copy_from=self)
|
|
# Rewind
|
|
rewind()
|
|
|
|
return next_state
|
|
|
|
|
|
class ConstraintSequence:
|
|
def __init__(self, sequences: List[List[int]]):
|
|
"""Represents a set of possibly multitoken constraints by
|
|
concatenating them and internally recording the end points.
|
|
"""
|
|
self.sequences = []
|
|
self.endpoints = []
|
|
self.num_tokens = 0
|
|
self.tokens = set()
|
|
for sequence in sequences:
|
|
for token in sequence:
|
|
self.tokens.add(token)
|
|
self.num_tokens += len(sequence)
|
|
self.endpoints += [False for x in range(len(sequence) - 1)] + [True]
|
|
self.sequences += sequence
|
|
|
|
def __getitem__(self, key: int):
|
|
return self.sequences[key]
|
|
|
|
def __len__(self):
|
|
return len(self.sequences)
|
|
|
|
def __str__(self):
|
|
return str(self.sequences)
|
|
|
|
|
|
class OrderedConstraintState(ConstraintState):
|
|
"""
|
|
Records progress through the set of linear nonbranching constraints with gaps.
|
|
"""
|
|
|
|
def __init__(self, sequence: ConstraintSequence, state: int = -1):
|
|
self.sequence = sequence
|
|
self.state = state
|
|
|
|
@staticmethod
|
|
def create(constraint_tensor: torch.Tensor):
|
|
constraint_list = unpack_constraints(constraint_tensor)
|
|
return OrderedConstraintState(ConstraintSequence(constraint_list), -1)
|
|
|
|
def __str__(self):
|
|
return f"{self.state}/{self.bank}x{self.num_completed}"
|
|
|
|
def __copy__(self):
|
|
return OrderedConstraintState(self.sequence, self.state)
|
|
|
|
def copy(self):
|
|
return self.__copy__()
|
|
|
|
@property
|
|
def num_completed(self):
|
|
if self.state == -1:
|
|
return 0
|
|
count = len(
|
|
list(filter(lambda x: x, self.sequence.endpoints[0 : self.state + 1]))
|
|
)
|
|
return count
|
|
|
|
@property
|
|
def is_root(self):
|
|
return self.state == -1
|
|
|
|
@property
|
|
def name(self):
|
|
if self.state == -1:
|
|
return "ROOT"
|
|
else:
|
|
return str(self.sequence[self.state])
|
|
|
|
@property
|
|
def bank(self) -> int:
|
|
return self.state + 1
|
|
|
|
@property
|
|
def finished(self):
|
|
return self.state + 1 == len(self.sequence)
|
|
|
|
@property
|
|
def token_counts(self):
|
|
return self.sequence.token_counts()
|
|
|
|
@property
|
|
def tokens(self):
|
|
return self.sequence.tokens
|
|
|
|
@property
|
|
def num_constraint_tokens(self):
|
|
return sum(self.token_counts.values())
|
|
|
|
def next_tokens(self) -> Set[int]:
|
|
"""Returns the list of tokens that could come next.
|
|
These are (a) all tokens extending the root state and, for
|
|
non-root states, additionally all tokens extending the current
|
|
state."""
|
|
|
|
tokens = set()
|
|
if self.state > 0:
|
|
tokens.add(self.sequence[0])
|
|
if not self.finished:
|
|
tokens.add(self.sequence[self.state + 1])
|
|
return tokens
|
|
|
|
def advance(self, token: int):
|
|
"""Reads in a token and advances the state. Here's how it works.
|
|
|
|
We can advance to the next state if:
|
|
- there is a matching child
|
|
- its path isn't blocked
|
|
|
|
A path is blocked when all constraints that are descendants of
|
|
that node have already been generated, in the current state.
|
|
|
|
If we are not able to advance from the current state, we "fall
|
|
off the graph" and return to the root state. There, we again
|
|
try to advance, checking the same criteria.
|
|
|
|
In any case, when falling off the graph, we need to do some
|
|
bookkeeping. We:
|
|
- check whether any constraints were met (all prefixes of
|
|
current state)
|
|
- if one is found, mark it as completed
|
|
- adjust visited nodes accordingly
|
|
"""
|
|
token = int(token)
|
|
# print(f"{self} ADVANCE({token}) {self.sequence} -> ", end="")
|
|
|
|
if self.finished:
|
|
# Accept anything
|
|
next_state = self.copy()
|
|
|
|
elif self.sequence[self.state + 1] == token:
|
|
# Advance to the next token
|
|
next_state = OrderedConstraintState(self.sequence, self.state + 1)
|
|
|
|
elif self.sequence.endpoints[self.state]:
|
|
# Accept anything between constraints (*)
|
|
next_state = self.copy()
|
|
|
|
elif token == self.sequence[0]:
|
|
# Start over having generated the first token
|
|
next_state = OrderedConstraintState(self.sequence, 0)
|
|
else:
|
|
# Start over from the root
|
|
next_state = OrderedConstraintState(self.sequence, -1)
|
|
|
|
return next_state
|