mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Initial support for finetuning qwen image. Will only work with caching for now, need to add controls everywhere.
This commit is contained in:
@@ -1232,5 +1232,10 @@ def validate_configs(
|
||||
for dataset in dataset_configs:
|
||||
if not dataset.cache_text_embeddings:
|
||||
raise ValueError("All datasets must have cache_text_embeddings set to True when caching text embeddings is enabled.")
|
||||
|
||||
# qwen image edit cannot cache text embeddings
|
||||
if model_config.arch == 'qwen_image_edit':
|
||||
if train_config.unload_text_encoder:
|
||||
raise ValueError("Cannot cache unload text encoder with qwen_image_edit model. Control images are encoded with text embeddings. You can cache the text embeddings though")
|
||||
|
||||
|
||||
|
||||
@@ -497,6 +497,7 @@ class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin
|
||||
dataloader_transforms=self.transform,
|
||||
size_database=self.size_database,
|
||||
dataset_root=dataset_folder,
|
||||
encode_control_in_text_embeddings=self.sd.encode_control_in_text_embeddings if self.sd else False,
|
||||
)
|
||||
self.file_list.append(file_item)
|
||||
except Exception as e:
|
||||
|
||||
@@ -50,6 +50,7 @@ class FileItemDTO(
|
||||
self.is_video = self.dataset_config.num_frames > 1
|
||||
size_database = kwargs.get('size_database', {})
|
||||
dataset_root = kwargs.get('dataset_root', None)
|
||||
self.encode_control_in_text_embeddings = kwargs.get('encode_control_in_text_embeddings', False)
|
||||
if dataset_root is not None:
|
||||
# remove dataset root from path
|
||||
file_key = self.path.replace(dataset_root, '')
|
||||
|
||||
@@ -30,6 +30,7 @@ import albumentations as A
|
||||
from toolkit.print import print_acc
|
||||
from toolkit.accelerator import get_accelerator
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from torchvision.transforms import functional as TF
|
||||
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
@@ -1802,6 +1803,9 @@ class TextEmbeddingFileItemDTOMixin:
|
||||
("text_embedding_space_version", self.text_embedding_space_version),
|
||||
("text_embedding_version", self.text_embedding_version),
|
||||
])
|
||||
# if we have a control image, cache the path
|
||||
if self.encode_control_in_text_embeddings and self.control_path is not None:
|
||||
item["control_path"] = self.control_path
|
||||
return item
|
||||
|
||||
def get_text_embedding_path(self: 'FileItemDTO', recalculate=False):
|
||||
@@ -1860,7 +1864,19 @@ class TextEmbeddingCachingMixin:
|
||||
if not did_move:
|
||||
self.sd.set_device_state_preset('cache_text_encoder')
|
||||
did_move = True
|
||||
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)
|
||||
|
||||
if file_item.encode_control_in_text_embeddings and file_item.control_path is not None:
|
||||
# load the control image and feed it into the text encoder
|
||||
ctrl_img = Image.open(file_item.control_path).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img = (
|
||||
TF.to_tensor(ctrl_img)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption, control_images=ctrl_img)
|
||||
else:
|
||||
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)
|
||||
# save it
|
||||
prompt_embeds.save(text_embedding_path)
|
||||
del prompt_embeds
|
||||
|
||||
@@ -36,6 +36,7 @@ from diffusers import \
|
||||
UNet2DConditionModel
|
||||
from diffusers import PixArtAlphaPipeline
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
||||
from torchvision.transforms import functional as TF
|
||||
|
||||
from toolkit.accelerator import get_accelerator, unwrap_model
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -177,6 +178,9 @@ class BaseModel:
|
||||
self.multistage_boundaries: List[float] = [0.0]
|
||||
# a list of trainable multistage boundaries
|
||||
self.trainable_multistage_boundaries: List[int] = [0]
|
||||
|
||||
# set true for models that encode control image into text embeddings
|
||||
self.encode_control_in_text_embeddings = False
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
@@ -287,7 +291,7 @@ class BaseModel:
|
||||
raise NotImplementedError(
|
||||
"get_noise_prediction must be implemented in child classes")
|
||||
|
||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||
def get_prompt_embeds(self, prompt: str, control_images=None) -> PromptEmbeds:
|
||||
raise NotImplementedError(
|
||||
"get_prompt_embeds must be implemented in child classes")
|
||||
|
||||
@@ -496,17 +500,34 @@ class BaseModel:
|
||||
if self.sample_prompts_cache is not None:
|
||||
conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype)
|
||||
unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype)
|
||||
else:
|
||||
else:
|
||||
ctrl_img = None
|
||||
# load the control image if out model uses it in text encoding
|
||||
if gen_config.ctrl_img is not None and self.encode_control_in_text_embeddings:
|
||||
ctrl_img = Image.open(gen_config.ctrl_img).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img = (
|
||||
TF.to_tensor(ctrl_img)
|
||||
.unsqueeze(0)
|
||||
.to(self.device_torch, dtype=self.torch_dtype)
|
||||
)
|
||||
# encode the prompt ourselves so we can do fun stuff with embeddings
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
conditional_embeds = self.encode_prompt(
|
||||
gen_config.prompt, gen_config.prompt_2, force_all=True)
|
||||
gen_config.prompt,
|
||||
gen_config.prompt_2,
|
||||
force_all=True,
|
||||
control_images=ctrl_img
|
||||
)
|
||||
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = True
|
||||
unconditional_embeds = self.encode_prompt(
|
||||
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
|
||||
gen_config.negative_prompt,
|
||||
gen_config.negative_prompt_2,
|
||||
force_all=True,
|
||||
control_images=ctrl_img
|
||||
)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
@@ -989,6 +1010,7 @@ class BaseModel:
|
||||
long_prompts=False,
|
||||
max_length=None,
|
||||
dropout_prob=0.0,
|
||||
control_images=None,
|
||||
) -> PromptEmbeds:
|
||||
# sd1.5 embeddings are (bs, 77, 768)
|
||||
prompt = prompt
|
||||
@@ -998,6 +1020,9 @@ class BaseModel:
|
||||
|
||||
if prompt2 is not None and not isinstance(prompt2, list):
|
||||
prompt2 = [prompt2]
|
||||
# if control_images in the signature, pass it. This keep from breaking plugins
|
||||
if self.encode_control_in_text_embeddings:
|
||||
return self.get_prompt_embeds(prompt, control_images=control_images)
|
||||
|
||||
return self.get_prompt_embeds(prompt)
|
||||
|
||||
|
||||
@@ -217,6 +217,9 @@ class StableDiffusion:
|
||||
# a list of trainable multistage boundaries
|
||||
self.trainable_multistage_boundaries: List[int] = [0]
|
||||
|
||||
# set true for models that encode control image into text embeddings
|
||||
self.encode_control_in_text_embeddings = False
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
def is_xl(self):
|
||||
@@ -2356,6 +2359,7 @@ class StableDiffusion:
|
||||
long_prompts=False,
|
||||
max_length=None,
|
||||
dropout_prob=0.0,
|
||||
control_images=None,
|
||||
) -> PromptEmbeds:
|
||||
# sd1.5 embeddings are (bs, 77, 768)
|
||||
prompt = prompt
|
||||
|
||||
Reference in New Issue
Block a user