mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Switch order to save first, then sample.
This commit is contained in:
@@ -2189,6 +2189,22 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.step_num != self.start_step:
|
||||
if is_sample_step or is_save_step:
|
||||
self.accelerator.wait_for_everyone()
|
||||
|
||||
if is_save_step:
|
||||
self.accelerator
|
||||
# print above the progress bar
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.pause()
|
||||
print_acc(f"\nSaving at step {self.step_num}")
|
||||
self.save(self.step_num)
|
||||
self.ensure_params_requires_grad()
|
||||
# clear any grads
|
||||
optimizer.zero_grad()
|
||||
flush()
|
||||
flush_next = True
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.unpause()
|
||||
|
||||
if is_sample_step:
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.pause()
|
||||
@@ -2206,21 +2222,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.unpause()
|
||||
|
||||
if is_save_step:
|
||||
self.accelerator
|
||||
# print above the progress bar
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.pause()
|
||||
print_acc(f"\nSaving at step {self.step_num}")
|
||||
self.save(self.step_num)
|
||||
self.ensure_params_requires_grad()
|
||||
# clear any grads
|
||||
optimizer.zero_grad()
|
||||
flush()
|
||||
flush_next = True
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.unpause()
|
||||
|
||||
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.pause()
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.5.5"
|
||||
VERSION = "0.5.6"
|
||||
Reference in New Issue
Block a user