From 9f94c7b61e8dd76b409ccc058eb42135e993fc61 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 22 Oct 2024 09:25:52 -0600 Subject: [PATCH] Added experimental param multiplier to the ema module --- jobs/process/BaseSDTrainProcess.py | 3 ++- toolkit/config_modules.py | 7 +++++++ toolkit/ema.py | 7 ++++++- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 3067ae57..43d45f73 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -588,8 +588,9 @@ class BaseSDTrainProcess(BaseTrainProcess): params.append(param) self.ema = ExponentialMovingAverage( params, - self.train_config.ema_config.ema_decay, + decay=self.train_config.ema_config.ema_decay, use_feedback=self.train_config.ema_config.use_feedback, + param_multiplier=self.train_config.ema_config.param_multiplier, ) def before_dataset_load(self): diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index c897e9ed..51bb57e7 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -458,6 +458,11 @@ class EMAConfig: self.ema_decay: float = kwargs.get('ema_decay', 0.999) # feeds back the decay difference into the parameter self.use_feedback: bool = kwargs.get('use_feedback', False) + + # every update, the params are multiplied by this amount + # only use for things without a bias like lora + # similar to a decay in an optimizer but the opposite + self.param_multiplier: float = kwargs.get('param_multiplier', 1.0) class ReferenceDatasetConfig: @@ -546,6 +551,8 @@ class DatasetConfig: self.dataset_path: str = kwargs.get('dataset_path', None) self.default_caption: str = kwargs.get('default_caption', None) + # trigger word for just this dataset + self.trigger_word: str = kwargs.get('trigger_word', None) random_triggers = kwargs.get('random_triggers', []) # if they are a string, load them from a file if isinstance(random_triggers, str) and os.path.exists(random_triggers): diff --git a/toolkit/ema.py b/toolkit/ema.py index 6a5df2df..3be7c406 100644 --- a/toolkit/ema.py +++ b/toolkit/ema.py @@ -45,7 +45,8 @@ class ExponentialMovingAverage: decay: float = 0.995, use_num_updates: bool = True, # feeds back the decat to the parameter - use_feedback: bool = False + use_feedback: bool = False, + param_multiplier: float = 1.0 ): if parameters is None: raise ValueError("parameters must be provided") @@ -54,6 +55,7 @@ class ExponentialMovingAverage: self.decay = decay self.num_updates = 0 if use_num_updates else None self.use_feedback = use_feedback + self.param_multiplier = param_multiplier parameters = list(parameters) self.shadow_params = [ p.clone().detach() @@ -128,6 +130,9 @@ class ExponentialMovingAverage: if self.use_feedback: param.add_(tmp) + + if self.param_multiplier != 1.0: + param.mul_(self.param_multiplier) def copy_to( self,