added ema

This commit is contained in:
Jaret Burkett
2024-06-28 10:03:26 -06:00
parent 657fd09f25
commit 603ceca3ca
4 changed files with 367 additions and 3 deletions

View File

@@ -29,6 +29,7 @@ import gc
import torch
from jobs.process import BaseSDTrainProcess
from torchvision import transforms
from diffusers import EMAModel
import math
@@ -1510,6 +1511,9 @@ class SDTrainer(BaseSDTrainProcess):
# self.scaler.update()
# self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
if self.ema is not None:
with self.timer('ema_update'):
self.ema.update()
else:
# gradient accumulation. Just a place for breakpoint
pass