Added experimental param multiplier to the ema module

This commit is contained in:
Jaret Burkett
2024-10-22 09:25:52 -06:00
parent bedb8197a2
commit 9f94c7b61e
3 changed files with 15 additions and 2 deletions

View File

@@ -588,8 +588,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
params.append(param)
self.ema = ExponentialMovingAverage(
params,
self.train_config.ema_config.ema_decay,
decay=self.train_config.ema_config.ema_decay,
use_feedback=self.train_config.ema_config.use_feedback,
param_multiplier=self.train_config.ema_config.param_multiplier,
)
def before_dataset_load(self):

View File

@@ -458,6 +458,11 @@ class EMAConfig:
self.ema_decay: float = kwargs.get('ema_decay', 0.999)
# feeds back the decay difference into the parameter
self.use_feedback: bool = kwargs.get('use_feedback', False)
# every update, the params are multiplied by this amount
# only use for things without a bias like lora
# similar to a decay in an optimizer but the opposite
self.param_multiplier: float = kwargs.get('param_multiplier', 1.0)
class ReferenceDatasetConfig:
@@ -546,6 +551,8 @@ class DatasetConfig:
self.dataset_path: str = kwargs.get('dataset_path', None)
self.default_caption: str = kwargs.get('default_caption', None)
# trigger word for just this dataset
self.trigger_word: str = kwargs.get('trigger_word', None)
random_triggers = kwargs.get('random_triggers', [])
# if they are a string, load them from a file
if isinstance(random_triggers, str) and os.path.exists(random_triggers):

View File

@@ -45,7 +45,8 @@ class ExponentialMovingAverage:
decay: float = 0.995,
use_num_updates: bool = True,
# feeds back the decat to the parameter
use_feedback: bool = False
use_feedback: bool = False,
param_multiplier: float = 1.0
):
if parameters is None:
raise ValueError("parameters must be provided")
@@ -54,6 +55,7 @@ class ExponentialMovingAverage:
self.decay = decay
self.num_updates = 0 if use_num_updates else None
self.use_feedback = use_feedback
self.param_multiplier = param_multiplier
parameters = list(parameters)
self.shadow_params = [
p.clone().detach()
@@ -128,6 +130,9 @@ class ExponentialMovingAverage:
if self.use_feedback:
param.add_(tmp)
if self.param_multiplier != 1.0:
param.mul_(self.param_multiplier)
def copy_to(
self,