added ema

This commit is contained in:
Jaret Burkett
2024-06-28 10:03:26 -06:00
parent 657fd09f25
commit 603ceca3ca
4 changed files with 367 additions and 3 deletions

View File

@@ -22,6 +22,7 @@ from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.custom_adapter import CustomAdapter
from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.ema import ExponentialMovingAverage
from toolkit.embedding import Embedding
from toolkit.image_utils import show_tensors, show_latents
from toolkit.ip_adapter import IPAdapter
@@ -174,6 +175,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.embed_config is not None or is_training_adapter:
self.named_lora = True
self.snr_gos: Union[LearnableSNRGamma, None] = None
self.ema: ExponentialMovingAverage = None
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
# override in subclass
@@ -253,9 +255,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
# post process
gen_img_config_list = self.post_process_generate_image_config_list(gen_img_config_list)
# if we have an ema, set it to validation mode
if self.ema is not None:
self.ema.eval()
# send to be generated
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
if self.ema is not None:
self.ema.train()
def update_training_metadata(self):
o_dict = OrderedDict({
"training_info": self.get_training_info()
@@ -369,6 +378,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
def save(self, step=None):
flush()
if self.ema is not None:
# always save params as ema
self.ema.eval()
if not os.path.exists(self.save_root):
os.makedirs(self.save_root, exist_ok=True)
@@ -527,6 +540,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.print(f"Saved to {file_path}")
self.clean_up_saves()
self.post_save_hook(file_path)
if self.ema is not None:
self.ema.train()
flush()
# Called before the model is loaded
@@ -541,6 +557,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
def hook_before_train_loop(self):
pass
def setup_ema(self):
if self.train_config.ema_config.use_ema:
# our params are in groups. We need them as a single iterable
params = []
for group in self.optimizer.param_groups:
for param in group['params']:
params.append(param)
self.ema = ExponentialMovingAverage(
params,
self.train_config.ema_config.ema_decay
)
def before_dataset_load(self):
pass