mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Allow for training loras on onle one text encoder for sdxl
This commit is contained in:
@@ -722,7 +722,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
conv_alpha=self.network_config.conv_alpha,
|
||||
is_sdxl=self.model_config.is_xl,
|
||||
is_v2=self.model_config.is_v2,
|
||||
dropout=self.network_config.dropout
|
||||
dropout=self.network_config.dropout,
|
||||
use_text_encoder_1=self.model_config.use_text_encoder_1,
|
||||
use_text_encoder_2=self.model_config.use_text_encoder_2,
|
||||
)
|
||||
|
||||
self.network.force_to(self.device_torch, dtype=dtype)
|
||||
|
||||
@@ -137,6 +137,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
module_class: Type[object] = LoRAModule,
|
||||
varbose: Optional[bool] = False,
|
||||
train_text_encoder: Optional[bool] = True,
|
||||
use_text_encoder_1: bool = True,
|
||||
use_text_encoder_2: bool = True,
|
||||
train_unet: Optional[bool] = True,
|
||||
is_sdxl=False,
|
||||
is_v2=False,
|
||||
@@ -273,6 +275,10 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
skipped_te = []
|
||||
if train_text_encoder:
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if not use_text_encoder_1 and i == 0:
|
||||
continue
|
||||
if not use_text_encoder_2 and i == 1:
|
||||
continue
|
||||
if len(text_encoders) > 1:
|
||||
index = i + 1
|
||||
print(f"create LoRA for Text Encoder {index}:")
|
||||
|
||||
@@ -156,6 +156,8 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
network_module: Type[object] = LoConSpecialModule,
|
||||
train_unet: bool = True,
|
||||
train_text_encoder: bool = True,
|
||||
use_text_encoder_1: bool = True,
|
||||
use_text_encoder_2: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# call ToolkitNetworkMixin super
|
||||
@@ -332,6 +334,10 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
self.text_encoder_loras = []
|
||||
if self.train_text_encoder:
|
||||
for i, te in enumerate(text_encoders):
|
||||
if not use_text_encoder_1 and i == 0:
|
||||
continue
|
||||
if not use_text_encoder_2 and i == 1:
|
||||
continue
|
||||
self.text_encoder_loras.extend(create_modules(
|
||||
LycorisSpecialNetwork.LORA_PREFIX_TEXT_ENCODER + (f'{i + 1}' if use_index else ''),
|
||||
te,
|
||||
|
||||
@@ -60,7 +60,7 @@ def get_train_sd_device_state_preset(
|
||||
preset['unet']['training'] = True
|
||||
|
||||
if train_lora:
|
||||
preset['text_encoder']['requires_grad'] = False
|
||||
# preset['text_encoder']['requires_grad'] = False
|
||||
preset['unet']['requires_grad'] = False
|
||||
|
||||
if train_adapter:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import gc
|
||||
import json
|
||||
import random
|
||||
import shutil
|
||||
import typing
|
||||
from typing import Union, List, Literal, Iterator
|
||||
@@ -677,6 +678,14 @@ class StableDiffusion:
|
||||
if prompt2 is not None and not isinstance(prompt2, list):
|
||||
prompt2 = [prompt2]
|
||||
if self.is_xl:
|
||||
# todo make this a config
|
||||
# 50% chance to use an encoder anyway even if it is disabled
|
||||
# allows the other TE to compensate for the disabled one
|
||||
use_encoder_1 = self.use_text_encoder_1 or force_all or random.random() > 0.5
|
||||
use_encoder_2 = self.use_text_encoder_2 or force_all or random.random() > 0.5
|
||||
# use_encoder_1 = True
|
||||
# use_encoder_2 = True
|
||||
|
||||
return PromptEmbeds(
|
||||
train_tools.encode_prompts_xl(
|
||||
self.tokenizer,
|
||||
@@ -684,8 +693,8 @@ class StableDiffusion:
|
||||
prompt,
|
||||
prompt2,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
use_text_encoder_1=self.use_text_encoder_1 or force_all,
|
||||
use_text_encoder_2=self.use_text_encoder_2 or force_all,
|
||||
use_text_encoder_1=use_encoder_1,
|
||||
use_text_encoder_2=use_encoder_2,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -831,6 +840,13 @@ class StableDiffusion:
|
||||
if text_encoder:
|
||||
if isinstance(self.text_encoder, list):
|
||||
for i, encoder in enumerate(self.text_encoder):
|
||||
if self.is_xl and not self.model_config.use_text_encoder_1 and i == 0:
|
||||
# dont add these params
|
||||
continue
|
||||
if self.is_xl and not self.model_config.use_text_encoder_2 and i == 1:
|
||||
# dont add these params
|
||||
continue
|
||||
|
||||
for name, param in encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}{i}"):
|
||||
named_params[name] = param
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user