Bug fixes

This commit is contained in:
Jaret Burkett
2024-07-03 10:56:34 -06:00
parent bb57623a35
commit acb06d6ff3
6 changed files with 133 additions and 10 deletions

View File

@@ -557,6 +557,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
def hook_before_train_loop(self):
pass
def ensure_params_requires_grad(self):
# get param groups
for group in self.optimizer.param_groups:
for param in group['params']:
param.requires_grad = True
def setup_ema(self):
if self.train_config.ema_config.use_ema:
# our params are in groups. We need them as a single iterable
@@ -1535,6 +1541,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
# print(f"Compiling Model")
# torch.compile(self.sd.unet, dynamic=True)
# make sure all params require grad
self.ensure_params_requires_grad()
###################################################################
# TRAIN LOOP
###################################################################
@@ -1652,6 +1662,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.train_config.free_u:
self.sd.pipeline.disable_freeu()
self.sample(self.step_num)
self.ensure_params_requires_grad()
self.progress_bar.unpause()
if is_save_step:
@@ -1659,6 +1670,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.progress_bar.pause()
self.print(f"Saving at step {self.step_num}")
self.save(self.step_num)
self.ensure_params_requires_grad()
self.progress_bar.unpause()
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0: