mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Switch order to save first, then sample.
This commit is contained in:
@@ -2189,6 +2189,22 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
if self.step_num != self.start_step:
|
if self.step_num != self.start_step:
|
||||||
if is_sample_step or is_save_step:
|
if is_sample_step or is_save_step:
|
||||||
self.accelerator.wait_for_everyone()
|
self.accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
if is_save_step:
|
||||||
|
self.accelerator
|
||||||
|
# print above the progress bar
|
||||||
|
if self.progress_bar is not None:
|
||||||
|
self.progress_bar.pause()
|
||||||
|
print_acc(f"\nSaving at step {self.step_num}")
|
||||||
|
self.save(self.step_num)
|
||||||
|
self.ensure_params_requires_grad()
|
||||||
|
# clear any grads
|
||||||
|
optimizer.zero_grad()
|
||||||
|
flush()
|
||||||
|
flush_next = True
|
||||||
|
if self.progress_bar is not None:
|
||||||
|
self.progress_bar.unpause()
|
||||||
|
|
||||||
if is_sample_step:
|
if is_sample_step:
|
||||||
if self.progress_bar is not None:
|
if self.progress_bar is not None:
|
||||||
self.progress_bar.pause()
|
self.progress_bar.pause()
|
||||||
@@ -2206,21 +2222,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
if self.progress_bar is not None:
|
if self.progress_bar is not None:
|
||||||
self.progress_bar.unpause()
|
self.progress_bar.unpause()
|
||||||
|
|
||||||
if is_save_step:
|
|
||||||
self.accelerator
|
|
||||||
# print above the progress bar
|
|
||||||
if self.progress_bar is not None:
|
|
||||||
self.progress_bar.pause()
|
|
||||||
print_acc(f"\nSaving at step {self.step_num}")
|
|
||||||
self.save(self.step_num)
|
|
||||||
self.ensure_params_requires_grad()
|
|
||||||
# clear any grads
|
|
||||||
optimizer.zero_grad()
|
|
||||||
flush()
|
|
||||||
flush_next = True
|
|
||||||
if self.progress_bar is not None:
|
|
||||||
self.progress_bar.unpause()
|
|
||||||
|
|
||||||
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
|
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
|
||||||
if self.progress_bar is not None:
|
if self.progress_bar is not None:
|
||||||
self.progress_bar.pause()
|
self.progress_bar.pause()
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
VERSION = "0.5.5"
|
VERSION = "0.5.6"
|
||||||
Reference in New Issue
Block a user