mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Bug fixes
This commit is contained in:
@@ -557,6 +557,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def hook_before_train_loop(self):
|
||||
pass
|
||||
|
||||
def ensure_params_requires_grad(self):
|
||||
# get param groups
|
||||
for group in self.optimizer.param_groups:
|
||||
for param in group['params']:
|
||||
param.requires_grad = True
|
||||
|
||||
def setup_ema(self):
|
||||
if self.train_config.ema_config.use_ema:
|
||||
# our params are in groups. We need them as a single iterable
|
||||
@@ -1535,6 +1541,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# print(f"Compiling Model")
|
||||
# torch.compile(self.sd.unet, dynamic=True)
|
||||
|
||||
# make sure all params require grad
|
||||
self.ensure_params_requires_grad()
|
||||
|
||||
|
||||
###################################################################
|
||||
# TRAIN LOOP
|
||||
###################################################################
|
||||
@@ -1652,6 +1662,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.train_config.free_u:
|
||||
self.sd.pipeline.disable_freeu()
|
||||
self.sample(self.step_num)
|
||||
self.ensure_params_requires_grad()
|
||||
self.progress_bar.unpause()
|
||||
|
||||
if is_save_step:
|
||||
@@ -1659,6 +1670,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.progress_bar.pause()
|
||||
self.print(f"Saving at step {self.step_num}")
|
||||
self.save(self.step_num)
|
||||
self.ensure_params_requires_grad()
|
||||
self.progress_bar.unpause()
|
||||
|
||||
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
|
||||
|
||||
Reference in New Issue
Block a user