added prompt dropout to happen indempendently on each TE

This commit is contained in:
Jaret Burkett
2023-11-14 05:26:51 -07:00
parent 7782caa468
commit 4f9cdd916a
7 changed files with 144 additions and 15 deletions

View File

@@ -419,6 +419,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
with open(path_to_save, 'w') as f:
json.dump(json_data, f, indent=4)
# save optimizer
if self.optimizer is not None:
try:
filename = f'optimizer.pt'
file_path = os.path.join(self.save_root, filename)
torch.save(self.optimizer.state_dict(), file_path)
except Exception as e:
print(e)
print("Could not save optimizer")
self.print(f"Saved to {file_path}")
self.clean_up_saves()
self.post_save_hook(file_path)
@@ -1121,6 +1131,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
optimizer_params=self.train_config.optimizer_params)
self.optimizer = optimizer
# check if it exists
optimizer_state_filename = f'optimizer.pt'
optimizer_state_file_path = os.path.join(self.save_root, optimizer_state_filename)
if os.path.exists(optimizer_state_file_path):
# try to load
try:
print(f"Loading optimizer state from {optimizer_state_file_path}")
optimizer_state_dict = torch.load(optimizer_state_file_path)
optimizer.load_state_dict(optimizer_state_dict)
except Exception as e:
print(f"Failed to load optimizer state from {optimizer_state_file_path}")
print(e)
lr_scheduler_params = self.train_config.lr_scheduler_params
# make sure it had bare minimum