Allow user to set the attention backend. Add method to recomver from the occasional OOM if it is a rare event. Still exit if it ooms 3 times in a row.

This commit is contained in:
Jaret Burkett
2025-09-27 08:56:15 -06:00
parent 6da417261c
commit 3b1f7b0948
2 changed files with 40 additions and 10 deletions

View File

@@ -266,6 +266,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.current_boundary_index = 0
self.steps_this_boundary = 0
self.num_consecutive_oom = 0
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
# override in subclass
@@ -1592,6 +1593,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
# if it has it
if hasattr(te, 'enable_xformers_memory_efficient_attention'):
te.enable_xformers_memory_efficient_attention()
if self.train_config.attention_backend != 'native':
if hasattr(vae, 'set_attention_backend'):
vae.set_attention_backend(self.train_config.attention_backend)
if hasattr(unet, 'set_attention_backend'):
unet.set_attention_backend(self.train_config.attention_backend)
if isinstance(text_encoder, list):
for te in text_encoder:
if hasattr(te, 'set_attention_backend'):
te.set_attention_backend(self.train_config.attention_backend)
else:
if hasattr(text_encoder, 'set_attention_backend'):
text_encoder.set_attention_backend(self.train_config.attention_backend)
if self.train_config.sdp:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
@@ -2137,17 +2151,31 @@ class BaseSDTrainProcess(BaseTrainProcess):
### HOOK ###
if self.torch_profiler is not None:
self.torch_profiler.start()
with self.accelerator.accumulate(self.modules_being_trained):
try:
did_oom = False
try:
with self.accelerator.accumulate(self.modules_being_trained):
loss_dict = self.hook_train_loop(batch_list)
except Exception as e:
traceback.print_exc()
#print batch info
print("Batch Items:")
for batch in batch_list:
for item in batch.file_items:
print(f" - {item.path}")
raise e
except torch.cuda.OutOfMemoryError:
did_oom = True
except RuntimeError as e:
if "CUDA out of memory" in str(e):
did_oom = True
else:
raise # not an OOM; surface real errors
if did_oom:
self.num_consecutive_oom += 1
if self.num_consecutive_oom > 3:
raise RuntimeError("OOM during training step 3 times in a row, aborting training")
optimizer.zero_grad(set_to_none=True)
flush()
torch.cuda.ipc_collect()
# skip this step and keep going
print_acc("")
print_acc("################################################")
print_acc(f"# OOM during training step, skipping batch {self.num_consecutive_oom}/3 #")
print_acc("################################################")
print_acc("")
self.num_consecutive_oom = 0
if self.torch_profiler is not None:
torch.cuda.synchronize() # Make sure all CUDA ops are done
self.torch_profiler.stop()

View File

@@ -357,6 +357,8 @@ class TrainConfig:
self.dtype: str = kwargs.get('dtype', 'fp32')
self.xformers = kwargs.get('xformers', False)
self.sdp = kwargs.get('sdp', False)
# see https://huggingface.co/docs/diffusers/main/optimization/attention_backends#available-backends for options
self.attention_backend: str = kwargs.get('attention_backend', 'native') # native, flash, _flash_3_hub, _flash_3,
self.train_unet = kwargs.get('train_unet', True)
self.train_text_encoder = kwargs.get('train_text_encoder', False)
self.train_refiner = kwargs.get('train_refiner', True)