Adjustments to loading of flux. Added a feedback to ema

This commit is contained in:
Jaret Burkett
2024-08-07 13:17:26 -06:00
parent 653fe60f16
commit acafe9984f
5 changed files with 27 additions and 10 deletions

View File

@@ -43,7 +43,9 @@ class ExponentialMovingAverage:
self,
parameters: Iterable[torch.nn.Parameter] = None,
decay: float = 0.995,
use_num_updates: bool = True
use_num_updates: bool = True,
# feeds back the decat to the parameter
use_feedback: bool = False
):
if parameters is None:
raise ValueError("parameters must be provided")
@@ -51,6 +53,7 @@ class ExponentialMovingAverage:
raise ValueError('Decay must be between 0 and 1')
self.decay = decay
self.num_updates = 0 if use_num_updates else None
self.use_feedback = use_feedback
parameters = list(parameters)
self.shadow_params = [
p.clone().detach()
@@ -123,6 +126,9 @@ class ExponentialMovingAverage:
tmp.mul_(one_minus_decay)
s_param.sub_(tmp)
if self.use_feedback:
param.add_(tmp)
def copy_to(
self,
parameters: Optional[Iterable[torch.nn.Parameter]] = None