mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Prevent lycoris network moduels if not training that part of network. Skew timesteps to favor later steps. It performs better
This commit is contained in:
@@ -437,6 +437,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
else:
|
else:
|
||||||
min_timestep = self.train_config.min_denoising_steps
|
min_timestep = self.train_config.min_denoising_steps
|
||||||
|
|
||||||
|
# todo improve this, but is skews odds for higher timesteps
|
||||||
|
# 50% chance to use midpoint as the min_time_step
|
||||||
|
mid_point = (self.train_config.max_denoising_steps + min_timestep) / 2
|
||||||
|
if torch.rand(1) > 0.5:
|
||||||
|
min_timestep = mid_point
|
||||||
|
|
||||||
|
# 50% chance to use midpoint as the min_time_step
|
||||||
|
mid_point = (self.train_config.max_denoising_steps + min_timestep) / 2
|
||||||
|
if torch.rand(1) > 0.5:
|
||||||
|
min_timestep = mid_point
|
||||||
|
|
||||||
|
min_timestep = int(min_timestep)
|
||||||
|
|
||||||
timesteps = torch.randint(
|
timesteps = torch.randint(
|
||||||
min_timestep,
|
min_timestep,
|
||||||
self.train_config.max_denoising_steps,
|
self.train_config.max_denoising_steps,
|
||||||
|
|||||||
@@ -154,6 +154,8 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
|||||||
conv_alpha: Optional[float] = None,
|
conv_alpha: Optional[float] = None,
|
||||||
use_cp: Optional[bool] = False,
|
use_cp: Optional[bool] = False,
|
||||||
network_module: Type[object] = LoConSpecialModule,
|
network_module: Type[object] = LoConSpecialModule,
|
||||||
|
train_unet: bool = True,
|
||||||
|
train_text_encoder: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
# call ToolkitNetworkMixin super
|
# call ToolkitNetworkMixin super
|
||||||
@@ -170,6 +172,8 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
|||||||
rank_dropout = 0
|
rank_dropout = 0
|
||||||
if module_dropout is None:
|
if module_dropout is None:
|
||||||
module_dropout = 0
|
module_dropout = 0
|
||||||
|
self.train_unet = train_unet
|
||||||
|
self.train_text_encoder = train_text_encoder
|
||||||
|
|
||||||
self.torch_multiplier = None
|
self.torch_multiplier = None
|
||||||
# triggers a tensor update
|
# triggers a tensor update
|
||||||
@@ -326,16 +330,19 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
|||||||
use_index = False
|
use_index = False
|
||||||
|
|
||||||
self.text_encoder_loras = []
|
self.text_encoder_loras = []
|
||||||
for i, te in enumerate(text_encoders):
|
if self.train_text_encoder:
|
||||||
self.text_encoder_loras.extend(create_modules(
|
for i, te in enumerate(text_encoders):
|
||||||
LycorisSpecialNetwork.LORA_PREFIX_TEXT_ENCODER + (f'{i + 1}' if use_index else ''),
|
self.text_encoder_loras.extend(create_modules(
|
||||||
te,
|
LycorisSpecialNetwork.LORA_PREFIX_TEXT_ENCODER + (f'{i + 1}' if use_index else ''),
|
||||||
LycorisSpecialNetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
te,
|
||||||
))
|
LycorisSpecialNetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||||
|
))
|
||||||
print(f"create LyCORIS for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
print(f"create LyCORIS for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||||
|
if self.train_unet:
|
||||||
self.unet_loras = create_modules(LycorisSpecialNetwork.LORA_PREFIX_UNET, unet,
|
self.unet_loras = create_modules(LycorisSpecialNetwork.LORA_PREFIX_UNET, unet,
|
||||||
LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE)
|
LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE)
|
||||||
|
else:
|
||||||
|
self.unet_loras = []
|
||||||
print(f"create LyCORIS for U-Net: {len(self.unet_loras)} modules.")
|
print(f"create LyCORIS for U-Net: {len(self.unet_loras)} modules.")
|
||||||
|
|
||||||
self.weights_sd = None
|
self.weights_sd = None
|
||||||
|
|||||||
Reference in New Issue
Block a user