mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-30 11:21:28 +00:00
40 lines
1.3 KiB
Python
40 lines
1.3 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from dataclasses import dataclass
|
|
from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary
|
|
from fairseq.tasks.translation import TranslationConfig, TranslationTask
|
|
|
|
from . import register_task
|
|
|
|
|
|
@dataclass
|
|
class TranslationFromPretrainedXLMConfig(TranslationConfig):
|
|
pass
|
|
|
|
|
|
@register_task(
|
|
"translation_from_pretrained_xlm", dataclass=TranslationFromPretrainedXLMConfig
|
|
)
|
|
class TranslationFromPretrainedXLMTask(TranslationTask):
|
|
"""
|
|
Same as TranslationTask except use the MaskedLMDictionary class so that
|
|
we can load data that was binarized with the MaskedLMDictionary class.
|
|
|
|
This task should be used for the entire training pipeline when we want to
|
|
train an NMT model from a pretrained XLM checkpoint: binarizing NMT data,
|
|
training NMT with the pretrained XLM checkpoint, and subsequent evaluation
|
|
of that trained model.
|
|
"""
|
|
|
|
@classmethod
|
|
def load_dictionary(cls, filename):
|
|
"""Load the masked LM dictionary from the filename
|
|
|
|
Args:
|
|
filename (str): the filename
|
|
"""
|
|
return MaskedLMDictionary.load(filename)
|