mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-10 05:50:10 +00:00
121 lines
4.0 KiB
Python
121 lines
4.0 KiB
Python
# Originally from Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
""" Wrapper for ngram_repeat_block cuda extension """
|
|
import math
|
|
import warnings
|
|
from typing import List
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
try:
|
|
from fairseq import ngram_repeat_block_cuda
|
|
|
|
EXTENSION_BUILT = True
|
|
except ImportError:
|
|
EXTENSION_BUILT = False
|
|
|
|
|
|
def is_cuda_extension_usable() -> bool:
|
|
"""Check whether ngram_repeat_block_cuda is built properly"""
|
|
if not EXTENSION_BUILT or not torch.cuda.is_available():
|
|
return False
|
|
bsz = 2
|
|
tokens = torch.tensor([[4, 4, 3, 2], [1, 2, 3, 4]], dtype=torch.long, device="cuda")
|
|
lprobs = torch.rand((8, 12), device="cuda")
|
|
try:
|
|
outputs = ngram_repeat_block_cuda.forward(tokens, lprobs, bsz, 3, 4, 3)
|
|
outputs = outputs + 4 # This line breaks if the extension is built incorrectly.
|
|
return True
|
|
except RuntimeError:
|
|
warnings.warn(
|
|
"NGramRepeatBlock extension must be rebuilt."
|
|
'Run TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0" python setup.py build_ext --inplace'
|
|
)
|
|
return False
|
|
|
|
|
|
class NGramRepeatBlock(nn.Module):
|
|
"""Wrapper class for calling ngram_repeat_block cuda extension"""
|
|
|
|
def __init__(self, no_repeat_ngram_size: int, use_extension: bool = True):
|
|
super().__init__()
|
|
self.use_extension = is_cuda_extension_usable() if use_extension else False
|
|
self.no_repeat_ngram_size = no_repeat_ngram_size
|
|
|
|
def reset_parameters(self):
|
|
pass
|
|
|
|
@torch.jit.unused
|
|
def call_cuda_extension(
|
|
self,
|
|
tokens,
|
|
lprobs,
|
|
bsz: int,
|
|
beam_size: int,
|
|
step: int,
|
|
):
|
|
return ngram_repeat_block_cuda.forward(
|
|
tokens, lprobs, bsz, step, beam_size, self.no_repeat_ngram_size
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
tokens,
|
|
lprobs,
|
|
bsz: int,
|
|
beam_size: int,
|
|
step: int,
|
|
):
|
|
"""
|
|
Args:
|
|
tokens(Tensor): Input tokens(Bsz*beam, seq_len)
|
|
lprobs(Tensor): likelihood probability,
|
|
Expected to be updated in place.(Bsz*beam, vocab_size)
|
|
bsz(int): batch size
|
|
step(int): current step
|
|
beam_size(int): beam size
|
|
no_repeat_ngram_size(int): Ngram size
|
|
"""
|
|
msg = f"expected {bsz *beam_size} got"
|
|
assert tokens.size(0) == bsz * beam_size, f"{msg} {tokens.size(0)}"
|
|
assert lprobs.size(0) == bsz * beam_size, f"{msg} {lprobs.size(0)}"
|
|
if self.use_extension:
|
|
return self.call_cuda_extension(tokens, lprobs, bsz, beam_size, step)
|
|
|
|
else:
|
|
return self._no_repeat_ngram(
|
|
tokens,
|
|
lprobs,
|
|
bsz,
|
|
beam_size,
|
|
step,
|
|
)
|
|
|
|
def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, step: int):
|
|
"""For each hypothesis generate a list of previous ngrams and set associated lprobs to -inf"""
|
|
banned_tokens = [
|
|
torch.jit.annotate(List[int], []) for bbsz_idx in range(bsz * beam_size)
|
|
]
|
|
if step + 2 - self.no_repeat_ngram_size >= 0:
|
|
cpu_tokens: List[List[int]] = tokens.cpu().tolist()
|
|
check_start_pos = step + 2 - self.no_repeat_ngram_size
|
|
for bbsz_idx in range(bsz * beam_size):
|
|
ngram_to_check = cpu_tokens[bbsz_idx][
|
|
-(self.no_repeat_ngram_size - 1) :
|
|
]
|
|
for i in range(check_start_pos):
|
|
if (
|
|
ngram_to_check
|
|
== cpu_tokens[bbsz_idx][i : i + self.no_repeat_ngram_size - 1]
|
|
):
|
|
banned_tokens[bbsz_idx].append(
|
|
cpu_tokens[bbsz_idx][i + self.no_repeat_ngram_size - 1]
|
|
)
|
|
for bbsz_idx in range(bsz * beam_size):
|
|
lprobs[bbsz_idx][
|
|
torch.tensor(banned_tokens[bbsz_idx], dtype=torch.int64)
|
|
] = torch.tensor(-math.inf).to(lprobs)
|
|
return lprobs
|