mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-01 17:40:27 +00:00
414 lines
15 KiB
Python
414 lines
15 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import argparse
|
|
from pathlib import Path
|
|
from typing import Callable, List, Optional, Union
|
|
|
|
import torch
|
|
from fairseq import utils
|
|
from fairseq.data.indexed_dataset import get_available_dataset_impl
|
|
from fairseq.dataclass.configs import (
|
|
CheckpointConfig,
|
|
CommonConfig,
|
|
CommonEvalConfig,
|
|
DatasetConfig,
|
|
DistributedTrainingConfig,
|
|
EvalLMConfig,
|
|
GenerationConfig,
|
|
InteractiveConfig,
|
|
OptimizationConfig,
|
|
EMAConfig,
|
|
)
|
|
from fairseq.dataclass.utils import gen_parser_from_dataclass
|
|
|
|
# this import is for backward compatibility
|
|
from fairseq.utils import csv_str_list, eval_bool, eval_str_dict, eval_str_list # noqa
|
|
|
|
|
|
def get_preprocessing_parser(default_task="translation"):
|
|
parser = get_parser("Preprocessing", default_task)
|
|
add_preprocess_args(parser)
|
|
return parser
|
|
|
|
|
|
def get_training_parser(default_task="translation"):
|
|
parser = get_parser("Trainer", default_task)
|
|
add_dataset_args(parser, train=True)
|
|
add_distributed_training_args(parser)
|
|
add_model_args(parser)
|
|
add_optimization_args(parser)
|
|
add_checkpoint_args(parser)
|
|
add_ema_args(parser)
|
|
return parser
|
|
|
|
|
|
def get_generation_parser(interactive=False, default_task="translation"):
|
|
parser = get_parser("Generation", default_task)
|
|
add_dataset_args(parser, gen=True)
|
|
add_distributed_training_args(parser, default_world_size=1)
|
|
add_generation_args(parser)
|
|
add_checkpoint_args(parser)
|
|
if interactive:
|
|
add_interactive_args(parser)
|
|
return parser
|
|
|
|
|
|
def get_speech_generation_parser(default_task="text_to_speech"):
|
|
parser = get_parser("Speech Generation", default_task)
|
|
add_dataset_args(parser, gen=True)
|
|
add_distributed_training_args(parser, default_world_size=1)
|
|
add_speech_generation_args(parser)
|
|
return parser
|
|
|
|
|
|
def get_interactive_generation_parser(default_task="translation"):
|
|
return get_generation_parser(interactive=True, default_task=default_task)
|
|
|
|
|
|
def get_eval_lm_parser(default_task="language_modeling"):
|
|
parser = get_parser("Evaluate Language Model", default_task)
|
|
add_dataset_args(parser, gen=True)
|
|
add_distributed_training_args(parser, default_world_size=1)
|
|
add_eval_lm_args(parser)
|
|
return parser
|
|
|
|
|
|
def get_validation_parser(default_task=None):
|
|
parser = get_parser("Validation", default_task)
|
|
add_dataset_args(parser, train=True)
|
|
add_distributed_training_args(parser, default_world_size=1)
|
|
group = parser.add_argument_group("Evaluation")
|
|
gen_parser_from_dataclass(group, CommonEvalConfig())
|
|
return parser
|
|
|
|
|
|
def parse_args_and_arch(
|
|
parser: argparse.ArgumentParser,
|
|
input_args: List[str] = None,
|
|
parse_known: bool = False,
|
|
suppress_defaults: bool = False,
|
|
modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None,
|
|
):
|
|
"""
|
|
Args:
|
|
parser (ArgumentParser): the parser
|
|
input_args (List[str]): strings to parse, defaults to sys.argv
|
|
parse_known (bool): only parse known arguments, similar to
|
|
`ArgumentParser.parse_known_args`
|
|
suppress_defaults (bool): parse while ignoring all default values
|
|
modify_parser (Optional[Callable[[ArgumentParser], None]]):
|
|
function to modify the parser, e.g., to set default values
|
|
"""
|
|
if suppress_defaults:
|
|
# Parse args without any default values. This requires us to parse
|
|
# twice, once to identify all the necessary task/model args, and a second
|
|
# time with all defaults set to None.
|
|
args = parse_args_and_arch(
|
|
parser,
|
|
input_args=input_args,
|
|
parse_known=parse_known,
|
|
suppress_defaults=False,
|
|
)
|
|
suppressed_parser = argparse.ArgumentParser(add_help=False, parents=[parser])
|
|
suppressed_parser.set_defaults(**{k: None for k, v in vars(args).items()})
|
|
args = suppressed_parser.parse_args(input_args)
|
|
return argparse.Namespace(
|
|
**{k: v for k, v in vars(args).items() if v is not None}
|
|
)
|
|
|
|
from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY, MODEL_REGISTRY
|
|
|
|
# Before creating the true parser, we need to import optional user module
|
|
# in order to eagerly import custom tasks, optimizers, architectures, etc.
|
|
usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
|
|
usr_parser.add_argument("--user-dir", default=None)
|
|
usr_args, _ = usr_parser.parse_known_args(input_args)
|
|
utils.import_user_module(usr_args)
|
|
|
|
if modify_parser is not None:
|
|
modify_parser(parser)
|
|
|
|
# The parser doesn't know about model/criterion/optimizer-specific args, so
|
|
# we parse twice. First we parse the model/criterion/optimizer, then we
|
|
# parse a second time after adding the *-specific arguments.
|
|
# If input_args is given, we will parse those args instead of sys.argv.
|
|
args, _ = parser.parse_known_args(input_args)
|
|
|
|
# Add model-specific args to parser.
|
|
if hasattr(args, "arch"):
|
|
model_specific_group = parser.add_argument_group(
|
|
"Model-specific configuration",
|
|
# Only include attributes which are explicitly given as command-line
|
|
# arguments or which have default values.
|
|
argument_default=argparse.SUPPRESS,
|
|
)
|
|
if args.arch in ARCH_MODEL_REGISTRY:
|
|
ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group)
|
|
elif args.arch in MODEL_REGISTRY:
|
|
MODEL_REGISTRY[args.arch].add_args(model_specific_group)
|
|
else:
|
|
raise RuntimeError()
|
|
|
|
if hasattr(args, "task"):
|
|
from fairseq.tasks import TASK_REGISTRY
|
|
|
|
TASK_REGISTRY[args.task].add_args(parser)
|
|
if getattr(args, "use_bmuf", False):
|
|
# hack to support extra args for block distributed data parallelism
|
|
from fairseq.optim.bmuf import FairseqBMUF
|
|
|
|
FairseqBMUF.add_args(parser)
|
|
|
|
# Add *-specific args to parser.
|
|
from fairseq.registry import REGISTRIES
|
|
|
|
for registry_name, REGISTRY in REGISTRIES.items():
|
|
choice = getattr(args, registry_name, None)
|
|
if choice is not None:
|
|
cls = REGISTRY["registry"][choice]
|
|
if hasattr(cls, "add_args"):
|
|
cls.add_args(parser)
|
|
elif hasattr(cls, "__dataclass"):
|
|
gen_parser_from_dataclass(parser, cls.__dataclass())
|
|
|
|
# Modify the parser a second time, since defaults may have been reset
|
|
if modify_parser is not None:
|
|
modify_parser(parser)
|
|
|
|
# Parse a second time.
|
|
if parse_known:
|
|
args, extra = parser.parse_known_args(input_args)
|
|
else:
|
|
args = parser.parse_args(input_args)
|
|
extra = None
|
|
# Post-process args.
|
|
if (
|
|
hasattr(args, "batch_size_valid") and args.batch_size_valid is None
|
|
) or not hasattr(args, "batch_size_valid"):
|
|
args.batch_size_valid = args.batch_size
|
|
if hasattr(args, "max_tokens_valid") and args.max_tokens_valid is None:
|
|
args.max_tokens_valid = args.max_tokens
|
|
if getattr(args, "memory_efficient_fp16", False):
|
|
args.fp16 = True
|
|
if getattr(args, "memory_efficient_bf16", False):
|
|
args.bf16 = True
|
|
args.tpu = getattr(args, "tpu", False)
|
|
args.bf16 = getattr(args, "bf16", False)
|
|
if args.bf16:
|
|
args.tpu = True
|
|
if args.tpu and args.fp16:
|
|
raise ValueError("Cannot combine --fp16 and --tpu, use --bf16 on TPUs")
|
|
|
|
if getattr(args, "seed", None) is None:
|
|
args.seed = 1 # default seed for training
|
|
args.no_seed_provided = True
|
|
else:
|
|
args.no_seed_provided = False
|
|
|
|
if getattr(args, "update_epoch_batch_itr", None) is None:
|
|
if hasattr(args, "grouped_shuffling"):
|
|
args.update_epoch_batch_itr = args.grouped_shuffling
|
|
else:
|
|
args.grouped_shuffling = False
|
|
args.update_epoch_batch_itr = False
|
|
|
|
# Apply architecture configuration.
|
|
if hasattr(args, "arch") and args.arch in ARCH_CONFIG_REGISTRY:
|
|
ARCH_CONFIG_REGISTRY[args.arch](args)
|
|
|
|
if parse_known:
|
|
return args, extra
|
|
else:
|
|
return args
|
|
|
|
|
|
def get_parser(desc, default_task="translation"):
|
|
# Before creating the true parser, we need to import optional user module
|
|
# in order to eagerly import custom tasks, optimizers, architectures, etc.
|
|
usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
|
|
usr_parser.add_argument("--user-dir", default=None)
|
|
usr_args, _ = usr_parser.parse_known_args()
|
|
utils.import_user_module(usr_args)
|
|
|
|
parser = argparse.ArgumentParser(allow_abbrev=False)
|
|
gen_parser_from_dataclass(parser, CommonConfig())
|
|
|
|
from fairseq.registry import REGISTRIES
|
|
|
|
for registry_name, REGISTRY in REGISTRIES.items():
|
|
parser.add_argument(
|
|
"--" + registry_name.replace("_", "-"),
|
|
default=REGISTRY["default"],
|
|
choices=REGISTRY["registry"].keys(),
|
|
)
|
|
|
|
# Task definitions can be found under fairseq/tasks/
|
|
from fairseq.tasks import TASK_REGISTRY
|
|
|
|
parser.add_argument(
|
|
"--task",
|
|
metavar="TASK",
|
|
default=default_task,
|
|
choices=TASK_REGISTRY.keys(),
|
|
help="task",
|
|
)
|
|
# fmt: on
|
|
return parser
|
|
|
|
|
|
def add_preprocess_args(parser):
|
|
group = parser.add_argument_group("Preprocessing")
|
|
# fmt: off
|
|
group.add_argument("-s", "--source-lang", default=None, metavar="SRC",
|
|
help="source language")
|
|
group.add_argument("-t", "--target-lang", default=None, metavar="TARGET",
|
|
help="target language")
|
|
group.add_argument("--trainpref", metavar="FP", default=None,
|
|
help="train file prefix (also used to build dictionaries)")
|
|
group.add_argument("--validpref", metavar="FP", default=None,
|
|
help="comma separated, valid file prefixes "
|
|
"(words missing from train set are replaced with <unk>)")
|
|
group.add_argument("--testpref", metavar="FP", default=None,
|
|
help="comma separated, test file prefixes "
|
|
"(words missing from train set are replaced with <unk>)")
|
|
group.add_argument("--align-suffix", metavar="FP", default=None,
|
|
help="alignment file suffix")
|
|
group.add_argument("--destdir", metavar="DIR", default="data-bin",
|
|
help="destination dir")
|
|
group.add_argument("--thresholdtgt", metavar="N", default=0, type=int,
|
|
help="map words appearing less than threshold times to unknown")
|
|
group.add_argument("--thresholdsrc", metavar="N", default=0, type=int,
|
|
help="map words appearing less than threshold times to unknown")
|
|
group.add_argument("--tgtdict", metavar="FP",
|
|
help="reuse given target dictionary")
|
|
group.add_argument("--srcdict", metavar="FP",
|
|
help="reuse given source dictionary")
|
|
group.add_argument("--nwordstgt", metavar="N", default=-1, type=int,
|
|
help="number of target words to retain")
|
|
group.add_argument("--nwordssrc", metavar="N", default=-1, type=int,
|
|
help="number of source words to retain")
|
|
group.add_argument("--alignfile", metavar="ALIGN", default=None,
|
|
help="an alignment file (optional)")
|
|
parser.add_argument('--dataset-impl', metavar='FORMAT', default='mmap',
|
|
choices=get_available_dataset_impl(),
|
|
help='output dataset implementation')
|
|
group.add_argument("--joined-dictionary", action="store_true",
|
|
help="Generate joined dictionary")
|
|
group.add_argument("--only-source", action="store_true",
|
|
help="Only process the source language")
|
|
group.add_argument("--padding-factor", metavar="N", default=8, type=int,
|
|
help="Pad dictionary size to be multiple of N")
|
|
group.add_argument("--workers", metavar="N", default=1, type=int,
|
|
help="number of parallel workers")
|
|
group.add_argument("--dict-only", action='store_true',
|
|
help="if true, only builds a dictionary and then exits")
|
|
# fmt: on
|
|
return parser
|
|
|
|
|
|
def add_dataset_args(parser, train=False, gen=False):
|
|
group = parser.add_argument_group("dataset_data_loading")
|
|
gen_parser_from_dataclass(group, DatasetConfig())
|
|
# fmt: on
|
|
return group
|
|
|
|
|
|
def add_distributed_training_args(parser, default_world_size=None):
|
|
group = parser.add_argument_group("distributed_training")
|
|
if default_world_size is None:
|
|
default_world_size = max(1, torch.cuda.device_count())
|
|
gen_parser_from_dataclass(
|
|
group, DistributedTrainingConfig(distributed_world_size=default_world_size)
|
|
)
|
|
return group
|
|
|
|
|
|
def add_optimization_args(parser):
|
|
group = parser.add_argument_group("optimization")
|
|
# fmt: off
|
|
gen_parser_from_dataclass(group, OptimizationConfig())
|
|
# fmt: on
|
|
return group
|
|
|
|
|
|
def add_checkpoint_args(parser):
|
|
group = parser.add_argument_group("checkpoint")
|
|
# fmt: off
|
|
gen_parser_from_dataclass(group, CheckpointConfig())
|
|
# fmt: on
|
|
return group
|
|
|
|
|
|
def add_common_eval_args(group):
|
|
gen_parser_from_dataclass(group, CommonEvalConfig())
|
|
|
|
|
|
def add_eval_lm_args(parser):
|
|
group = parser.add_argument_group("LM Evaluation")
|
|
add_common_eval_args(group)
|
|
gen_parser_from_dataclass(group, EvalLMConfig())
|
|
|
|
|
|
def add_generation_args(parser):
|
|
group = parser.add_argument_group("Generation")
|
|
add_common_eval_args(group)
|
|
gen_parser_from_dataclass(group, GenerationConfig())
|
|
return group
|
|
|
|
|
|
def add_speech_generation_args(parser):
|
|
group = parser.add_argument_group("Speech Generation")
|
|
add_common_eval_args(group) # NOTE: remove_bpe is not needed
|
|
# fmt: off
|
|
group.add_argument('--eos_prob_threshold', default=0.5, type=float,
|
|
help='terminate when eos probability exceeds this')
|
|
# fmt: on
|
|
return group
|
|
|
|
|
|
def add_interactive_args(parser):
|
|
group = parser.add_argument_group("Interactive")
|
|
gen_parser_from_dataclass(group, InteractiveConfig())
|
|
|
|
|
|
def add_model_args(parser):
|
|
group = parser.add_argument_group("Model configuration")
|
|
# fmt: off
|
|
|
|
# Model definitions can be found under fairseq/models/
|
|
#
|
|
# The model architecture can be specified in several ways.
|
|
# In increasing order of priority:
|
|
# 1) model defaults (lowest priority)
|
|
# 2) --arch argument
|
|
# 3) --encoder/decoder-* arguments (highest priority)
|
|
from fairseq.models import ARCH_MODEL_REGISTRY
|
|
group.add_argument('--arch', '-a', metavar='ARCH',
|
|
choices=ARCH_MODEL_REGISTRY.keys(),
|
|
help='model architecture')
|
|
# fmt: on
|
|
return group
|
|
|
|
|
|
def get_args(
|
|
data: Union[str, Path],
|
|
task: str = "translation",
|
|
arch: str = "transformer",
|
|
**overrides
|
|
):
|
|
parser = get_training_parser(task)
|
|
args = parse_args_and_arch(parser, [str(data), "--task", task, "--arch", arch])
|
|
|
|
for k, v in overrides.items():
|
|
setattr(args, k, v)
|
|
|
|
return args
|
|
|
|
|
|
def add_ema_args(parser):
|
|
group = parser.add_argument_group("EMA configuration")
|
|
gen_parser_from_dataclass(group, EMAConfig())
|