mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-25 14:53:57 +00:00
Added experimental param multiplier to the ema module
This commit is contained in:
@@ -588,8 +588,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
params.append(param)
|
||||
self.ema = ExponentialMovingAverage(
|
||||
params,
|
||||
self.train_config.ema_config.ema_decay,
|
||||
decay=self.train_config.ema_config.ema_decay,
|
||||
use_feedback=self.train_config.ema_config.use_feedback,
|
||||
param_multiplier=self.train_config.ema_config.param_multiplier,
|
||||
)
|
||||
|
||||
def before_dataset_load(self):
|
||||
|
||||
@@ -458,6 +458,11 @@ class EMAConfig:
|
||||
self.ema_decay: float = kwargs.get('ema_decay', 0.999)
|
||||
# feeds back the decay difference into the parameter
|
||||
self.use_feedback: bool = kwargs.get('use_feedback', False)
|
||||
|
||||
# every update, the params are multiplied by this amount
|
||||
# only use for things without a bias like lora
|
||||
# similar to a decay in an optimizer but the opposite
|
||||
self.param_multiplier: float = kwargs.get('param_multiplier', 1.0)
|
||||
|
||||
|
||||
class ReferenceDatasetConfig:
|
||||
@@ -546,6 +551,8 @@ class DatasetConfig:
|
||||
self.dataset_path: str = kwargs.get('dataset_path', None)
|
||||
|
||||
self.default_caption: str = kwargs.get('default_caption', None)
|
||||
# trigger word for just this dataset
|
||||
self.trigger_word: str = kwargs.get('trigger_word', None)
|
||||
random_triggers = kwargs.get('random_triggers', [])
|
||||
# if they are a string, load them from a file
|
||||
if isinstance(random_triggers, str) and os.path.exists(random_triggers):
|
||||
|
||||
@@ -45,7 +45,8 @@ class ExponentialMovingAverage:
|
||||
decay: float = 0.995,
|
||||
use_num_updates: bool = True,
|
||||
# feeds back the decat to the parameter
|
||||
use_feedback: bool = False
|
||||
use_feedback: bool = False,
|
||||
param_multiplier: float = 1.0
|
||||
):
|
||||
if parameters is None:
|
||||
raise ValueError("parameters must be provided")
|
||||
@@ -54,6 +55,7 @@ class ExponentialMovingAverage:
|
||||
self.decay = decay
|
||||
self.num_updates = 0 if use_num_updates else None
|
||||
self.use_feedback = use_feedback
|
||||
self.param_multiplier = param_multiplier
|
||||
parameters = list(parameters)
|
||||
self.shadow_params = [
|
||||
p.clone().detach()
|
||||
@@ -128,6 +130,9 @@ class ExponentialMovingAverage:
|
||||
|
||||
if self.use_feedback:
|
||||
param.add_(tmp)
|
||||
|
||||
if self.param_multiplier != 1.0:
|
||||
param.mul_(self.param_multiplier)
|
||||
|
||||
def copy_to(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user