Added experimental param multiplier to the ema module

This commit is contained in:
Jaret Burkett
2024-10-22 09:25:52 -06:00
parent bedb8197a2
commit 9f94c7b61e
3 changed files with 15 additions and 2 deletions

View File

@@ -588,8 +588,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
params.append(param)
self.ema = ExponentialMovingAverage(
params,
self.train_config.ema_config.ema_decay,
decay=self.train_config.ema_config.ema_decay,
use_feedback=self.train_config.ema_config.use_feedback,
param_multiplier=self.train_config.ema_config.param_multiplier,
)
def before_dataset_load(self):