Initial support for finetuning qwen image. Will only work with caching for now, need to add controls everywhere.

This commit is contained in:
Jaret Burkett
2025-08-21 16:41:17 -06:00
parent 38d3814be7
commit bf2700f7be
12 changed files with 399 additions and 31 deletions

View File

@@ -1232,5 +1232,10 @@ def validate_configs(
for dataset in dataset_configs:
if not dataset.cache_text_embeddings:
raise ValueError("All datasets must have cache_text_embeddings set to True when caching text embeddings is enabled.")
# qwen image edit cannot cache text embeddings
if model_config.arch == 'qwen_image_edit':
if train_config.unload_text_encoder:
raise ValueError("Cannot cache unload text encoder with qwen_image_edit model. Control images are encoded with text embeddings. You can cache the text embeddings though")

View File

@@ -497,6 +497,7 @@ class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin
dataloader_transforms=self.transform,
size_database=self.size_database,
dataset_root=dataset_folder,
encode_control_in_text_embeddings=self.sd.encode_control_in_text_embeddings if self.sd else False,
)
self.file_list.append(file_item)
except Exception as e:

View File

@@ -50,6 +50,7 @@ class FileItemDTO(
self.is_video = self.dataset_config.num_frames > 1
size_database = kwargs.get('size_database', {})
dataset_root = kwargs.get('dataset_root', None)
self.encode_control_in_text_embeddings = kwargs.get('encode_control_in_text_embeddings', False)
if dataset_root is not None:
# remove dataset root from path
file_key = self.path.replace(dataset_root, '')

View File

@@ -30,6 +30,7 @@ import albumentations as A
from toolkit.print import print_acc
from toolkit.accelerator import get_accelerator
from toolkit.prompt_utils import PromptEmbeds
from torchvision.transforms import functional as TF
from toolkit.train_tools import get_torch_dtype
@@ -1802,6 +1803,9 @@ class TextEmbeddingFileItemDTOMixin:
("text_embedding_space_version", self.text_embedding_space_version),
("text_embedding_version", self.text_embedding_version),
])
# if we have a control image, cache the path
if self.encode_control_in_text_embeddings and self.control_path is not None:
item["control_path"] = self.control_path
return item
def get_text_embedding_path(self: 'FileItemDTO', recalculate=False):
@@ -1860,7 +1864,19 @@ class TextEmbeddingCachingMixin:
if not did_move:
self.sd.set_device_state_preset('cache_text_encoder')
did_move = True
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)
if file_item.encode_control_in_text_embeddings and file_item.control_path is not None:
# load the control image and feed it into the text encoder
ctrl_img = Image.open(file_item.control_path).convert("RGB")
# convert to 0 to 1 tensor
ctrl_img = (
TF.to_tensor(ctrl_img)
.unsqueeze(0)
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
)
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption, control_images=ctrl_img)
else:
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)
# save it
prompt_embeds.save(text_embedding_path)
del prompt_embeds

View File

@@ -36,6 +36,7 @@ from diffusers import \
UNet2DConditionModel
from diffusers import PixArtAlphaPipeline
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
from torchvision.transforms import functional as TF
from toolkit.accelerator import get_accelerator, unwrap_model
from typing import TYPE_CHECKING
@@ -177,6 +178,9 @@ class BaseModel:
self.multistage_boundaries: List[float] = [0.0]
# a list of trainable multistage boundaries
self.trainable_multistage_boundaries: List[int] = [0]
# set true for models that encode control image into text embeddings
self.encode_control_in_text_embeddings = False
# properties for old arch for backwards compatibility
@property
@@ -287,7 +291,7 @@ class BaseModel:
raise NotImplementedError(
"get_noise_prediction must be implemented in child classes")
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
def get_prompt_embeds(self, prompt: str, control_images=None) -> PromptEmbeds:
raise NotImplementedError(
"get_prompt_embeds must be implemented in child classes")
@@ -496,17 +500,34 @@ class BaseModel:
if self.sample_prompts_cache is not None:
conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype)
unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype)
else:
else:
ctrl_img = None
# load the control image if out model uses it in text encoding
if gen_config.ctrl_img is not None and self.encode_control_in_text_embeddings:
ctrl_img = Image.open(gen_config.ctrl_img).convert("RGB")
# convert to 0 to 1 tensor
ctrl_img = (
TF.to_tensor(ctrl_img)
.unsqueeze(0)
.to(self.device_torch, dtype=self.torch_dtype)
)
# encode the prompt ourselves so we can do fun stuff with embeddings
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
conditional_embeds = self.encode_prompt(
gen_config.prompt, gen_config.prompt_2, force_all=True)
gen_config.prompt,
gen_config.prompt_2,
force_all=True,
control_images=ctrl_img
)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True
unconditional_embeds = self.encode_prompt(
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
gen_config.negative_prompt,
gen_config.negative_prompt_2,
force_all=True,
control_images=ctrl_img
)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
@@ -989,6 +1010,7 @@ class BaseModel:
long_prompts=False,
max_length=None,
dropout_prob=0.0,
control_images=None,
) -> PromptEmbeds:
# sd1.5 embeddings are (bs, 77, 768)
prompt = prompt
@@ -998,6 +1020,9 @@ class BaseModel:
if prompt2 is not None and not isinstance(prompt2, list):
prompt2 = [prompt2]
# if control_images in the signature, pass it. This keep from breaking plugins
if self.encode_control_in_text_embeddings:
return self.get_prompt_embeds(prompt, control_images=control_images)
return self.get_prompt_embeds(prompt)

View File

@@ -217,6 +217,9 @@ class StableDiffusion:
# a list of trainable multistage boundaries
self.trainable_multistage_boundaries: List[int] = [0]
# set true for models that encode control image into text embeddings
self.encode_control_in_text_embeddings = False
# properties for old arch for backwards compatibility
@property
def is_xl(self):
@@ -2356,6 +2359,7 @@ class StableDiffusion:
long_prompts=False,
max_length=None,
dropout_prob=0.0,
control_images=None,
) -> PromptEmbeds:
# sd1.5 embeddings are (bs, 77, 768)
prompt = prompt