diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index dbe1010c..9bb0a171 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -208,7 +208,9 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.network is not None: prev_multiplier = self.network.multiplier self.network.multiplier = 1.0 - # TODO handle dreambooth, fine tuning, etc + if self.network_config.normalize: + # apply the normalization + self.network.apply_stored_normalizer() self.network.save_weights( file_path, dtype=get_torch_dtype(self.save_config.dtype), @@ -323,7 +325,6 @@ class BaseSDTrainProcess(BaseTrainProcess): imgs = imgs.to(self.device_torch, dtype=dtype) latents = self.sd.encode_images(imgs) - self.sd.noise_scheduler.set_timesteps( self.train_config.max_denoising_steps, device=self.device_torch ) @@ -429,6 +430,9 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.train_config.gradient_checkpointing: self.network.enable_gradient_checkpointing() + # set the network to normalize if we are + self.network.is_normalizing = self.network_config.normalize + latest_save_path = self.get_latest_save_path() if latest_save_path is not None: self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") @@ -522,71 +526,84 @@ class BaseSDTrainProcess(BaseTrainProcess): dataloader_reg = None dataloader_iterator_reg = None + # zero any gradients + optimizer.zero_grad() + # self.step_num = 0 for step in range(self.step_num, self.train_config.steps): - # if is even step and we have a reg dataset, use that - # todo improve this logic to send one of each through if we can buckets and batch size might be an issue - if step % 2 == 0 and dataloader_reg is not None: - try: - batch = next(dataloader_iterator_reg) - except StopIteration: - # hit the end of an epoch, reset - dataloader_iterator_reg = iter(dataloader_reg) - batch = next(dataloader_iterator_reg) - elif dataloader is not None: - try: - batch = next(dataloader_iterator) - except StopIteration: - # hit the end of an epoch, reset - dataloader_iterator = iter(dataloader) - batch = next(dataloader_iterator) - else: - batch = None + with torch.no_grad(): + # if is even step and we have a reg dataset, use that + # todo improve this logic to send one of each through if we can buckets and batch size might be an issue + if step % 2 == 0 and dataloader_reg is not None: + try: + batch = next(dataloader_iterator_reg) + except StopIteration: + # hit the end of an epoch, reset + dataloader_iterator_reg = iter(dataloader_reg) + batch = next(dataloader_iterator_reg) + elif dataloader is not None: + try: + batch = next(dataloader_iterator) + except StopIteration: + # hit the end of an epoch, reset + dataloader_iterator = iter(dataloader) + batch = next(dataloader_iterator) + else: + batch = None + + # turn on normalization if we are using it and it is not on + if self.network is not None and self.network_config.normalize and not self.network.is_normalizing: + self.network.is_normalizing = True ### HOOK ### loss_dict = self.hook_train_loop(batch) flush() - if self.train_config.optimizer.lower().startswith('dadaptation') or \ - self.train_config.optimizer.lower().startswith('prodigy'): - learning_rate = ( - optimizer.param_groups[0]["d"] * - optimizer.param_groups[0]["lr"] - ) - else: - learning_rate = optimizer.param_groups[0]['lr'] + with torch.no_grad(): + if self.train_config.optimizer.lower().startswith('dadaptation') or \ + self.train_config.optimizer.lower().startswith('prodigy'): + learning_rate = ( + optimizer.param_groups[0]["d"] * + optimizer.param_groups[0]["lr"] + ) + else: + learning_rate = optimizer.param_groups[0]['lr'] - prog_bar_string = f"lr: {learning_rate:.1e}" - for key, value in loss_dict.items(): - prog_bar_string += f" {key}: {value:.3e}" + prog_bar_string = f"lr: {learning_rate:.1e}" + for key, value in loss_dict.items(): + prog_bar_string += f" {key}: {value:.3e}" - self.progress_bar.set_postfix_str(prog_bar_string) + self.progress_bar.set_postfix_str(prog_bar_string) - # don't do on first step - if self.step_num != self.start_step: - # pause progress bar - self.progress_bar.unpause() # makes it so doesn't track time - if self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0: - # print above the progress bar - self.sample(self.step_num) + # don't do on first step + if self.step_num != self.start_step: + # pause progress bar + self.progress_bar.unpause() # makes it so doesn't track time + if self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0: + # print above the progress bar + self.sample(self.step_num) - if self.save_config.save_every and self.step_num % self.save_config.save_every == 0: - # print above the progress bar - self.print(f"Saving at step {self.step_num}") - self.save(self.step_num) + if self.save_config.save_every and self.step_num % self.save_config.save_every == 0: + # print above the progress bar + self.print(f"Saving at step {self.step_num}") + self.save(self.step_num) - if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0: - # log to tensorboard - if self.writer is not None: - for key, value in loss_dict.items(): - self.writer.add_scalar(f"{key}", value, self.step_num) - self.writer.add_scalar(f"lr", learning_rate, self.step_num) - self.progress_bar.refresh() + if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0: + # log to tensorboard + if self.writer is not None: + for key, value in loss_dict.items(): + self.writer.add_scalar(f"{key}", value, self.step_num) + self.writer.add_scalar(f"lr", learning_rate, self.step_num) + self.progress_bar.refresh() - # sets progress bar to match out step - self.progress_bar.update(step - self.progress_bar.n) - # end of step - self.step_num = step + # sets progress bar to match out step + self.progress_bar.update(step - self.progress_bar.n) + # end of step + self.step_num = step + + # apply network normalizer if we are using it + if self.network is not None and self.network.is_normalizing: + self.network.apply_stored_normalizer() self.sample(self.step_num + 1) print("") diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 6d90be9d..5005bee4 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -48,6 +48,7 @@ class NetworkConfig: self.alpha: float = kwargs.get('alpha', 1.0) self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha) self.conv_alpha: float = kwargs.get('conv_alpha', self.conv) + self.normalize = kwargs.get('normalize', False) class EmbeddingConfig: diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 82b39e11..7234551b 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -95,34 +95,40 @@ class BucketsMixin: # the other dimension should be the same ratio it is now (bigger) new_width = resolution new_height = resolution - new_x = file_item.crop_x - new_y = file_item.crop_y if width > height: # scale width to match new resolution, new_width = int(width * (resolution / height)) + file_item.crop_width = new_width + file_item.scale_to_width = new_width + file_item.crop_height = resolution + file_item.scale_to_height = resolution # make sure new_width is divisible by bucket_tolerance if new_width % bucket_tolerance != 0: # reduce it to the nearest divisible number reduction = new_width % bucket_tolerance - new_width = new_width - reduction + file_item.crop_width = new_width - reduction # adjust the new x position so we evenly crop - new_x = int(new_x + (reduction / 2)) + file_item.crop_x = int(file_item.crop_x + (reduction / 2)) elif height > width: # scale height to match new resolution new_height = int(height * (resolution / width)) + file_item.crop_height = new_height + file_item.scale_to_height = new_height + file_item.scale_to_width = resolution + file_item.crop_width = resolution # make sure new_height is divisible by bucket_tolerance if new_height % bucket_tolerance != 0: # reduce it to the nearest divisible number reduction = new_height % bucket_tolerance - new_height = new_height - reduction + file_item.crop_height = new_height - reduction # adjust the new x position so we evenly crop - new_y = int(new_y + (reduction / 2)) - - # add info to file - file_item.crop_x = new_x - file_item.crop_y = new_y - file_item.crop_width = new_width - file_item.crop_height = new_height + file_item.crop_y = int(file_item.crop_y + (reduction / 2)) + else: + # square image + file_item.crop_height = resolution + file_item.scale_to_height = resolution + file_item.scale_to_width = resolution + file_item.crop_width = resolution # check if bucket exists, if not, create it bucket_key = f'{new_width}x{new_height}' diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 79ece460..f13c9e43 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -8,6 +8,7 @@ import torch from transformers import CLIPTextModel from .paths import SD_SCRIPTS_ROOT +from .train_tools import get_torch_dtype sys.path.append(SD_SCRIPTS_ROOT) @@ -78,6 +79,8 @@ class LoRAModule(torch.nn.Module): self.rank_dropout = rank_dropout self.module_dropout = module_dropout self.is_checkpointing = False + self.is_normalizing = False + self.normalize_scaler = 1.0 def apply_to(self): self.org_forward = self.org_module.forward @@ -91,8 +94,8 @@ class LoRAModule(torch.nn.Module): batch_size = lora_up.size(0) # batch will have all negative prompts first and positive prompts second # our multiplier list is for a prompt pair. So we need to repeat it for positive and negative prompts - # if there is more than our multiplier, it is liekly a batch size increase, so we need to - # interleve the multipliers + # if there is more than our multiplier, it is likely a batch size increase, so we need to + # interleave the multipliers if isinstance(self.multiplier, list): if len(self.multiplier) == 0: # single item, just return it @@ -153,25 +156,30 @@ class LoRAModule(torch.nn.Module): return lx * multiplier * scale - def create_custom_forward(self): - def custom_forward(*inputs): - return self._call_forward(*inputs) - - return custom_forward - def forward(self, x): org_forwarded = self.org_forward(x) - # TODO this just loses the grad. Not sure why. Probably why no one else is doing it either - # if torch.is_grad_enabled() and self.is_checkpointing and self.training: - # lora_output = checkpoint( - # self.create_custom_forward(), - # x, - # ) - # else: - # lora_output = self._call_forward(x) - lora_output = self._call_forward(x) + if self.is_normalizing: + # get a dim array from orig forward that had index of all dimensions except the batch and channel + + # Calculate the target magnitude for the combined output + orig_max = torch.max(torch.abs(org_forwarded)) + + # Calculate the additional increase in magnitude that lora_output would introduce + potential_max_increase = torch.max(torch.abs(org_forwarded + lora_output) - torch.abs(org_forwarded)) + + epsilon = 1e-6 # Small constant to avoid division by zero + + # Calculate the scaling factor for the lora_output + # to ensure that the potential increase in magnitude doesn't change the original max + normalize_scaler = orig_max / (orig_max + potential_max_increase + epsilon) + + # save the scaler so it can be applied later + self.normalize_scaler = normalize_scaler.clone().detach() + + lora_output *= normalize_scaler + return org_forwarded + lora_output def enable_gradient_checkpointing(self): @@ -180,11 +188,39 @@ class LoRAModule(torch.nn.Module): def disable_gradient_checkpointing(self): self.is_checkpointing = False + @torch.no_grad() + def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0): + """ + Applied the previous normalization calculation to the module. + This must be called before saving or normalization will be lost. + It is probably best to call after each batch as well. + We just scale the up down weights to match this vector + :return: + """ + # get state dict + state_dict = self.state_dict() + dtype = state_dict['lora_up.weight'].dtype + device = state_dict['lora_up.weight'].device + + # todo should we do this at fp32? + + total_module_scale = torch.tensor(self.normalize_scaler / target_normalize_scaler) \ + .to(device, dtype=dtype) + num_modules_layers = 2 # up and down + up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \ + .to(device, dtype=dtype) + + # apply the scaler to the up and down weights + for key in state_dict.keys(): + if key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'): + # do it inplace do params are updated + state_dict[key] *= up_down_scale + + # reset the normalization scaler + self.normalize_scaler = target_normalize_scaler + class LoRASpecialNetwork(LoRANetwork): - _multiplier: float = 1.0 - is_active: bool = False - NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] @@ -230,7 +266,6 @@ class LoRASpecialNetwork(LoRANetwork): """ # call the parent of the parent we are replacing (LoRANetwork) init super(LoRANetwork, self).__init__() - self.multiplier = multiplier self.lora_dim = lora_dim self.alpha = alpha @@ -240,6 +275,11 @@ class LoRASpecialNetwork(LoRANetwork): self.rank_dropout = rank_dropout self.module_dropout = module_dropout self.is_checkpointing = False + self._multiplier: float = 1.0 + self.is_active: bool = False + self._is_normalizing: bool = False + # triggers the state updates + self.multiplier = multiplier if modules_dim is not None: print(f"create LoRA network from weights") @@ -451,21 +491,20 @@ class LoRASpecialNetwork(LoRANetwork): for lora in loras: lora.to(device, dtype) + def get_all_modules(self): + loras = [] + if hasattr(self, 'unet_loras'): + loras += self.unet_loras + if hasattr(self, 'text_encoder_loras'): + loras += self.text_encoder_loras + return loras + def _update_checkpointing(self): - if self.is_checkpointing: - if hasattr(self, 'unet_loras'): - for lora in self.unet_loras: - lora.enable_gradient_checkpointing() - if hasattr(self, 'text_encoder_loras'): - for lora in self.text_encoder_loras: - lora.enable_gradient_checkpointing() - else: - if hasattr(self, 'unet_loras'): - for lora in self.unet_loras: - lora.disable_gradient_checkpointing() - if hasattr(self, 'text_encoder_loras'): - for lora in self.text_encoder_loras: - lora.disable_gradient_checkpointing() + for module in self.get_all_modules(): + if self.is_checkpointing: + module.enable_gradient_checkpointing() + else: + module.disable_gradient_checkpointing() def enable_gradient_checkpointing(self): # not supported @@ -476,3 +515,17 @@ class LoRASpecialNetwork(LoRANetwork): # not supported self.is_checkpointing = False self._update_checkpointing() + + @property + def is_normalizing(self) -> bool: + return self._is_normalizing + + @is_normalizing.setter + def is_normalizing(self, value: bool): + self._is_normalizing = value + for module in self.get_all_modules(): + module.is_normalizing = self._is_normalizing + + def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0): + for module in self.get_all_modules(): + module.apply_stored_normalizer(target_normalize_scaler) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 704d3c00..5794a575 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -38,10 +38,13 @@ SD_PREFIX_TEXT_ENCODER2 = "te2" class BlankNetwork: - multiplier = 1.0 - is_active = True def __init__(self): + self.multiplier = 1.0 + self.is_active = True + self.is_normalizing = False + + def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0): pass def __enter__(self): @@ -258,6 +261,12 @@ class StableDiffusion: else: network = BlankNetwork() + was_network_normalizing = network.is_normalizing + # apply the normalizer if it is normalizing before inference and disable it + if network.is_normalizing: + network.apply_stored_normalizer() + network.is_normalizing = False + # save current seed state for training rng_state = torch.get_rng_state() cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None @@ -377,6 +386,7 @@ class StableDiffusion: if self.network is not None: self.network.train() self.network.multiplier = start_multiplier + self.network.is_normalizing = was_network_normalizing # self.tokenizer.to(original_device_dict['tokenizer']) def get_latent_noise(