Switch order to save first, then sample.

This commit is contained in:
Jaret Burkett
2025-08-27 11:07:03 -06:00
parent 1f541bc5d8
commit fc5b41666a
2 changed files with 17 additions and 16 deletions

View File

@@ -2189,6 +2189,22 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.step_num != self.start_step:
if is_sample_step or is_save_step:
self.accelerator.wait_for_everyone()
if is_save_step:
self.accelerator
# print above the progress bar
if self.progress_bar is not None:
self.progress_bar.pause()
print_acc(f"\nSaving at step {self.step_num}")
self.save(self.step_num)
self.ensure_params_requires_grad()
# clear any grads
optimizer.zero_grad()
flush()
flush_next = True
if self.progress_bar is not None:
self.progress_bar.unpause()
if is_sample_step:
if self.progress_bar is not None:
self.progress_bar.pause()
@@ -2206,21 +2222,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.progress_bar is not None:
self.progress_bar.unpause()
if is_save_step:
self.accelerator
# print above the progress bar
if self.progress_bar is not None:
self.progress_bar.pause()
print_acc(f"\nSaving at step {self.step_num}")
self.save(self.step_num)
self.ensure_params_requires_grad()
# clear any grads
optimizer.zero_grad()
flush()
flush_next = True
if self.progress_bar is not None:
self.progress_bar.unpause()
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
if self.progress_bar is not None:
self.progress_bar.pause()

View File

@@ -1 +1 @@
VERSION = "0.5.5"
VERSION = "0.5.6"