From 121a760c192ecc0eeb0bd89e8e3d6231596488f1 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 3 Sep 2024 07:24:18 -0600 Subject: [PATCH] Added proper grad accumulation --- extensions_built_in/sd_trainer/SDTrainer.py | 36 +++++++--- jobs/process/BaseSDTrainProcess.py | 77 +++++++++++---------- toolkit/config_modules.py | 9 +++ 3 files changed, 78 insertions(+), 44 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 280b4738..df24e435 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -910,7 +910,7 @@ class SDTrainer(BaseSDTrainProcess): **kwargs ) - def hook_train_loop(self, batch: 'DataLoaderBatchDTO'): + def train_single_accumulation(self, batch: DataLoaderBatchDTO): self.timer.start('preprocess_batch') batch = self.preprocess_batch(batch) dtype = get_torch_dtype(self.train_config.dtype) @@ -1243,7 +1243,8 @@ class SDTrainer(BaseSDTrainProcess): pred_kwargs = {} if has_adapter_img: - if (self.adapter and isinstance(self.adapter, T2IAdapter)) or (self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)): + if (self.adapter and isinstance(self.adapter, T2IAdapter)) or ( + self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)): with torch.set_grad_enabled(self.adapter is not None): adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter adapter_multiplier = get_adapter_multiplier() @@ -1283,7 +1284,8 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.do_cfg: embeds = [ - load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in range(noisy_latents.shape[0]) + load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in + range(noisy_latents.shape[0]) ] unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( embeds, @@ -1424,7 +1426,6 @@ class SDTrainer(BaseSDTrainProcess): if prior_pred is not None: prior_pred = prior_pred.detach() - # do the custom adapter after the prior prediction if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image: quad_count = random.randint(1, 4) @@ -1450,10 +1451,12 @@ class SDTrainer(BaseSDTrainProcess): self.adapter.add_extra_values(batch.extra_values.detach()) if self.train_config.do_cfg: - self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()), is_unconditional=True) + self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()), + is_unconditional=True) if has_adapter_img: - if (self.adapter and isinstance(self.adapter, ControlNetModel)) or (self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)): + if (self.adapter and isinstance(self.adapter, ControlNetModel)) or ( + self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)): if self.train_config.do_cfg: raise ValueError("ControlNetModel is not supported with CFG") with torch.set_grad_enabled(self.adapter is not None): @@ -1478,7 +1481,6 @@ class SDTrainer(BaseSDTrainProcess): pred_kwargs['down_block_additional_residuals'] = down_block_res_samples pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample - self.before_unet_predict() # do a prior pred if we have an unconditional image, we will swap out the giadance later if batch.unconditional_latents is not None or self.do_guided_loss: @@ -1526,7 +1528,6 @@ class SDTrainer(BaseSDTrainProcess): print("loss is nan") loss = torch.zeros_like(loss).requires_grad_(True) - with self.timer('backward'): # todo we have multiplier seperated. works for now as res are not in same batch, but need to change loss = loss * loss_multiplier.mean() @@ -1543,8 +1544,27 @@ class SDTrainer(BaseSDTrainProcess): loss.backward() else: self.scaler.scale(loss).backward() + + return loss.detach() # flush() + def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchDTO]]): + if isinstance(batch, list): + batch_list = batch + else: + batch_list = [batch] + total_loss = None + for batch in batch_list: + self.optimizer.zero_grad(set_to_none=True) + loss = self.train_single_accumulation(batch) + if total_loss is None: + total_loss = loss + else: + total_loss += loss + if len(batch_list) > 1 and self.model_config.low_vram: + torch.cuda.empty_cache() + + if not self.is_grad_accumulation_step: # fix this for multi params if self.train_config.optimizer != 'adafactor': diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 930a556e..c50a5cdc 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1648,42 +1648,47 @@ class BaseSDTrainProcess(BaseTrainProcess): is_sample_step = self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0 if self.train_config.disable_sampling: is_sample_step = False - # don't do a reg step on sample or save steps as we dont want to normalize on those - if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step: - try: - with self.timer('get_batch:reg'): - batch = next(dataloader_iterator_reg) - except StopIteration: - with self.timer('reset_batch:reg'): - # hit the end of an epoch, reset - self.progress_bar.pause() - dataloader_iterator_reg = iter(dataloader_reg) - trigger_dataloader_setup_epoch(dataloader_reg) - with self.timer('get_batch:reg'): - batch = next(dataloader_iterator_reg) - self.progress_bar.unpause() - is_reg_step = True - elif dataloader is not None: - try: - with self.timer('get_batch'): - batch = next(dataloader_iterator) - except StopIteration: - with self.timer('reset_batch'): - # hit the end of an epoch, reset - self.progress_bar.pause() - dataloader_iterator = iter(dataloader) - trigger_dataloader_setup_epoch(dataloader) - self.epoch_num += 1 - if self.train_config.gradient_accumulation_steps == -1: - # if we are accumulating for an entire epoch, trigger a step - self.is_grad_accumulation_step = False - self.grad_accumulation_step = 0 - with self.timer('get_batch'): - batch = next(dataloader_iterator) - self.progress_bar.unpause() - else: - batch = None + batch_list = [] + + for b in range(self.train_config.gradient_accumulation): + # don't do a reg step on sample or save steps as we dont want to normalize on those + if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step: + try: + with self.timer('get_batch:reg'): + batch = next(dataloader_iterator_reg) + except StopIteration: + with self.timer('reset_batch:reg'): + # hit the end of an epoch, reset + self.progress_bar.pause() + dataloader_iterator_reg = iter(dataloader_reg) + trigger_dataloader_setup_epoch(dataloader_reg) + + with self.timer('get_batch:reg'): + batch = next(dataloader_iterator_reg) + self.progress_bar.unpause() + is_reg_step = True + elif dataloader is not None: + try: + with self.timer('get_batch'): + batch = next(dataloader_iterator) + except StopIteration: + with self.timer('reset_batch'): + # hit the end of an epoch, reset + self.progress_bar.pause() + dataloader_iterator = iter(dataloader) + trigger_dataloader_setup_epoch(dataloader) + self.epoch_num += 1 + if self.train_config.gradient_accumulation_steps == -1: + # if we are accumulating for an entire epoch, trigger a step + self.is_grad_accumulation_step = False + self.grad_accumulation_step = 0 + with self.timer('get_batch'): + batch = next(dataloader_iterator) + self.progress_bar.unpause() + else: + batch = None + batch_list.append(batch) # setup accumulation if self.train_config.gradient_accumulation_steps == -1: @@ -1701,7 +1706,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # flush() ### HOOK ### - loss_dict = self.hook_train_loop(batch) + loss_dict = self.hook_train_loop(batch_list) self.timer.stop('train_loop') if not did_first_flush: flush() diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index deb4d68b..4950868c 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -236,6 +236,7 @@ class TrainConfig: self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0) self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000) self.batch_size: int = kwargs.get('batch_size', 1) + self.orig_batch_size: int = self.batch_size self.dtype: str = kwargs.get('dtype', 'fp32') self.xformers = kwargs.get('xformers', False) self.sdp = kwargs.get('sdp', False) @@ -284,8 +285,16 @@ class TrainConfig: # set to -1 to accumulate gradients for entire epoch # warning, only do this with a small dataset or you will run out of memory + # This is legacy but left in for backwards compatibility self.gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1) + # this will do proper gradient accumulation where you will not see a step until the end of the accumulation + # the method above will show a step every accumulation + self.gradient_accumulation = kwargs.get('gradient_accumulation', 1) + if self.gradient_accumulation > 1: + if self.gradient_accumulation_steps != 1: + raise ValueError("gradient_accumulation and gradient_accumulation_steps are mutually exclusive") + # short long captions will double your batch size. This only works when a dataset is # prepared with a json caption file that has both short and long captions in it. It will # Double up every image and run it through with both short and long captions. The idea