mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-27 09:52:03 +00:00
61 lines
1.5 KiB
Python
61 lines
1.5 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from fairseq.data import Dictionary
|
|
|
|
|
|
class MaskedLMDictionary(Dictionary):
|
|
"""
|
|
Dictionary for Masked Language Modelling tasks. This extends Dictionary by
|
|
adding the mask symbol.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
pad="<pad>",
|
|
eos="</s>",
|
|
unk="<unk>",
|
|
mask="<mask>",
|
|
):
|
|
super().__init__(pad=pad, eos=eos, unk=unk)
|
|
self.mask_word = mask
|
|
self.mask_index = self.add_symbol(mask)
|
|
self.nspecial = len(self.symbols)
|
|
|
|
def mask(self):
|
|
"""Helper to get index of mask symbol"""
|
|
return self.mask_index
|
|
|
|
|
|
class BertDictionary(MaskedLMDictionary):
|
|
"""
|
|
Dictionary for BERT task. This extends MaskedLMDictionary by adding support
|
|
for cls and sep symbols.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
pad="<pad>",
|
|
eos="</s>",
|
|
unk="<unk>",
|
|
mask="<mask>",
|
|
cls="<cls>",
|
|
sep="<sep>",
|
|
):
|
|
super().__init__(pad=pad, eos=eos, unk=unk, mask=mask)
|
|
self.cls_word = cls
|
|
self.sep_word = sep
|
|
self.cls_index = self.add_symbol(cls)
|
|
self.sep_index = self.add_symbol(sep)
|
|
self.nspecial = len(self.symbols)
|
|
|
|
def cls(self):
|
|
"""Helper to get index of cls symbol"""
|
|
return self.cls_index
|
|
|
|
def sep(self):
|
|
"""Helper to get index of sep symbol"""
|
|
return self.sep_index
|