From 298001439a4f6cb34da55eb9a9c9cf6ccbb71617 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 28 Oct 2023 13:14:29 -0600 Subject: [PATCH] Added gradient accumulation finally --- extensions_built_in/sd_trainer/SDTrainer.py | 17 +++--- jobs/process/BaseSDTrainProcess.py | 63 ++++++++++++++++++--- toolkit/config_modules.py | 7 +++ toolkit/lorm.py | 8 +++ 4 files changed, 79 insertions(+), 16 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 269e5892..da2baacf 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -448,16 +448,17 @@ class SDTrainer(BaseSDTrainProcess): # I spent weeks on fighting this. DON'T DO IT # with fsdp_overlap_step_with_backward(): loss.backward() - - torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) # flush() - with self.timer('optimizer_step'): - # apply gradients - self.optimizer.step() - self.optimizer.zero_grad(set_to_none=True) - with self.timer('scheduler_step'): - self.lr_scheduler.step() + if not self.is_grad_accumulation_step: + torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) + # only step if we are not accumulating + with self.timer('optimizer_step'): + # apply gradients + self.optimizer.step() + self.optimizer.zero_grad(set_to_none=True) + with self.timer('scheduler_step'): + self.lr_scheduler.step() if self.embedding is not None: with self.timer('restore_embeddings'): diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index ce8f5b00..53266604 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -60,6 +60,11 @@ class BaseSDTrainProcess(BaseTrainProcess): self.custom_pipeline = custom_pipeline self.step_num = 0 self.start_step = 0 + self.epoch_num = 0 + # start at 1 so we can do a sample at the start + self.grad_accumulation_step = 1 + # if true, then we do not do an optimizer step. We are accumulating gradients + self.is_grad_accumulation_step = False self.device = self.get_conf('device', self.job.device) self.device_torch = torch.device(self.device) network_config = self.get_conf('network', None) @@ -250,7 +255,8 @@ class BaseSDTrainProcess(BaseTrainProcess): def get_training_info(self): info = OrderedDict({ - 'step': self.step_num + 1 + 'step': self.step_num, + 'epoch': self.epoch_num, }) return info @@ -417,6 +423,8 @@ class BaseSDTrainProcess(BaseTrainProcess): # if 'training_info' in Orderdict keys if 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None: self.step_num = meta['training_info']['step'] + if 'epoch' in meta['training_info']: + self.epoch_num = meta['training_info']['epoch'] self.start_step = self.step_num print(f"Found step {self.step_num} in metadata, starting from there") @@ -441,6 +449,8 @@ class BaseSDTrainProcess(BaseTrainProcess): # if 'training_info' in Orderdict keys if 'training_info' in meta and 'step' in meta['training_info']: self.step_num = meta['training_info']['step'] + if 'epoch' in meta['training_info']: + self.epoch_num = meta['training_info']['epoch'] self.start_step = self.step_num print(f"Found step {self.step_num} in metadata, starting from there") @@ -712,6 +722,8 @@ class BaseSDTrainProcess(BaseTrainProcess): # if 'training_info' in Orderdict keys if 'training_info' in meta and 'step' in meta['training_info']: self.step_num = meta['training_info']['step'] + if 'epoch' in meta['training_info']: + self.epoch_num = meta['training_info']['epoch'] self.start_step = self.step_num print(f"Found step {self.step_num} in metadata, starting from there") @@ -1035,7 +1047,16 @@ class BaseSDTrainProcess(BaseTrainProcess): self.sd.set_device_state(self.train_device_state_preset) flush() # self.step_num = 0 - for step in range(self.step_num, self.train_config.steps): + + ################################################################### + # TRAIN LOOP + ################################################################### + + start_step_num = self.step_num + for step in range(start_step_num, self.train_config.steps): + self.step_num = step + # default to true so various things can turn it off + self.is_grad_accumulation_step = True if self.train_config.free_u: self.sd.pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2) self.progress_bar.unpause() @@ -1071,12 +1092,31 @@ class BaseSDTrainProcess(BaseTrainProcess): self.progress_bar.pause() dataloader_iterator = iter(dataloader) trigger_dataloader_setup_epoch(dataloader) + self.epoch_num += 1 + if self.train_config.gradient_accumulation_steps == -1: + # if we are accumulating for an entire epoch, trigger a step + self.is_grad_accumulation_step = False + self.grad_accumulation_step = 0 with self.timer('get_batch'): batch = next(dataloader_iterator) self.progress_bar.unpause() else: batch = None + # setup accumulation + if self.train_config.gradient_accumulation_steps == -1: + # epoch is handling the accumulation, dont touch it + pass + else: + # determine if we are accumulating or not + # since optimizer step happens in the loop, we trigger it a step early + # since we cannot reprocess it before them + optimizer_step_at = self.train_config.gradient_accumulation_steps + is_optimizer_step = self.grad_accumulation_step >= optimizer_step_at + self.is_grad_accumulation_step = not is_optimizer_step + if is_optimizer_step: + self.grad_accumulation_step = 0 + # flush() ### HOOK ### self.timer.start('train_loop') @@ -1144,17 +1184,24 @@ class BaseSDTrainProcess(BaseTrainProcess): # sets progress bar to match out step self.progress_bar.update(step - self.progress_bar.n) - # end of step - self.step_num = step - # flush every 10 steps - # if self.step_num % 10 == 0: - # flush() + ############################# + # End of step + ############################# + + # update various steps + self.step_num = step + 1 + self.grad_accumulation_step += 1 + + + ################################################################### + ## END TRAIN LOOP + ################################################################### self.progress_bar.close() if self.train_config.free_u: self.sd.pipeline.disable_freeu() - self.sample(self.step_num + 1) + self.sample(self.step_num) print("") self.save() diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index b307c3e0..71315b18 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -52,6 +52,7 @@ class LormModuleSettingsConfig: class LoRMConfig: def __init__(self, **kwargs): self.extract_mode: str = kwargs.get('extract_mode', 'ratio') + self.do_conv: bool = kwargs.get('do_conv', False) self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25) self.parameter_threshold: int = kwargs.get('parameter_threshold', 0) module_settings = kwargs.get('module_settings', []) @@ -110,6 +111,8 @@ class NetworkConfig: # set linear to arbitrary values so it makes them self.linear = 4 self.rank = 4 + if self.lorm_config.do_conv: + self.conv = 4 AdapterTypes = Literal['t2i', 'ip', 'ip+'] @@ -177,6 +180,10 @@ class TrainConfig: self.noise_multiplier = kwargs.get('noise_multiplier', 1.0) self.img_multiplier = kwargs.get('img_multiplier', 1.0) + # set to -1 to accumulate gradients for entire epoch + # warning, only do this with a small dataset or you will run out of memory + self.gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1) + # short long captions will double your batch size. This only works when a dataset is # prepared with a json caption file that has both short and long captions in it. It will # Double up every image and run it through with both short and long captions. The idea diff --git a/toolkit/lorm.py b/toolkit/lorm.py index 782f1a86..0a432838 100644 --- a/toolkit/lorm.py +++ b/toolkit/lorm.py @@ -191,6 +191,8 @@ def extract_conv( if lora_rank >= out_ch / 2: lora_rank = int(out_ch / 2) print(f"rank is higher than it should be") + # print(f"Skipping layer as determined rank is too high") + # return None, None, None, None # return weight, 'full' U = U[:, :lora_rank] @@ -243,6 +245,8 @@ def extract_linear( # print(f"rank is higher than it should be") lora_rank = int(out_ch / 2) # return weight, 'full' + # print(f"Skipping layer as determined rank is too high") + # return None, None, None, None U = U[:, :lora_rank] S = S[:lora_rank] @@ -358,6 +362,8 @@ def convert_diffusers_unet_to_lorm( mode_param=extract_mode_param, device=child_module.weight.device, ) + if down_weight is None: + continue down_weight = down_weight.to(dtype=dtype) up_weight = up_weight.to(dtype=dtype) bias_weight = None @@ -398,6 +404,8 @@ def convert_diffusers_unet_to_lorm( mode_param=extract_mode_param, device=child_module.weight.device, ) + if down_weight is None: + continue down_weight = down_weight.to(dtype=dtype) up_weight = up_weight.to(dtype=dtype) bias_weight = None