mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Allow user to set the attention backend. Add method to recomver from the occasional OOM if it is a rare event. Still exit if it ooms 3 times in a row.
This commit is contained in:
@@ -266,6 +266,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
self.current_boundary_index = 0
|
||||
self.steps_this_boundary = 0
|
||||
self.num_consecutive_oom = 0
|
||||
|
||||
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
|
||||
# override in subclass
|
||||
@@ -1592,6 +1593,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# if it has it
|
||||
if hasattr(te, 'enable_xformers_memory_efficient_attention'):
|
||||
te.enable_xformers_memory_efficient_attention()
|
||||
|
||||
if self.train_config.attention_backend != 'native':
|
||||
if hasattr(vae, 'set_attention_backend'):
|
||||
vae.set_attention_backend(self.train_config.attention_backend)
|
||||
if hasattr(unet, 'set_attention_backend'):
|
||||
unet.set_attention_backend(self.train_config.attention_backend)
|
||||
if isinstance(text_encoder, list):
|
||||
for te in text_encoder:
|
||||
if hasattr(te, 'set_attention_backend'):
|
||||
te.set_attention_backend(self.train_config.attention_backend)
|
||||
else:
|
||||
if hasattr(text_encoder, 'set_attention_backend'):
|
||||
text_encoder.set_attention_backend(self.train_config.attention_backend)
|
||||
if self.train_config.sdp:
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
torch.backends.cuda.enable_flash_sdp(True)
|
||||
@@ -2137,17 +2151,31 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
### HOOK ###
|
||||
if self.torch_profiler is not None:
|
||||
self.torch_profiler.start()
|
||||
with self.accelerator.accumulate(self.modules_being_trained):
|
||||
try:
|
||||
did_oom = False
|
||||
try:
|
||||
with self.accelerator.accumulate(self.modules_being_trained):
|
||||
loss_dict = self.hook_train_loop(batch_list)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
#print batch info
|
||||
print("Batch Items:")
|
||||
for batch in batch_list:
|
||||
for item in batch.file_items:
|
||||
print(f" - {item.path}")
|
||||
raise e
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
did_oom = True
|
||||
except RuntimeError as e:
|
||||
if "CUDA out of memory" in str(e):
|
||||
did_oom = True
|
||||
else:
|
||||
raise # not an OOM; surface real errors
|
||||
if did_oom:
|
||||
self.num_consecutive_oom += 1
|
||||
if self.num_consecutive_oom > 3:
|
||||
raise RuntimeError("OOM during training step 3 times in a row, aborting training")
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
flush()
|
||||
torch.cuda.ipc_collect()
|
||||
# skip this step and keep going
|
||||
print_acc("")
|
||||
print_acc("################################################")
|
||||
print_acc(f"# OOM during training step, skipping batch {self.num_consecutive_oom}/3 #")
|
||||
print_acc("################################################")
|
||||
print_acc("")
|
||||
self.num_consecutive_oom = 0
|
||||
if self.torch_profiler is not None:
|
||||
torch.cuda.synchronize() # Make sure all CUDA ops are done
|
||||
self.torch_profiler.stop()
|
||||
|
||||
@@ -357,6 +357,8 @@ class TrainConfig:
|
||||
self.dtype: str = kwargs.get('dtype', 'fp32')
|
||||
self.xformers = kwargs.get('xformers', False)
|
||||
self.sdp = kwargs.get('sdp', False)
|
||||
# see https://huggingface.co/docs/diffusers/main/optimization/attention_backends#available-backends for options
|
||||
self.attention_backend: str = kwargs.get('attention_backend', 'native') # native, flash, _flash_3_hub, _flash_3,
|
||||
self.train_unet = kwargs.get('train_unet', True)
|
||||
self.train_text_encoder = kwargs.get('train_text_encoder', False)
|
||||
self.train_refiner = kwargs.get('train_refiner', True)
|
||||
|
||||
Reference in New Issue
Block a user