mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-11 22:40:03 +00:00
42 lines
1.2 KiB
Python
42 lines
1.2 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import logging
|
|
from fairseq.tasks import register_task
|
|
from fairseq.tasks.speech_to_text import SpeechToTextTask
|
|
from fairseq.tasks.translation import TranslationTask, TranslationConfig
|
|
|
|
try:
|
|
import examples.simultaneous_translation # noqa
|
|
|
|
import_successful = True
|
|
except BaseException:
|
|
import_successful = False
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def check_import(flag):
|
|
if not flag:
|
|
raise ImportError(
|
|
"'examples.simultaneous_translation' is not correctly imported. "
|
|
"Please considering `pip install -e $FAIRSEQ_DIR`."
|
|
)
|
|
|
|
|
|
@register_task("simul_speech_to_text")
|
|
class SimulSpeechToTextTask(SpeechToTextTask):
|
|
def __init__(self, args, tgt_dict):
|
|
check_import(import_successful)
|
|
super().__init__(args, tgt_dict)
|
|
|
|
|
|
@register_task("simul_text_to_text", dataclass=TranslationConfig)
|
|
class SimulTextToTextTask(TranslationTask):
|
|
def __init__(self, cfg, src_dict, tgt_dict):
|
|
check_import(import_successful)
|
|
super().__init__(cfg, src_dict, tgt_dict)
|