mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Cleanup
This commit is contained in:
@@ -631,6 +631,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
path_to_save = file_path = os.path.join(self.save_root, 'learnable_snr.json')
|
||||
with open(path_to_save, 'w') as f:
|
||||
json.dump(json_data, f, indent=4)
|
||||
|
||||
print_acc(f"Saved checkpoint to {file_path}")
|
||||
|
||||
# save optimizer
|
||||
if self.optimizer is not None:
|
||||
@@ -639,11 +641,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
file_path = os.path.join(self.save_root, filename)
|
||||
state_dict = self.optimizer.state_dict()
|
||||
torch.save(state_dict, file_path)
|
||||
print_acc(f"Saved optimizer to {file_path}")
|
||||
except Exception as e:
|
||||
print_acc(e)
|
||||
print_acc("Could not save optimizer")
|
||||
|
||||
print_acc(f"Saved to {file_path}")
|
||||
self.clean_up_saves()
|
||||
self.post_save_hook(file_path)
|
||||
|
||||
@@ -2095,7 +2097,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# print above the progress bar
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.pause()
|
||||
print_acc(f"Saving at step {self.step_num}")
|
||||
print_acc(f"\nSaving at step {self.step_num}")
|
||||
self.save(self.step_num)
|
||||
self.ensure_params_requires_grad()
|
||||
if self.progress_bar is not None:
|
||||
|
||||
Reference in New Issue
Block a user