diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 43508d4f..20a85e55 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -2189,6 +2189,22 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.step_num != self.start_step: if is_sample_step or is_save_step: self.accelerator.wait_for_everyone() + + if is_save_step: + self.accelerator + # print above the progress bar + if self.progress_bar is not None: + self.progress_bar.pause() + print_acc(f"\nSaving at step {self.step_num}") + self.save(self.step_num) + self.ensure_params_requires_grad() + # clear any grads + optimizer.zero_grad() + flush() + flush_next = True + if self.progress_bar is not None: + self.progress_bar.unpause() + if is_sample_step: if self.progress_bar is not None: self.progress_bar.pause() @@ -2206,21 +2222,6 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.progress_bar is not None: self.progress_bar.unpause() - if is_save_step: - self.accelerator - # print above the progress bar - if self.progress_bar is not None: - self.progress_bar.pause() - print_acc(f"\nSaving at step {self.step_num}") - self.save(self.step_num) - self.ensure_params_requires_grad() - # clear any grads - optimizer.zero_grad() - flush() - flush_next = True - if self.progress_bar is not None: - self.progress_bar.unpause() - if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0: if self.progress_bar is not None: self.progress_bar.pause() diff --git a/version.py b/version.py index d613b95e..e7b5a5b6 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.5.5" \ No newline at end of file +VERSION = "0.5.6" \ No newline at end of file