diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 50d415af..9b3a5b5c 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -722,7 +722,9 @@ class BaseSDTrainProcess(BaseTrainProcess): conv_alpha=self.network_config.conv_alpha, is_sdxl=self.model_config.is_xl, is_v2=self.model_config.is_v2, - dropout=self.network_config.dropout + dropout=self.network_config.dropout, + use_text_encoder_1=self.model_config.use_text_encoder_1, + use_text_encoder_2=self.model_config.use_text_encoder_2, ) self.network.force_to(self.device_torch, dtype=dtype) diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index aadaf541..5e599227 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -137,6 +137,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): module_class: Type[object] = LoRAModule, varbose: Optional[bool] = False, train_text_encoder: Optional[bool] = True, + use_text_encoder_1: bool = True, + use_text_encoder_2: bool = True, train_unet: Optional[bool] = True, is_sdxl=False, is_v2=False, @@ -273,6 +275,10 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): skipped_te = [] if train_text_encoder: for i, text_encoder in enumerate(text_encoders): + if not use_text_encoder_1 and i == 0: + continue + if not use_text_encoder_2 and i == 1: + continue if len(text_encoders) > 1: index = i + 1 print(f"create LoRA for Text Encoder {index}:") diff --git a/toolkit/lycoris_special.py b/toolkit/lycoris_special.py index 36cdd8b7..d007de6c 100644 --- a/toolkit/lycoris_special.py +++ b/toolkit/lycoris_special.py @@ -156,6 +156,8 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): network_module: Type[object] = LoConSpecialModule, train_unet: bool = True, train_text_encoder: bool = True, + use_text_encoder_1: bool = True, + use_text_encoder_2: bool = True, **kwargs, ) -> None: # call ToolkitNetworkMixin super @@ -332,6 +334,10 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): self.text_encoder_loras = [] if self.train_text_encoder: for i, te in enumerate(text_encoders): + if not use_text_encoder_1 and i == 0: + continue + if not use_text_encoder_2 and i == 1: + continue self.text_encoder_loras.extend(create_modules( LycorisSpecialNetwork.LORA_PREFIX_TEXT_ENCODER + (f'{i + 1}' if use_index else ''), te, diff --git a/toolkit/sd_device_states_presets.py b/toolkit/sd_device_states_presets.py index fc16fd9e..3cf78017 100644 --- a/toolkit/sd_device_states_presets.py +++ b/toolkit/sd_device_states_presets.py @@ -60,7 +60,7 @@ def get_train_sd_device_state_preset( preset['unet']['training'] = True if train_lora: - preset['text_encoder']['requires_grad'] = False + # preset['text_encoder']['requires_grad'] = False preset['unet']['requires_grad'] = False if train_adapter: diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index f1959beb..8d3ae4c4 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1,6 +1,7 @@ import copy import gc import json +import random import shutil import typing from typing import Union, List, Literal, Iterator @@ -677,6 +678,14 @@ class StableDiffusion: if prompt2 is not None and not isinstance(prompt2, list): prompt2 = [prompt2] if self.is_xl: + # todo make this a config + # 50% chance to use an encoder anyway even if it is disabled + # allows the other TE to compensate for the disabled one + use_encoder_1 = self.use_text_encoder_1 or force_all or random.random() > 0.5 + use_encoder_2 = self.use_text_encoder_2 or force_all or random.random() > 0.5 + # use_encoder_1 = True + # use_encoder_2 = True + return PromptEmbeds( train_tools.encode_prompts_xl( self.tokenizer, @@ -684,8 +693,8 @@ class StableDiffusion: prompt, prompt2, num_images_per_prompt=num_images_per_prompt, - use_text_encoder_1=self.use_text_encoder_1 or force_all, - use_text_encoder_2=self.use_text_encoder_2 or force_all, + use_text_encoder_1=use_encoder_1, + use_text_encoder_2=use_encoder_2, ) ) else: @@ -831,6 +840,13 @@ class StableDiffusion: if text_encoder: if isinstance(self.text_encoder, list): for i, encoder in enumerate(self.text_encoder): + if self.is_xl and not self.model_config.use_text_encoder_1 and i == 0: + # dont add these params + continue + if self.is_xl and not self.model_config.use_text_encoder_2 and i == 1: + # dont add these params + continue + for name, param in encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}{i}"): named_params[name] = param else: