Prevent lycoris network moduels if not training that part of network. Skew timesteps to favor later steps. It performs better

This commit is contained in:
Jaret Burkett
2023-09-14 15:13:24 -06:00
parent 569d7464d5
commit 17e4fe40d7
2 changed files with 29 additions and 9 deletions

View File

@@ -437,6 +437,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
else:
min_timestep = self.train_config.min_denoising_steps
# todo improve this, but is skews odds for higher timesteps
# 50% chance to use midpoint as the min_time_step
mid_point = (self.train_config.max_denoising_steps + min_timestep) / 2
if torch.rand(1) > 0.5:
min_timestep = mid_point
# 50% chance to use midpoint as the min_time_step
mid_point = (self.train_config.max_denoising_steps + min_timestep) / 2
if torch.rand(1) > 0.5:
min_timestep = mid_point
min_timestep = int(min_timestep)
timesteps = torch.randint(
min_timestep,
self.train_config.max_denoising_steps,

View File

@@ -154,6 +154,8 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
conv_alpha: Optional[float] = None,
use_cp: Optional[bool] = False,
network_module: Type[object] = LoConSpecialModule,
train_unet: bool = True,
train_text_encoder: bool = True,
**kwargs,
) -> None:
# call ToolkitNetworkMixin super
@@ -170,6 +172,8 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
rank_dropout = 0
if module_dropout is None:
module_dropout = 0
self.train_unet = train_unet
self.train_text_encoder = train_text_encoder
self.torch_multiplier = None
# triggers a tensor update
@@ -326,16 +330,19 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
use_index = False
self.text_encoder_loras = []
for i, te in enumerate(text_encoders):
self.text_encoder_loras.extend(create_modules(
LycorisSpecialNetwork.LORA_PREFIX_TEXT_ENCODER + (f'{i + 1}' if use_index else ''),
te,
LycorisSpecialNetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
))
if self.train_text_encoder:
for i, te in enumerate(text_encoders):
self.text_encoder_loras.extend(create_modules(
LycorisSpecialNetwork.LORA_PREFIX_TEXT_ENCODER + (f'{i + 1}' if use_index else ''),
te,
LycorisSpecialNetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
))
print(f"create LyCORIS for Text Encoder: {len(self.text_encoder_loras)} modules.")
self.unet_loras = create_modules(LycorisSpecialNetwork.LORA_PREFIX_UNET, unet,
LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE)
if self.train_unet:
self.unet_loras = create_modules(LycorisSpecialNetwork.LORA_PREFIX_UNET, unet,
LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE)
else:
self.unet_loras = []
print(f"create LyCORIS for U-Net: {len(self.unet_loras)} modules.")
self.weights_sd = None