diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 393ba831..b4f768d9 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -629,7 +629,10 @@ class BaseSDTrainProcess(BaseTrainProcess): try: filename = f'optimizer.pt' file_path = os.path.join(self.save_root, filename) - state_dict = unwrap_model(self.optimizer).state_dict() + try: + state_dict = unwrap_model(self.optimizer).state_dict() + except Exception as e: + state_dict = self.optimizer.state_dict() torch.save(state_dict, file_path) print_acc(f"Saved optimizer to {file_path}") except Exception as e: