This commit is contained in:
Jaret Burkett
2025-04-18 11:44:49 -06:00
parent 1628884254
commit d455e76c4f
3 changed files with 91 additions and 78 deletions

View File

@@ -631,6 +631,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
path_to_save = file_path = os.path.join(self.save_root, 'learnable_snr.json')
with open(path_to_save, 'w') as f:
json.dump(json_data, f, indent=4)
print_acc(f"Saved checkpoint to {file_path}")
# save optimizer
if self.optimizer is not None:
@@ -639,11 +641,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
file_path = os.path.join(self.save_root, filename)
state_dict = self.optimizer.state_dict()
torch.save(state_dict, file_path)
print_acc(f"Saved optimizer to {file_path}")
except Exception as e:
print_acc(e)
print_acc("Could not save optimizer")
print_acc(f"Saved to {file_path}")
self.clean_up_saves()
self.post_save_hook(file_path)
@@ -2095,7 +2097,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# print above the progress bar
if self.progress_bar is not None:
self.progress_bar.pause()
print_acc(f"Saving at step {self.step_num}")
print_acc(f"\nSaving at step {self.step_num}")
self.save(self.step_num)
self.ensure_params_requires_grad()
if self.progress_bar is not None: