mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-26 01:09:19 +00:00
Allow for training loras on onle one text encoder for sdxl
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import gc
|
||||
import json
|
||||
import random
|
||||
import shutil
|
||||
import typing
|
||||
from typing import Union, List, Literal, Iterator
|
||||
@@ -677,6 +678,14 @@ class StableDiffusion:
|
||||
if prompt2 is not None and not isinstance(prompt2, list):
|
||||
prompt2 = [prompt2]
|
||||
if self.is_xl:
|
||||
# todo make this a config
|
||||
# 50% chance to use an encoder anyway even if it is disabled
|
||||
# allows the other TE to compensate for the disabled one
|
||||
use_encoder_1 = self.use_text_encoder_1 or force_all or random.random() > 0.5
|
||||
use_encoder_2 = self.use_text_encoder_2 or force_all or random.random() > 0.5
|
||||
# use_encoder_1 = True
|
||||
# use_encoder_2 = True
|
||||
|
||||
return PromptEmbeds(
|
||||
train_tools.encode_prompts_xl(
|
||||
self.tokenizer,
|
||||
@@ -684,8 +693,8 @@ class StableDiffusion:
|
||||
prompt,
|
||||
prompt2,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
use_text_encoder_1=self.use_text_encoder_1 or force_all,
|
||||
use_text_encoder_2=self.use_text_encoder_2 or force_all,
|
||||
use_text_encoder_1=use_encoder_1,
|
||||
use_text_encoder_2=use_encoder_2,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -831,6 +840,13 @@ class StableDiffusion:
|
||||
if text_encoder:
|
||||
if isinstance(self.text_encoder, list):
|
||||
for i, encoder in enumerate(self.text_encoder):
|
||||
if self.is_xl and not self.model_config.use_text_encoder_1 and i == 0:
|
||||
# dont add these params
|
||||
continue
|
||||
if self.is_xl and not self.model_config.use_text_encoder_2 and i == 1:
|
||||
# dont add these params
|
||||
continue
|
||||
|
||||
for name, param in encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}{i}"):
|
||||
named_params[name] = param
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user