From 9b164a8688abc206c57f7c90d011e7ce7f3f60ba Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 27 Aug 2023 09:40:01 -0600 Subject: [PATCH] Fixed issue with bucket dataloader corpping in too much. Added normalization capabilities to LoRA modules. Testing effects, but should prevent them from burning and also make them more compatable with stacking many LoRAs --- jobs/process/BaseSDTrainProcess.py | 125 ++++++++++++++++------------- toolkit/config_modules.py | 1 + toolkit/dataloader_mixins.py | 30 ++++--- toolkit/lora_special.py | 123 ++++++++++++++++++++-------- toolkit/stable_diffusion_model.py | 14 +++- 5 files changed, 190 insertions(+), 103 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index dbe1010c..9bb0a171 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -208,7 +208,9 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.network is not None: prev_multiplier = self.network.multiplier self.network.multiplier = 1.0 - # TODO handle dreambooth, fine tuning, etc + if self.network_config.normalize: + # apply the normalization + self.network.apply_stored_normalizer() self.network.save_weights( file_path, dtype=get_torch_dtype(self.save_config.dtype), @@ -323,7 +325,6 @@ class BaseSDTrainProcess(BaseTrainProcess): imgs = imgs.to(self.device_torch, dtype=dtype) latents = self.sd.encode_images(imgs) - self.sd.noise_scheduler.set_timesteps( self.train_config.max_denoising_steps, device=self.device_torch ) @@ -429,6 +430,9 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.train_config.gradient_checkpointing: self.network.enable_gradient_checkpointing() + # set the network to normalize if we are + self.network.is_normalizing = self.network_config.normalize + latest_save_path = self.get_latest_save_path() if latest_save_path is not None: self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") @@ -522,71 +526,84 @@ class BaseSDTrainProcess(BaseTrainProcess): dataloader_reg = None dataloader_iterator_reg = None + # zero any gradients + optimizer.zero_grad() + # self.step_num = 0 for step in range(self.step_num, self.train_config.steps): - # if is even step and we have a reg dataset, use that - # todo improve this logic to send one of each through if we can buckets and batch size might be an issue - if step % 2 == 0 and dataloader_reg is not None: - try: - batch = next(dataloader_iterator_reg) - except StopIteration: - # hit the end of an epoch, reset - dataloader_iterator_reg = iter(dataloader_reg) - batch = next(dataloader_iterator_reg) - elif dataloader is not None: - try: - batch = next(dataloader_iterator) - except StopIteration: - # hit the end of an epoch, reset - dataloader_iterator = iter(dataloader) - batch = next(dataloader_iterator) - else: - batch = None + with torch.no_grad(): + # if is even step and we have a reg dataset, use that + # todo improve this logic to send one of each through if we can buckets and batch size might be an issue + if step % 2 == 0 and dataloader_reg is not None: + try: + batch = next(dataloader_iterator_reg) + except StopIteration: + # hit the end of an epoch, reset + dataloader_iterator_reg = iter(dataloader_reg) + batch = next(dataloader_iterator_reg) + elif dataloader is not None: + try: + batch = next(dataloader_iterator) + except StopIteration: + # hit the end of an epoch, reset + dataloader_iterator = iter(dataloader) + batch = next(dataloader_iterator) + else: + batch = None + + # turn on normalization if we are using it and it is not on + if self.network is not None and self.network_config.normalize and not self.network.is_normalizing: + self.network.is_normalizing = True ### HOOK ### loss_dict = self.hook_train_loop(batch) flush() - if self.train_config.optimizer.lower().startswith('dadaptation') or \ - self.train_config.optimizer.lower().startswith('prodigy'): - learning_rate = ( - optimizer.param_groups[0]["d"] * - optimizer.param_groups[0]["lr"] - ) - else: - learning_rate = optimizer.param_groups[0]['lr'] + with torch.no_grad(): + if self.train_config.optimizer.lower().startswith('dadaptation') or \ + self.train_config.optimizer.lower().startswith('prodigy'): + learning_rate = ( + optimizer.param_groups[0]["d"] * + optimizer.param_groups[0]["lr"] + ) + else: + learning_rate = optimizer.param_groups[0]['lr'] - prog_bar_string = f"lr: {learning_rate:.1e}" - for key, value in loss_dict.items(): - prog_bar_string += f" {key}: {value:.3e}" + prog_bar_string = f"lr: {learning_rate:.1e}" + for key, value in loss_dict.items(): + prog_bar_string += f" {key}: {value:.3e}" - self.progress_bar.set_postfix_str(prog_bar_string) + self.progress_bar.set_postfix_str(prog_bar_string) - # don't do on first step - if self.step_num != self.start_step: - # pause progress bar - self.progress_bar.unpause() # makes it so doesn't track time - if self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0: - # print above the progress bar - self.sample(self.step_num) + # don't do on first step + if self.step_num != self.start_step: + # pause progress bar + self.progress_bar.unpause() # makes it so doesn't track time + if self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0: + # print above the progress bar + self.sample(self.step_num) - if self.save_config.save_every and self.step_num % self.save_config.save_every == 0: - # print above the progress bar - self.print(f"Saving at step {self.step_num}") - self.save(self.step_num) + if self.save_config.save_every and self.step_num % self.save_config.save_every == 0: + # print above the progress bar + self.print(f"Saving at step {self.step_num}") + self.save(self.step_num) - if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0: - # log to tensorboard - if self.writer is not None: - for key, value in loss_dict.items(): - self.writer.add_scalar(f"{key}", value, self.step_num) - self.writer.add_scalar(f"lr", learning_rate, self.step_num) - self.progress_bar.refresh() + if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0: + # log to tensorboard + if self.writer is not None: + for key, value in loss_dict.items(): + self.writer.add_scalar(f"{key}", value, self.step_num) + self.writer.add_scalar(f"lr", learning_rate, self.step_num) + self.progress_bar.refresh() - # sets progress bar to match out step - self.progress_bar.update(step - self.progress_bar.n) - # end of step - self.step_num = step + # sets progress bar to match out step + self.progress_bar.update(step - self.progress_bar.n) + # end of step + self.step_num = step + + # apply network normalizer if we are using it + if self.network is not None and self.network.is_normalizing: + self.network.apply_stored_normalizer() self.sample(self.step_num + 1) print("") diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 6d90be9d..5005bee4 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -48,6 +48,7 @@ class NetworkConfig: self.alpha: float = kwargs.get('alpha', 1.0) self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha) self.conv_alpha: float = kwargs.get('conv_alpha', self.conv) + self.normalize = kwargs.get('normalize', False) class EmbeddingConfig: diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 82b39e11..7234551b 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -95,34 +95,40 @@ class BucketsMixin: # the other dimension should be the same ratio it is now (bigger) new_width = resolution new_height = resolution - new_x = file_item.crop_x - new_y = file_item.crop_y if width > height: # scale width to match new resolution, new_width = int(width * (resolution / height)) + file_item.crop_width = new_width + file_item.scale_to_width = new_width + file_item.crop_height = resolution + file_item.scale_to_height = resolution # make sure new_width is divisible by bucket_tolerance if new_width % bucket_tolerance != 0: # reduce it to the nearest divisible number reduction = new_width % bucket_tolerance - new_width = new_width - reduction + file_item.crop_width = new_width - reduction # adjust the new x position so we evenly crop - new_x = int(new_x + (reduction / 2)) + file_item.crop_x = int(file_item.crop_x + (reduction / 2)) elif height > width: # scale height to match new resolution new_height = int(height * (resolution / width)) + file_item.crop_height = new_height + file_item.scale_to_height = new_height + file_item.scale_to_width = resolution + file_item.crop_width = resolution # make sure new_height is divisible by bucket_tolerance if new_height % bucket_tolerance != 0: # reduce it to the nearest divisible number reduction = new_height % bucket_tolerance - new_height = new_height - reduction + file_item.crop_height = new_height - reduction # adjust the new x position so we evenly crop - new_y = int(new_y + (reduction / 2)) - - # add info to file - file_item.crop_x = new_x - file_item.crop_y = new_y - file_item.crop_width = new_width - file_item.crop_height = new_height + file_item.crop_y = int(file_item.crop_y + (reduction / 2)) + else: + # square image + file_item.crop_height = resolution + file_item.scale_to_height = resolution + file_item.scale_to_width = resolution + file_item.crop_width = resolution # check if bucket exists, if not, create it bucket_key = f'{new_width}x{new_height}' diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 79ece460..f13c9e43 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -8,6 +8,7 @@ import torch from transformers import CLIPTextModel from .paths import SD_SCRIPTS_ROOT +from .train_tools import get_torch_dtype sys.path.append(SD_SCRIPTS_ROOT) @@ -78,6 +79,8 @@ class LoRAModule(torch.nn.Module): self.rank_dropout = rank_dropout self.module_dropout = module_dropout self.is_checkpointing = False + self.is_normalizing = False + self.normalize_scaler = 1.0 def apply_to(self): self.org_forward = self.org_module.forward @@ -91,8 +94,8 @@ class LoRAModule(torch.nn.Module): batch_size = lora_up.size(0) # batch will have all negative prompts first and positive prompts second # our multiplier list is for a prompt pair. So we need to repeat it for positive and negative prompts - # if there is more than our multiplier, it is liekly a batch size increase, so we need to - # interleve the multipliers + # if there is more than our multiplier, it is likely a batch size increase, so we need to + # interleave the multipliers if isinstance(self.multiplier, list): if len(self.multiplier) == 0: # single item, just return it @@ -153,25 +156,30 @@ class LoRAModule(torch.nn.Module): return lx * multiplier * scale - def create_custom_forward(self): - def custom_forward(*inputs): - return self._call_forward(*inputs) - - return custom_forward - def forward(self, x): org_forwarded = self.org_forward(x) - # TODO this just loses the grad. Not sure why. Probably why no one else is doing it either - # if torch.is_grad_enabled() and self.is_checkpointing and self.training: - # lora_output = checkpoint( - # self.create_custom_forward(), - # x, - # ) - # else: - # lora_output = self._call_forward(x) - lora_output = self._call_forward(x) + if self.is_normalizing: + # get a dim array from orig forward that had index of all dimensions except the batch and channel + + # Calculate the target magnitude for the combined output + orig_max = torch.max(torch.abs(org_forwarded)) + + # Calculate the additional increase in magnitude that lora_output would introduce + potential_max_increase = torch.max(torch.abs(org_forwarded + lora_output) - torch.abs(org_forwarded)) + + epsilon = 1e-6 # Small constant to avoid division by zero + + # Calculate the scaling factor for the lora_output + # to ensure that the potential increase in magnitude doesn't change the original max + normalize_scaler = orig_max / (orig_max + potential_max_increase + epsilon) + + # save the scaler so it can be applied later + self.normalize_scaler = normalize_scaler.clone().detach() + + lora_output *= normalize_scaler + return org_forwarded + lora_output def enable_gradient_checkpointing(self): @@ -180,11 +188,39 @@ class LoRAModule(torch.nn.Module): def disable_gradient_checkpointing(self): self.is_checkpointing = False + @torch.no_grad() + def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0): + """ + Applied the previous normalization calculation to the module. + This must be called before saving or normalization will be lost. + It is probably best to call after each batch as well. + We just scale the up down weights to match this vector + :return: + """ + # get state dict + state_dict = self.state_dict() + dtype = state_dict['lora_up.weight'].dtype + device = state_dict['lora_up.weight'].device + + # todo should we do this at fp32? + + total_module_scale = torch.tensor(self.normalize_scaler / target_normalize_scaler) \ + .to(device, dtype=dtype) + num_modules_layers = 2 # up and down + up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \ + .to(device, dtype=dtype) + + # apply the scaler to the up and down weights + for key in state_dict.keys(): + if key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'): + # do it inplace do params are updated + state_dict[key] *= up_down_scale + + # reset the normalization scaler + self.normalize_scaler = target_normalize_scaler + class LoRASpecialNetwork(LoRANetwork): - _multiplier: float = 1.0 - is_active: bool = False - NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] @@ -230,7 +266,6 @@ class LoRASpecialNetwork(LoRANetwork): """ # call the parent of the parent we are replacing (LoRANetwork) init super(LoRANetwork, self).__init__() - self.multiplier = multiplier self.lora_dim = lora_dim self.alpha = alpha @@ -240,6 +275,11 @@ class LoRASpecialNetwork(LoRANetwork): self.rank_dropout = rank_dropout self.module_dropout = module_dropout self.is_checkpointing = False + self._multiplier: float = 1.0 + self.is_active: bool = False + self._is_normalizing: bool = False + # triggers the state updates + self.multiplier = multiplier if modules_dim is not None: print(f"create LoRA network from weights") @@ -451,21 +491,20 @@ class LoRASpecialNetwork(LoRANetwork): for lora in loras: lora.to(device, dtype) + def get_all_modules(self): + loras = [] + if hasattr(self, 'unet_loras'): + loras += self.unet_loras + if hasattr(self, 'text_encoder_loras'): + loras += self.text_encoder_loras + return loras + def _update_checkpointing(self): - if self.is_checkpointing: - if hasattr(self, 'unet_loras'): - for lora in self.unet_loras: - lora.enable_gradient_checkpointing() - if hasattr(self, 'text_encoder_loras'): - for lora in self.text_encoder_loras: - lora.enable_gradient_checkpointing() - else: - if hasattr(self, 'unet_loras'): - for lora in self.unet_loras: - lora.disable_gradient_checkpointing() - if hasattr(self, 'text_encoder_loras'): - for lora in self.text_encoder_loras: - lora.disable_gradient_checkpointing() + for module in self.get_all_modules(): + if self.is_checkpointing: + module.enable_gradient_checkpointing() + else: + module.disable_gradient_checkpointing() def enable_gradient_checkpointing(self): # not supported @@ -476,3 +515,17 @@ class LoRASpecialNetwork(LoRANetwork): # not supported self.is_checkpointing = False self._update_checkpointing() + + @property + def is_normalizing(self) -> bool: + return self._is_normalizing + + @is_normalizing.setter + def is_normalizing(self, value: bool): + self._is_normalizing = value + for module in self.get_all_modules(): + module.is_normalizing = self._is_normalizing + + def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0): + for module in self.get_all_modules(): + module.apply_stored_normalizer(target_normalize_scaler) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 704d3c00..5794a575 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -38,10 +38,13 @@ SD_PREFIX_TEXT_ENCODER2 = "te2" class BlankNetwork: - multiplier = 1.0 - is_active = True def __init__(self): + self.multiplier = 1.0 + self.is_active = True + self.is_normalizing = False + + def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0): pass def __enter__(self): @@ -258,6 +261,12 @@ class StableDiffusion: else: network = BlankNetwork() + was_network_normalizing = network.is_normalizing + # apply the normalizer if it is normalizing before inference and disable it + if network.is_normalizing: + network.apply_stored_normalizer() + network.is_normalizing = False + # save current seed state for training rng_state = torch.get_rng_state() cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None @@ -377,6 +386,7 @@ class StableDiffusion: if self.network is not None: self.network.train() self.network.multiplier = start_multiplier + self.network.is_normalizing = was_network_normalizing # self.tokenizer.to(original_device_dict['tokenizer']) def get_latent_noise(