mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-05-02 12:21:19 +00:00
Add monkey patched fairseq package to run on python 3.11 (what is needed for our use of RVC at least)
This commit is contained in:
@@ -0,0 +1,78 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def uniform(dataset_sizes: List[int]):
|
||||
return [1.0] * len(dataset_sizes)
|
||||
|
||||
|
||||
def temperature_sampling(dataset_sizes, temp):
|
||||
total_size = sum(dataset_sizes)
|
||||
return [(size / total_size) ** (1.0 / temp) for size in dataset_sizes]
|
||||
|
||||
|
||||
def make_temperature_sampling(temp=1.0):
|
||||
def sampling_func(dataset_sizes):
|
||||
return temperature_sampling(dataset_sizes, temp)
|
||||
|
||||
return sampling_func
|
||||
|
||||
|
||||
def make_ratio_sampling(ratios):
|
||||
def sampling_func(dataset_sizes):
|
||||
return ratios
|
||||
|
||||
return sampling_func
|
||||
|
||||
|
||||
class SamplingMethod:
|
||||
@staticmethod
|
||||
def add_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--sampling-method",
|
||||
choices=[
|
||||
"uniform",
|
||||
"temperature",
|
||||
"concat",
|
||||
"RoundRobin",
|
||||
],
|
||||
type=str,
|
||||
default="concat",
|
||||
help="The method to sample data per language pairs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampling-temperature",
|
||||
default=1.5,
|
||||
type=float,
|
||||
help="only work with --sampling-method temperature",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_sampler(args, task):
|
||||
return SamplingMethod(args, task)
|
||||
|
||||
def __init__(self, args, task):
|
||||
self.args = args
|
||||
self.task = task
|
||||
|
||||
def is_adaptive(self):
|
||||
return False
|
||||
|
||||
def sampling_method_selector(self):
|
||||
args = self.args
|
||||
logger.info(f"selected sampler: {args.sampling_method}")
|
||||
if args.sampling_method == "uniform":
|
||||
return uniform
|
||||
elif args.sampling_method == "temperature" or self.is_adaptive():
|
||||
return make_temperature_sampling(float(args.sampling_temperature))
|
||||
else:
|
||||
# default to concating all data set together
|
||||
return None
|
||||
Reference in New Issue
Block a user