mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 10:11:14 +00:00
added ema
This commit is contained in:
@@ -29,6 +29,7 @@ import gc
|
||||
import torch
|
||||
from jobs.process import BaseSDTrainProcess
|
||||
from torchvision import transforms
|
||||
from diffusers import EMAModel
|
||||
import math
|
||||
|
||||
|
||||
@@ -1510,6 +1511,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# self.scaler.update()
|
||||
# self.optimizer.step()
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
if self.ema is not None:
|
||||
with self.timer('ema_update'):
|
||||
self.ema.update()
|
||||
else:
|
||||
# gradient accumulation. Just a place for breakpoint
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user