From 17e4fe40d722d8f91c759f8997b05fbcd00fb450 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 14 Sep 2023 15:13:24 -0600 Subject: [PATCH] Prevent lycoris network moduels if not training that part of network. Skew timesteps to favor later steps. It performs better --- jobs/process/BaseSDTrainProcess.py | 13 +++++++++++++ toolkit/lycoris_special.py | 25 ++++++++++++++++--------- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 872b7305..5d35ffb3 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -437,6 +437,19 @@ class BaseSDTrainProcess(BaseTrainProcess): else: min_timestep = self.train_config.min_denoising_steps + # todo improve this, but is skews odds for higher timesteps + # 50% chance to use midpoint as the min_time_step + mid_point = (self.train_config.max_denoising_steps + min_timestep) / 2 + if torch.rand(1) > 0.5: + min_timestep = mid_point + + # 50% chance to use midpoint as the min_time_step + mid_point = (self.train_config.max_denoising_steps + min_timestep) / 2 + if torch.rand(1) > 0.5: + min_timestep = mid_point + + min_timestep = int(min_timestep) + timesteps = torch.randint( min_timestep, self.train_config.max_denoising_steps, diff --git a/toolkit/lycoris_special.py b/toolkit/lycoris_special.py index 41bdf719..05df65b3 100644 --- a/toolkit/lycoris_special.py +++ b/toolkit/lycoris_special.py @@ -154,6 +154,8 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): conv_alpha: Optional[float] = None, use_cp: Optional[bool] = False, network_module: Type[object] = LoConSpecialModule, + train_unet: bool = True, + train_text_encoder: bool = True, **kwargs, ) -> None: # call ToolkitNetworkMixin super @@ -170,6 +172,8 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): rank_dropout = 0 if module_dropout is None: module_dropout = 0 + self.train_unet = train_unet + self.train_text_encoder = train_text_encoder self.torch_multiplier = None # triggers a tensor update @@ -326,16 +330,19 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): use_index = False self.text_encoder_loras = [] - for i, te in enumerate(text_encoders): - self.text_encoder_loras.extend(create_modules( - LycorisSpecialNetwork.LORA_PREFIX_TEXT_ENCODER + (f'{i + 1}' if use_index else ''), - te, - LycorisSpecialNetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE - )) + if self.train_text_encoder: + for i, te in enumerate(text_encoders): + self.text_encoder_loras.extend(create_modules( + LycorisSpecialNetwork.LORA_PREFIX_TEXT_ENCODER + (f'{i + 1}' if use_index else ''), + te, + LycorisSpecialNetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE + )) print(f"create LyCORIS for Text Encoder: {len(self.text_encoder_loras)} modules.") - - self.unet_loras = create_modules(LycorisSpecialNetwork.LORA_PREFIX_UNET, unet, - LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE) + if self.train_unet: + self.unet_loras = create_modules(LycorisSpecialNetwork.LORA_PREFIX_UNET, unet, + LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE) + else: + self.unet_loras = [] print(f"create LyCORIS for U-Net: {len(self.unet_loras)} modules.") self.weights_sd = None