mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added initial support for layer offloading wit Wan 2.2 14B models.
This commit is contained in:
@@ -2149,6 +2149,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.torch_profiler is not None:
|
||||
self.torch_profiler.start()
|
||||
did_oom = False
|
||||
loss_dict = None
|
||||
try:
|
||||
with self.accelerator.accumulate(self.modules_being_trained):
|
||||
loss_dict = self.hook_train_loop(batch_list)
|
||||
@@ -2172,7 +2173,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
print_acc(f"# OOM during training step, skipping batch {self.num_consecutive_oom}/3 #")
|
||||
print_acc("################################################")
|
||||
print_acc("")
|
||||
self.num_consecutive_oom = 0
|
||||
else:
|
||||
self.num_consecutive_oom = 0
|
||||
if self.torch_profiler is not None:
|
||||
torch.cuda.synchronize() # Make sure all CUDA ops are done
|
||||
self.torch_profiler.stop()
|
||||
@@ -2191,25 +2193,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
with torch.no_grad():
|
||||
# torch.cuda.empty_cache()
|
||||
# if optimizer has get_lrs method, then use it
|
||||
if hasattr(optimizer, 'get_avg_learning_rate'):
|
||||
learning_rate = optimizer.get_avg_learning_rate()
|
||||
elif hasattr(optimizer, 'get_learning_rates'):
|
||||
learning_rate = optimizer.get_learning_rates()[0]
|
||||
elif self.train_config.optimizer.lower().startswith('dadaptation') or \
|
||||
self.train_config.optimizer.lower().startswith('prodigy'):
|
||||
learning_rate = (
|
||||
optimizer.param_groups[0]["d"] *
|
||||
optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
if not did_oom and loss_dict is not None:
|
||||
if hasattr(optimizer, 'get_avg_learning_rate'):
|
||||
learning_rate = optimizer.get_avg_learning_rate()
|
||||
elif hasattr(optimizer, 'get_learning_rates'):
|
||||
learning_rate = optimizer.get_learning_rates()[0]
|
||||
elif self.train_config.optimizer.lower().startswith('dadaptation') or \
|
||||
self.train_config.optimizer.lower().startswith('prodigy'):
|
||||
learning_rate = (
|
||||
optimizer.param_groups[0]["d"] *
|
||||
optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
|
||||
prog_bar_string = f"lr: {learning_rate:.1e}"
|
||||
for key, value in loss_dict.items():
|
||||
prog_bar_string += f" {key}: {value:.3e}"
|
||||
prog_bar_string = f"lr: {learning_rate:.1e}"
|
||||
for key, value in loss_dict.items():
|
||||
prog_bar_string += f" {key}: {value:.3e}"
|
||||
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.set_postfix_str(prog_bar_string)
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.set_postfix_str(prog_bar_string)
|
||||
|
||||
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
|
||||
if isinstance(batch, DataLoaderBatchDTO):
|
||||
|
||||
Reference in New Issue
Block a user