mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-02-26 08:04:10 +00:00
79 lines
2.0 KiB
Python
79 lines
2.0 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import logging
|
|
from typing import List
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def uniform(dataset_sizes: List[int]):
|
|
return [1.0] * len(dataset_sizes)
|
|
|
|
|
|
def temperature_sampling(dataset_sizes, temp):
|
|
total_size = sum(dataset_sizes)
|
|
return [(size / total_size) ** (1.0 / temp) for size in dataset_sizes]
|
|
|
|
|
|
def make_temperature_sampling(temp=1.0):
|
|
def sampling_func(dataset_sizes):
|
|
return temperature_sampling(dataset_sizes, temp)
|
|
|
|
return sampling_func
|
|
|
|
|
|
def make_ratio_sampling(ratios):
|
|
def sampling_func(dataset_sizes):
|
|
return ratios
|
|
|
|
return sampling_func
|
|
|
|
|
|
class SamplingMethod:
|
|
@staticmethod
|
|
def add_arguments(parser):
|
|
parser.add_argument(
|
|
"--sampling-method",
|
|
choices=[
|
|
"uniform",
|
|
"temperature",
|
|
"concat",
|
|
"RoundRobin",
|
|
],
|
|
type=str,
|
|
default="concat",
|
|
help="The method to sample data per language pairs",
|
|
)
|
|
parser.add_argument(
|
|
"--sampling-temperature",
|
|
default=1.5,
|
|
type=float,
|
|
help="only work with --sampling-method temperature",
|
|
)
|
|
|
|
@staticmethod
|
|
def build_sampler(args, task):
|
|
return SamplingMethod(args, task)
|
|
|
|
def __init__(self, args, task):
|
|
self.args = args
|
|
self.task = task
|
|
|
|
def is_adaptive(self):
|
|
return False
|
|
|
|
def sampling_method_selector(self):
|
|
args = self.args
|
|
logger.info(f"selected sampler: {args.sampling_method}")
|
|
if args.sampling_method == "uniform":
|
|
return uniform
|
|
elif args.sampling_method == "temperature" or self.is_adaptive():
|
|
return make_temperature_sampling(float(args.sampling_temperature))
|
|
else:
|
|
# default to concating all data set together
|
|
return None
|