Added initial support for layer offloading wit Wan 2.2 14B models.

This commit is contained in:
Jaret Burkett
2025-10-20 14:54:30 -06:00
parent 8bbaa4e224
commit 76ce757e0c
7 changed files with 93 additions and 50 deletions

View File

@@ -2149,6 +2149,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.torch_profiler is not None:
self.torch_profiler.start()
did_oom = False
loss_dict = None
try:
with self.accelerator.accumulate(self.modules_being_trained):
loss_dict = self.hook_train_loop(batch_list)
@@ -2172,7 +2173,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
print_acc(f"# OOM during training step, skipping batch {self.num_consecutive_oom}/3 #")
print_acc("################################################")
print_acc("")
self.num_consecutive_oom = 0
else:
self.num_consecutive_oom = 0
if self.torch_profiler is not None:
torch.cuda.synchronize() # Make sure all CUDA ops are done
self.torch_profiler.stop()
@@ -2191,25 +2193,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
with torch.no_grad():
# torch.cuda.empty_cache()
# if optimizer has get_lrs method, then use it
if hasattr(optimizer, 'get_avg_learning_rate'):
learning_rate = optimizer.get_avg_learning_rate()
elif hasattr(optimizer, 'get_learning_rates'):
learning_rate = optimizer.get_learning_rates()[0]
elif self.train_config.optimizer.lower().startswith('dadaptation') or \
self.train_config.optimizer.lower().startswith('prodigy'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']
if not did_oom and loss_dict is not None:
if hasattr(optimizer, 'get_avg_learning_rate'):
learning_rate = optimizer.get_avg_learning_rate()
elif hasattr(optimizer, 'get_learning_rates'):
learning_rate = optimizer.get_learning_rates()[0]
elif self.train_config.optimizer.lower().startswith('dadaptation') or \
self.train_config.optimizer.lower().startswith('prodigy'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']
prog_bar_string = f"lr: {learning_rate:.1e}"
for key, value in loss_dict.items():
prog_bar_string += f" {key}: {value:.3e}"
prog_bar_string = f"lr: {learning_rate:.1e}"
for key, value in loss_dict.items():
prog_bar_string += f" {key}: {value:.3e}"
if self.progress_bar is not None:
self.progress_bar.set_postfix_str(prog_bar_string)
if self.progress_bar is not None:
self.progress_bar.set_postfix_str(prog_bar_string)
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
if isinstance(batch, DataLoaderBatchDTO):