Allow for training loras on onle one text encoder for sdxl

This commit is contained in:
Jaret Burkett
2023-10-06 08:11:56 -06:00
parent f73402473b
commit cac8754399
5 changed files with 34 additions and 4 deletions

View File

@@ -722,7 +722,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
conv_alpha=self.network_config.conv_alpha,
is_sdxl=self.model_config.is_xl,
is_v2=self.model_config.is_v2,
dropout=self.network_config.dropout
dropout=self.network_config.dropout,
use_text_encoder_1=self.model_config.use_text_encoder_1,
use_text_encoder_2=self.model_config.use_text_encoder_2,
)
self.network.force_to(self.device_torch, dtype=dtype)

View File

@@ -137,6 +137,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
module_class: Type[object] = LoRAModule,
varbose: Optional[bool] = False,
train_text_encoder: Optional[bool] = True,
use_text_encoder_1: bool = True,
use_text_encoder_2: bool = True,
train_unet: Optional[bool] = True,
is_sdxl=False,
is_v2=False,
@@ -273,6 +275,10 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
skipped_te = []
if train_text_encoder:
for i, text_encoder in enumerate(text_encoders):
if not use_text_encoder_1 and i == 0:
continue
if not use_text_encoder_2 and i == 1:
continue
if len(text_encoders) > 1:
index = i + 1
print(f"create LoRA for Text Encoder {index}:")

View File

@@ -156,6 +156,8 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
network_module: Type[object] = LoConSpecialModule,
train_unet: bool = True,
train_text_encoder: bool = True,
use_text_encoder_1: bool = True,
use_text_encoder_2: bool = True,
**kwargs,
) -> None:
# call ToolkitNetworkMixin super
@@ -332,6 +334,10 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
self.text_encoder_loras = []
if self.train_text_encoder:
for i, te in enumerate(text_encoders):
if not use_text_encoder_1 and i == 0:
continue
if not use_text_encoder_2 and i == 1:
continue
self.text_encoder_loras.extend(create_modules(
LycorisSpecialNetwork.LORA_PREFIX_TEXT_ENCODER + (f'{i + 1}' if use_index else ''),
te,

View File

@@ -60,7 +60,7 @@ def get_train_sd_device_state_preset(
preset['unet']['training'] = True
if train_lora:
preset['text_encoder']['requires_grad'] = False
# preset['text_encoder']['requires_grad'] = False
preset['unet']['requires_grad'] = False
if train_adapter:

View File

@@ -1,6 +1,7 @@
import copy
import gc
import json
import random
import shutil
import typing
from typing import Union, List, Literal, Iterator
@@ -677,6 +678,14 @@ class StableDiffusion:
if prompt2 is not None and not isinstance(prompt2, list):
prompt2 = [prompt2]
if self.is_xl:
# todo make this a config
# 50% chance to use an encoder anyway even if it is disabled
# allows the other TE to compensate for the disabled one
use_encoder_1 = self.use_text_encoder_1 or force_all or random.random() > 0.5
use_encoder_2 = self.use_text_encoder_2 or force_all or random.random() > 0.5
# use_encoder_1 = True
# use_encoder_2 = True
return PromptEmbeds(
train_tools.encode_prompts_xl(
self.tokenizer,
@@ -684,8 +693,8 @@ class StableDiffusion:
prompt,
prompt2,
num_images_per_prompt=num_images_per_prompt,
use_text_encoder_1=self.use_text_encoder_1 or force_all,
use_text_encoder_2=self.use_text_encoder_2 or force_all,
use_text_encoder_1=use_encoder_1,
use_text_encoder_2=use_encoder_2,
)
)
else:
@@ -831,6 +840,13 @@ class StableDiffusion:
if text_encoder:
if isinstance(self.text_encoder, list):
for i, encoder in enumerate(self.text_encoder):
if self.is_xl and not self.model_config.use_text_encoder_1 and i == 0:
# dont add these params
continue
if self.is_xl and not self.model_config.use_text_encoder_2 and i == 1:
# dont add these params
continue
for name, param in encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}{i}"):
named_params[name] = param
else: