diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index a528fd45..a4b3cf28 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -72,6 +72,7 @@ class SDTrainer(BaseSDTrainProcess): # 9.18 gb noise = noise.to(self.device_torch, dtype=dtype).detach() + if self.sd.prediction_type == 'v_prediction': # v-parameterization training target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index cb1a6c23..5cfa75b6 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -479,9 +479,6 @@ class BaseSDTrainProcess(BaseTrainProcess): is_lycoris = False # default to LoCON if there are any conv layers or if it is named NetworkClass = LoRASpecialNetwork - if self.network_config.conv is not None and self.network_config.conv > 0: - NetworkClass = LycorisSpecialNetwork - is_lycoris = True if self.network_config.type.lower() == 'locon' or self.network_config.type.lower() == 'lycoris': NetworkClass = LycorisSpecialNetwork is_lycoris = True diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index e83c468b..2ef1d3aa 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -46,9 +46,14 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module): dropout=None, rank_dropout=None, module_dropout=None, + parent=None, + **kwargs ): """if alpha == 0 or None, alpha is rank (no scaling).""" - super().__init__() + super().__init__( + org_module=org_module, + parent=parent + ) self.lora_name = lora_name self.scalar = torch.tensor(1.0) @@ -256,6 +261,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): dropout=dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, + parent=module, ) loras.append(lora) return loras, skipped diff --git a/toolkit/lycoris_special.py b/toolkit/lycoris_special.py index c71de612..88e98ab3 100644 --- a/toolkit/lycoris_special.py +++ b/toolkit/lycoris_special.py @@ -29,17 +29,19 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule): lora_dim=4, alpha=1, dropout=0., rank_dropout=0., module_dropout=0., use_cp=False, + parent=None, **kwargs, ): """ if alpha == 0 or None, alpha is rank (no scaling). """ + # call super of super + torch.nn.Module.__init__(self) # call super of super().__init__( + org_module=org_module, call_super_init=False, + parent=parent, **kwargs ) - # call super of super - super(LoConModule, self).__init__() - self.lora_name = lora_name self.lora_dim = lora_dim self.cp = False @@ -163,7 +165,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): **kwargs ) # call the parent of the parent LycorisNetwork - super(LycorisNetwork, self).__init__() + torch.nn.Module.__init__(self) # LyCORIS unique stuff if dropout is None: @@ -176,8 +178,9 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): self.multiplier = multiplier self.lora_dim = lora_dim - if not self.ENABLE_CONV: + if not self.ENABLE_CONV or conv_lora_dim is None: conv_lora_dim = 0 + conv_alpha = 0 self.conv_lora_dim = int(conv_lora_dim) if self.conv_lora_dim and self.conv_lora_dim != self.lora_dim: @@ -231,6 +234,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): self.lora_dim, self.alpha, self.dropout, self.rank_dropout, self.module_dropout, use_cp, + parent=module, **kwargs ) elif child_module.__class__.__name__ in CONV_MODULES: @@ -241,6 +245,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): self.lora_dim, self.alpha, self.dropout, self.rank_dropout, self.module_dropout, use_cp, + parent=module, **kwargs ) elif conv_lora_dim > 0: @@ -249,6 +254,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): self.conv_lora_dim, self.conv_alpha, self.dropout, self.rank_dropout, self.module_dropout, use_cp, + parent=module, **kwargs ) else: @@ -269,6 +275,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): self.lora_dim, self.alpha, self.dropout, self.rank_dropout, self.module_dropout, use_cp, + parent=module, **kwargs ) elif module.__class__.__name__ == 'Conv2d': @@ -279,6 +286,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): self.lora_dim, self.alpha, self.dropout, self.rank_dropout, self.module_dropout, use_cp, + parent=module, **kwargs ) elif conv_lora_dim > 0: @@ -287,6 +295,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): self.conv_lora_dim, self.conv_alpha, self.dropout, self.rank_dropout, self.module_dropout, use_cp, + parent=module, **kwargs ) else: diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 7f314748..ea350542 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -18,6 +18,30 @@ if TYPE_CHECKING: Network = Union['LycorisSpecialNetwork', 'LoRASpecialNetwork'] Module = Union['LoConSpecialModule', 'LoRAModule'] +LINEAR_MODULES = [ + 'Linear', + 'LoRACompatibleLinear' + # 'GroupNorm', +] +CONV_MODULES = [ + 'Conv2d', + 'LoRACompatibleConv' +] + + +def broadcast_and_multiply(tensor, multiplier): + # Determine the number of dimensions required + num_extra_dims = tensor.dim() - multiplier.dim() + + # Unsqueezing the tensor to match the dimensionality + for _ in range(num_extra_dims): + multiplier = multiplier.unsqueeze(-1) + + # Multiplying the broadcasted tensor with the output tensor + result = tensor * multiplier + + return result + class ToolkitModuleMixin: def __init__( @@ -28,49 +52,41 @@ class ToolkitModuleMixin: ): if call_super_init: super().__init__(*args, **kwargs) - self.org_module: torch.nn.Module = kwargs.get('org_module', None) + self.tk_orig_module: torch.nn.Module = kwargs.get('org_module', None) + self.tk_orig_parent = kwargs.get('parent', None) self.is_checkpointing = False self.is_normalizing = False self.normalize_scaler = 1.0 + # see if is conv or linear + self.is_conv = False + self.is_linear = False + if self.tk_orig_module.__class__.__name__ in LINEAR_MODULES: + self.is_linear = True + elif self.tk_orig_module.__class__.__name__ in CONV_MODULES: + self.is_conv = True + self._multiplier: Union[float, list, torch.Tensor] = 1.0 # this allows us to set different multipliers on a per item in a batch basis # allowing us to run positive and negative weights in the same batch - # really only useful for slider training for now - def get_multiplier(self: Module, lora_up): + def set_multiplier(self: Module, multiplier): + device = self.lora_down.weight.device + dtype = self.lora_down.weight.dtype with torch.no_grad(): - batch_size = lora_up.size(0) - # batch will have all negative prompts first and positive prompts second - # our multiplier list is for a prompt pair. So we need to repeat it for positive and negative prompts - # if there is more than our multiplier, it is likely a batch size increase, so we need to - # interleave the multipliers - if isinstance(self.multiplier, list): - if len(self.multiplier) == 0: - # single item, just return it - return self.multiplier[0] - elif len(self.multiplier) == batch_size: - # not doing CFG - multiplier_tensor = torch.tensor(self.multiplier).to(lora_up.device, dtype=lora_up.dtype) - else: + tensor_multiplier = None + if isinstance(multiplier, int) or isinstance(multiplier, float): + tensor_multiplier = torch.tensor((multiplier,)).to(device, dtype=dtype) + elif isinstance(multiplier, list): + tensor_list = [] + for m in multiplier: + if isinstance(m, int) or isinstance(m, float): + tensor_list.append(torch.tensor((m,)).to(device, dtype=dtype)) + elif isinstance(m, torch.Tensor): + tensor_list.append(m.clone().detach().to(device, dtype=dtype)) + tensor_multiplier = torch.cat(tensor_list) + elif isinstance(multiplier, torch.Tensor): + tensor_multiplier = multiplier.clone().detach().to(device, dtype=dtype) - # we have a list of multipliers, so we need to get the multiplier for this batch - multiplier_tensor = torch.tensor(self.multiplier * 2).to(lora_up.device, dtype=lora_up.dtype) - # should be 1 for if total batch size was 1 - num_interleaves = (batch_size // 2) // len(self.multiplier) - multiplier_tensor = multiplier_tensor.repeat_interleave(num_interleaves) - - # match lora_up rank - if len(lora_up.size()) == 2: - multiplier_tensor = multiplier_tensor.view(-1, 1) - elif len(lora_up.size()) == 3: - multiplier_tensor = multiplier_tensor.view(-1, 1, 1) - elif len(lora_up.size()) == 4: - multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1) - return multiplier_tensor.detach() - - else: - if isinstance(self.multiplier, torch.Tensor): - return self.multiplier.detach() - return self.multiplier + self._multiplier = tensor_multiplier.clone().detach() def _call_forward(self: Module, x): # module dropout @@ -111,15 +127,26 @@ class ToolkitModuleMixin: # handle trainable scaler method locon does if hasattr(self, 'scalar'): - scale *= self.scalar + scale = scale * self.scalar return lx * scale def forward(self: Module, x): - x = x.detach() + org_forwarded = self.org_forward(x) lora_output = self._call_forward(x) - multiplier = self.get_multiplier(lora_output) + multiplier = self._multiplier.clone().detach() + + lora_output_batch_size = lora_output.size(0) + multiplier_batch_size = multiplier.size(0) + if lora_output_batch_size != multiplier_batch_size: + print( + f"Warning: lora_output_batch_size {lora_output_batch_size} != multiplier_batch_size {multiplier_batch_size}") + # doing cfg + # should be 1 for if total batch size was 1 + num_interleaves = (lora_output_batch_size // 2) // multiplier_batch_size + multiplier = multiplier.repeat_interleave(num_interleaves) + # multiplier = 1.0 if self.is_normalizing: with torch.no_grad(): @@ -150,9 +177,9 @@ class ToolkitModuleMixin: # save the scaler so it can be applied later self.normalize_scaler = normalize_scaler.clone().detach() - lora_output *= normalize_scaler + lora_output = lora_output * normalize_scaler - return org_forwarded + (lora_output * multiplier) + return org_forwarded + broadcast_and_multiply(lora_output, multiplier) def enable_gradient_checkpointing(self: Module): self.is_checkpointing = True @@ -320,19 +347,11 @@ class ToolkitNetworkMixin: def _update_lora_multiplier(self: Network): if self.is_active: - if hasattr(self, 'unet_loras'): - for lora in self.unet_loras: - lora.multiplier = self._multiplier - if hasattr(self, 'text_encoder_loras'): - for lora in self.text_encoder_loras: - lora.multiplier = self._multiplier + for lora in self.get_all_modules(): + lora.set_multiplier(self._multiplier) else: - if hasattr(self, 'unet_loras'): - for lora in self.unet_loras: - lora.multiplier = 0 - if hasattr(self, 'text_encoder_loras'): - for lora in self.text_encoder_loras: - lora.multiplier = 0 + for lora in self.get_all_modules(): + lora.set_multiplier(0) # called when the context manager is entered # ie: with network: @@ -369,15 +388,15 @@ class ToolkitNetworkMixin: else: module.disable_gradient_checkpointing() - # def enable_gradient_checkpointing(self: Network): - # # not supported - # self.is_checkpointing = True - # self._update_checkpointing() - # - # def disable_gradient_checkpointing(self: Network): - # # not supported - # self.is_checkpointing = False - # self._update_checkpointing() + def enable_gradient_checkpointing(self: Network): + # not supported + self.is_checkpointing = True + self._update_checkpointing() + + def disable_gradient_checkpointing(self: Network): + # not supported + self.is_checkpointing = False + self._update_checkpointing() @property def is_normalizing(self: Network) -> bool: