diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 01d37b81..7e8d0576 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -266,6 +266,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.current_boundary_index = 0 self.steps_this_boundary = 0 + self.num_consecutive_oom = 0 def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]): # override in subclass @@ -1592,6 +1593,19 @@ class BaseSDTrainProcess(BaseTrainProcess): # if it has it if hasattr(te, 'enable_xformers_memory_efficient_attention'): te.enable_xformers_memory_efficient_attention() + + if self.train_config.attention_backend != 'native': + if hasattr(vae, 'set_attention_backend'): + vae.set_attention_backend(self.train_config.attention_backend) + if hasattr(unet, 'set_attention_backend'): + unet.set_attention_backend(self.train_config.attention_backend) + if isinstance(text_encoder, list): + for te in text_encoder: + if hasattr(te, 'set_attention_backend'): + te.set_attention_backend(self.train_config.attention_backend) + else: + if hasattr(text_encoder, 'set_attention_backend'): + text_encoder.set_attention_backend(self.train_config.attention_backend) if self.train_config.sdp: torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) @@ -2137,17 +2151,31 @@ class BaseSDTrainProcess(BaseTrainProcess): ### HOOK ### if self.torch_profiler is not None: self.torch_profiler.start() - with self.accelerator.accumulate(self.modules_being_trained): - try: + did_oom = False + try: + with self.accelerator.accumulate(self.modules_being_trained): loss_dict = self.hook_train_loop(batch_list) - except Exception as e: - traceback.print_exc() - #print batch info - print("Batch Items:") - for batch in batch_list: - for item in batch.file_items: - print(f" - {item.path}") - raise e + except torch.cuda.OutOfMemoryError: + did_oom = True + except RuntimeError as e: + if "CUDA out of memory" in str(e): + did_oom = True + else: + raise # not an OOM; surface real errors + if did_oom: + self.num_consecutive_oom += 1 + if self.num_consecutive_oom > 3: + raise RuntimeError("OOM during training step 3 times in a row, aborting training") + optimizer.zero_grad(set_to_none=True) + flush() + torch.cuda.ipc_collect() + # skip this step and keep going + print_acc("") + print_acc("################################################") + print_acc(f"# OOM during training step, skipping batch {self.num_consecutive_oom}/3 #") + print_acc("################################################") + print_acc("") + self.num_consecutive_oom = 0 if self.torch_profiler is not None: torch.cuda.synchronize() # Make sure all CUDA ops are done self.torch_profiler.stop() diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index ad6cc1f3..75abb5ff 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -357,6 +357,8 @@ class TrainConfig: self.dtype: str = kwargs.get('dtype', 'fp32') self.xformers = kwargs.get('xformers', False) self.sdp = kwargs.get('sdp', False) + # see https://huggingface.co/docs/diffusers/main/optimization/attention_backends#available-backends for options + self.attention_backend: str = kwargs.get('attention_backend', 'native') # native, flash, _flash_3_hub, _flash_3, self.train_unet = kwargs.get('train_unet', True) self.train_text_encoder = kwargs.get('train_text_encoder', False) self.train_refiner = kwargs.get('train_refiner', True)