mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Adjustments to loading of flux. Added a feedback to ema
This commit is contained in:
@@ -43,7 +43,9 @@ class ExponentialMovingAverage:
|
||||
self,
|
||||
parameters: Iterable[torch.nn.Parameter] = None,
|
||||
decay: float = 0.995,
|
||||
use_num_updates: bool = True
|
||||
use_num_updates: bool = True,
|
||||
# feeds back the decat to the parameter
|
||||
use_feedback: bool = False
|
||||
):
|
||||
if parameters is None:
|
||||
raise ValueError("parameters must be provided")
|
||||
@@ -51,6 +53,7 @@ class ExponentialMovingAverage:
|
||||
raise ValueError('Decay must be between 0 and 1')
|
||||
self.decay = decay
|
||||
self.num_updates = 0 if use_num_updates else None
|
||||
self.use_feedback = use_feedback
|
||||
parameters = list(parameters)
|
||||
self.shadow_params = [
|
||||
p.clone().detach()
|
||||
@@ -123,6 +126,9 @@ class ExponentialMovingAverage:
|
||||
tmp.mul_(one_minus_decay)
|
||||
s_param.sub_(tmp)
|
||||
|
||||
if self.use_feedback:
|
||||
param.add_(tmp)
|
||||
|
||||
def copy_to(
|
||||
self,
|
||||
parameters: Optional[Iterable[torch.nn.Parameter]] = None
|
||||
|
||||
Reference in New Issue
Block a user