mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-11 14:30:03 +00:00
31 lines
909 B
Python
31 lines
909 B
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
from fairseq.data import encoders
|
|
|
|
|
|
def get_whole_word_mask(args, dictionary):
|
|
bpe = encoders.build_bpe(args)
|
|
if bpe is not None:
|
|
|
|
def is_beginning_of_word(i):
|
|
if i < dictionary.nspecial:
|
|
# special elements are always considered beginnings
|
|
return True
|
|
tok = dictionary[i]
|
|
if tok.startswith("madeupword"):
|
|
return True
|
|
try:
|
|
return bpe.is_beginning_of_word(tok)
|
|
except ValueError:
|
|
return True
|
|
|
|
mask_whole_words = torch.ByteTensor(
|
|
list(map(is_beginning_of_word, range(len(dictionary))))
|
|
)
|
|
return mask_whole_words
|
|
return None
|