mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-27 00:49:47 +00:00
Prevent lycoris network moduels if not training that part of network. Skew timesteps to favor later steps. It performs better
This commit is contained in:
@@ -437,6 +437,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
else:
|
||||
min_timestep = self.train_config.min_denoising_steps
|
||||
|
||||
# todo improve this, but is skews odds for higher timesteps
|
||||
# 50% chance to use midpoint as the min_time_step
|
||||
mid_point = (self.train_config.max_denoising_steps + min_timestep) / 2
|
||||
if torch.rand(1) > 0.5:
|
||||
min_timestep = mid_point
|
||||
|
||||
# 50% chance to use midpoint as the min_time_step
|
||||
mid_point = (self.train_config.max_denoising_steps + min_timestep) / 2
|
||||
if torch.rand(1) > 0.5:
|
||||
min_timestep = mid_point
|
||||
|
||||
min_timestep = int(min_timestep)
|
||||
|
||||
timesteps = torch.randint(
|
||||
min_timestep,
|
||||
self.train_config.max_denoising_steps,
|
||||
|
||||
@@ -154,6 +154,8 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
conv_alpha: Optional[float] = None,
|
||||
use_cp: Optional[bool] = False,
|
||||
network_module: Type[object] = LoConSpecialModule,
|
||||
train_unet: bool = True,
|
||||
train_text_encoder: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# call ToolkitNetworkMixin super
|
||||
@@ -170,6 +172,8 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
rank_dropout = 0
|
||||
if module_dropout is None:
|
||||
module_dropout = 0
|
||||
self.train_unet = train_unet
|
||||
self.train_text_encoder = train_text_encoder
|
||||
|
||||
self.torch_multiplier = None
|
||||
# triggers a tensor update
|
||||
@@ -326,16 +330,19 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
use_index = False
|
||||
|
||||
self.text_encoder_loras = []
|
||||
for i, te in enumerate(text_encoders):
|
||||
self.text_encoder_loras.extend(create_modules(
|
||||
LycorisSpecialNetwork.LORA_PREFIX_TEXT_ENCODER + (f'{i + 1}' if use_index else ''),
|
||||
te,
|
||||
LycorisSpecialNetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||
))
|
||||
if self.train_text_encoder:
|
||||
for i, te in enumerate(text_encoders):
|
||||
self.text_encoder_loras.extend(create_modules(
|
||||
LycorisSpecialNetwork.LORA_PREFIX_TEXT_ENCODER + (f'{i + 1}' if use_index else ''),
|
||||
te,
|
||||
LycorisSpecialNetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||
))
|
||||
print(f"create LyCORIS for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
self.unet_loras = create_modules(LycorisSpecialNetwork.LORA_PREFIX_UNET, unet,
|
||||
LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE)
|
||||
if self.train_unet:
|
||||
self.unet_loras = create_modules(LycorisSpecialNetwork.LORA_PREFIX_UNET, unet,
|
||||
LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE)
|
||||
else:
|
||||
self.unet_loras = []
|
||||
print(f"create LyCORIS for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
self.weights_sd = None
|
||||
|
||||
Reference in New Issue
Block a user