mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-25 16:59:22 +00:00
added prompt dropout to happen indempendently on each TE
This commit is contained in:
@@ -419,6 +419,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
with open(path_to_save, 'w') as f:
|
||||
json.dump(json_data, f, indent=4)
|
||||
|
||||
# save optimizer
|
||||
if self.optimizer is not None:
|
||||
try:
|
||||
filename = f'optimizer.pt'
|
||||
file_path = os.path.join(self.save_root, filename)
|
||||
torch.save(self.optimizer.state_dict(), file_path)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Could not save optimizer")
|
||||
|
||||
self.print(f"Saved to {file_path}")
|
||||
self.clean_up_saves()
|
||||
self.post_save_hook(file_path)
|
||||
@@ -1121,6 +1131,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
optimizer_params=self.train_config.optimizer_params)
|
||||
self.optimizer = optimizer
|
||||
|
||||
# check if it exists
|
||||
optimizer_state_filename = f'optimizer.pt'
|
||||
optimizer_state_file_path = os.path.join(self.save_root, optimizer_state_filename)
|
||||
if os.path.exists(optimizer_state_file_path):
|
||||
# try to load
|
||||
try:
|
||||
print(f"Loading optimizer state from {optimizer_state_file_path}")
|
||||
optimizer_state_dict = torch.load(optimizer_state_file_path)
|
||||
optimizer.load_state_dict(optimizer_state_dict)
|
||||
except Exception as e:
|
||||
print(f"Failed to load optimizer state from {optimizer_state_file_path}")
|
||||
print(e)
|
||||
|
||||
lr_scheduler_params = self.train_config.lr_scheduler_params
|
||||
|
||||
# make sure it had bare minimum
|
||||
|
||||
Reference in New Issue
Block a user