From 454be0958a3fbf2d82a7e1ff31dfe1368ec3339c Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 24 Sep 2025 11:39:10 -0600 Subject: [PATCH] Initial support for qwen image edit plus --- .../diffusion_models/__init__.py | 3 +- .../diffusion_models/qwen_image/__init__.py | 3 +- .../qwen_image/qwen_image_edit_plus.py | 309 ++++++++++++++++++ extensions_built_in/sd_trainer/SDTrainer.py | 70 +++- jobs/process/BaseSDTrainProcess.py | 3 + requirements.txt | 2 +- toolkit/config_modules.py | 14 + toolkit/data_transfer_object/data_loader.py | 12 +- toolkit/dataloader_mixins.py | 53 +-- toolkit/models/base_model.py | 4 + toolkit/stable_diffusion_model.py | 4 + 11 files changed, 445 insertions(+), 32 deletions(-) create mode 100644 extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py diff --git a/extensions_built_in/diffusion_models/__init__.py b/extensions_built_in/diffusion_models/__init__.py index 6449fe61..f7d874e5 100644 --- a/extensions_built_in/diffusion_models/__init__.py +++ b/extensions_built_in/diffusion_models/__init__.py @@ -4,7 +4,7 @@ from .f_light import FLiteModel from .omnigen2 import OmniGen2Model from .flux_kontext import FluxKontextModel from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel -from .qwen_image import QwenImageModel, QwenImageEditModel +from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusModel AI_TOOLKIT_MODELS = [ # put a list of models here @@ -20,4 +20,5 @@ AI_TOOLKIT_MODELS = [ Wan2214bModel, QwenImageModel, QwenImageEditModel, + QwenImageEditPlusModel, ] diff --git a/extensions_built_in/diffusion_models/qwen_image/__init__.py b/extensions_built_in/diffusion_models/qwen_image/__init__.py index d8b32a85..3ff1797b 100644 --- a/extensions_built_in/diffusion_models/qwen_image/__init__.py +++ b/extensions_built_in/diffusion_models/qwen_image/__init__.py @@ -1,2 +1,3 @@ from .qwen_image import QwenImageModel -from .qwen_image_edit import QwenImageEditModel \ No newline at end of file +from .qwen_image_edit import QwenImageEditModel +from .qwen_image_edit_plus import QwenImageEditPlusModel diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py new file mode 100644 index 00000000..32fcaa70 --- /dev/null +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py @@ -0,0 +1,309 @@ +import math +import torch +from .qwen_image import QwenImageModel +import os +from typing import TYPE_CHECKING, List, Optional +import yaml +from toolkit import train_tools +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from PIL import Image +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.accelerator import get_accelerator, unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.quantize import quantize, get_qtype, quantize_model +import torch.nn.functional as F + +from diffusers import ( + QwenImageTransformer2DModel, + AutoencoderKLQwenImage, +) +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer +from tqdm import tqdm + + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +try: + from diffusers import QwenImageEditPlusPipeline + from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus import CONDITION_IMAGE_SIZE, VAE_IMAGE_SIZE +except ImportError: + raise ImportError( + "Diffusers is out of date. Update diffusers to the latest version by doing 'pip uninstall diffusers' and then 'pip install -r requirements.txt'" + ) + + +class QwenImageEditPlusModel(QwenImageModel): + arch = "qwen_image_edit_plus" + _qwen_image_keep_visual = True + _qwen_pipeline = QwenImageEditPlusPipeline + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["QwenImageTransformer2DModel"] + + # set true for models that encode control image into text embeddings + self.encode_control_in_text_embeddings = True + # control images will come in as a list for encoding some things if true + self.has_multiple_control_images = True + # do not resize control images + self.use_raw_control_images = True + + def load_model(self): + super().load_model() + + def get_generation_pipeline(self): + scheduler = QwenImageModel.get_train_scheduler() + + pipeline: QwenImageEditPlusPipeline = QwenImageEditPlusPipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + processor=self.processor, + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer), + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: QwenImageEditPlusPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + self.model.to(self.device_torch, dtype=self.torch_dtype) + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + + control_img_list = [] + if gen_config.ctrl_img is not None: + control_img = Image.open(gen_config.ctrl_img) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + elif gen_config.ctrl_img_1 is not None: + control_img = Image.open(gen_config.ctrl_img_1) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + + if gen_config.ctrl_img_2 is not None: + control_img = Image.open(gen_config.ctrl_img_2) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + if gen_config.ctrl_img_3 is not None: + control_img = Image.open(gen_config.ctrl_img_3) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + + # flush for low vram if we are doing that + # flush_between_steps = self.model_config.low_vram + flush_between_steps = False + + # Fix a bug in diffusers/torch + def callback_on_step_end(pipe, i, t, callback_kwargs): + if flush_between_steps: + flush() + latents = callback_kwargs["latents"] + + return {"latents": latents} + + img = pipeline( + image=control_img_list, + prompt_embeds=conditional_embeds.text_embeds, + prompt_embeds_mask=conditional_embeds.attention_mask.to( + self.device_torch, dtype=torch.int64 + ), + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_prompt_embeds_mask=unconditional_embeds.attention_mask.to( + self.device_torch, dtype=torch.int64 + ), + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + true_cfg_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + callback_on_step_end=callback_on_step_end, + **extra, + ).images[0] + return img + + def condition_noisy_latents( + self, latents: torch.Tensor, batch: "DataLoaderBatchDTO" + ): + # we get the control image from the batch + return latents.detach() + + def get_prompt_embeds(self, prompt: str, control_images=None) -> PromptEmbeds: + # todo handle not caching text encoder + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + if control_images is not None and len(control_images) > 0: + for i in range(len(control_images)): + # control images are 0 - 1 scale, shape (bs, ch, height, width) + ratio = control_images[i].shape[2] / control_images[i].shape[3] + width = math.sqrt(CONDITION_IMAGE_SIZE * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + control_images[i] = F.interpolate( + control_images[i], size=(height, width), mode="bilinear" + ) + + prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt( + prompt, + image=control_images, + device=self.device_torch, + num_images_per_prompt=1, + ) + pe = PromptEmbeds(prompt_embeds) + pe.attention_mask = prompt_embeds_mask + return pe + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + batch: "DataLoaderBatchDTO" = None, + **kwargs, + ): + with torch.no_grad(): + batch_size, num_channels_latents, height, width = latent_model_input.shape + + # pack image tokens + latent_model_input = latent_model_input.view( + batch_size, num_channels_latents, height // 2, 2, width // 2, 2 + ) + latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5) + latent_model_input = latent_model_input.reshape( + batch_size, (height // 2) * (width // 2), num_channels_latents * 4 + ) + + raw_packed_latents = latent_model_input + + img_h2, img_w2 = height // 2, width // 2 + + img_shapes = [ + [(1, img_h2, img_w2)] + ] * batch_size + + # pack controls + if batch is None: + raise ValueError("Batch is required for QwenImageEditPlusModel") + + # split the latents into batch items so we can concat the controls + packed_latents_list = torch.chunk(latent_model_input, batch_size, dim=0) + packed_latents_with_controls_list = [] + + if batch.control_tensor_list is not None: + if len(batch.control_tensor_list) != batch_size: + raise ValueError("Control tensor list length does not match batch size") + b = 0 + for control_tensor_list in batch.control_tensor_list: + # control tensor list is a list of tensors for this batch item + controls = [] + # pack control + for control_img in control_tensor_list: + # control images are 0 - 1 scale, shape (1, ch, height, width) + control_img = control_img.to(self.device_torch, dtype=self.torch_dtype) + # if it is only 3 dim, add batch dim + if len(control_img.shape) == 3: + control_img = control_img.unsqueeze(0) + ratio = control_img.shape[2] / control_img.shape[3] + c_width = math.sqrt(VAE_IMAGE_SIZE * ratio) + c_height = c_width / ratio + + c_width = round(c_width / 32) * 32 + c_height = round(c_height / 32) * 32 + + control_img = F.interpolate( + control_img, size=(c_height, c_width), mode="bilinear" + ) + + control_latent = self.encode_images( + control_img, + device=self.device_torch, + dtype=self.torch_dtype, + ) + + clb, cl_num_channels_latents, cl_height, cl_width = control_latent.shape + + control = control_latent.view( + 1, cl_num_channels_latents, cl_height // 2, 2, cl_width // 2, 2 + ) + control = control.permute(0, 2, 4, 1, 3, 5) + control = control.reshape( + 1, (cl_height // 2) * (cl_width // 2), num_channels_latents * 4 + ) + + img_shapes[b].append((1, cl_height // 2, cl_width // 2)) + controls.append(control) + + # stack controls on dim 1 + control = torch.cat(controls, dim=1).to(packed_latents_list[b].device, dtype=packed_latents_list[b].dtype) + # concat with latents + packed_latents_with_control = torch.cat([packed_latents_list[b], control], dim=1) + + packed_latents_with_controls_list.append(packed_latents_with_control) + + b += 1 + + latent_model_input = torch.cat(packed_latents_with_controls_list, dim=0) + + prompt_embeds_mask = text_embeddings.attention_mask.to( + self.device_torch, dtype=torch.int64 + ) + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() + enc_hs = text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype) + prompt_embeds_mask = text_embeddings.attention_mask.to( + self.device_torch, dtype=torch.int64 + ) + + noise_pred = self.transformer( + hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=enc_hs, + encoder_hidden_states_mask=prompt_embeds_mask, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + return_dict=False, + **kwargs, + )[0] + + noise_pred = noise_pred[:, : raw_packed_latents.size(1)] + + # unpack + noise_pred = noise_pred.view( + batch_size, height // 2, width // 2, num_channels_latents, 2, 2 + ) + noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5) + noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width) + return noise_pred diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 16c392d7..86c198ea 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -129,17 +129,64 @@ class SDTrainer(BaseSDTrainProcess): prompt=prompt, # it will autoparse the prompt negative_prompt=sample_item.neg, output_path=output_path, - ctrl_img=sample_item.ctrl_img + ctrl_img=sample_item.ctrl_img, + ctrl_img_1=sample_item.ctrl_img_1, + ctrl_img_2=sample_item.ctrl_img_2, + ctrl_img_3=sample_item.ctrl_img_3, ) + + has_control_images = False + if gen_img_config.ctrl_img is not None or gen_img_config.ctrl_img_1 is not None or gen_img_config.ctrl_img_2 is not None or gen_img_config.ctrl_img_3 is not None: + has_control_images = True # see if we need to encode the control images - if self.sd.encode_control_in_text_embeddings and gen_img_config.ctrl_img is not None: - ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB") - # convert to 0 to 1 tensor - ctrl_img = ( - TF.to_tensor(ctrl_img) - .unsqueeze(0) - .to(self.sd.device_torch, dtype=self.sd.torch_dtype) - ) + if self.sd.encode_control_in_text_embeddings and has_control_images: + + ctrl_img_list = [] + + if gen_img_config.ctrl_img is not None: + ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img = ( + TF.to_tensor(ctrl_img) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img) + + if gen_img_config.ctrl_img_1 is not None: + ctrl_img_1 = Image.open(gen_img_config.ctrl_img_1).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_1 = ( + TF.to_tensor(ctrl_img_1) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_1) + if gen_img_config.ctrl_img_2 is not None: + ctrl_img_2 = Image.open(gen_img_config.ctrl_img_2).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_2 = ( + TF.to_tensor(ctrl_img_2) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_2) + if gen_img_config.ctrl_img_3 is not None: + ctrl_img_3 = Image.open(gen_img_config.ctrl_img_3).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_3 = ( + TF.to_tensor(ctrl_img_3) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_3) + + if self.sd.has_multiple_control_images: + ctrl_img = ctrl_img_list + else: + ctrl_img = ctrl_img_list[0] if len(ctrl_img_list) > 0 else None + + positive = self.sd.encode_prompt( gen_img_config.prompt, control_images=ctrl_img @@ -202,6 +249,9 @@ class SDTrainer(BaseSDTrainProcess): if self.sd.encode_control_in_text_embeddings: # just do a blank image for unconditionals control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype) + if self.sd.has_multiple_control_images: + control_image = [control_image] + kwargs['control_images'] = control_image self.unconditional_embeds = self.sd.encode_prompt( [self.train_config.unconditional_prompt], @@ -272,6 +322,8 @@ class SDTrainer(BaseSDTrainProcess): if self.sd.encode_control_in_text_embeddings: # just do a blank image for unconditionals control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype) + if self.sd.has_multiple_control_images: + control_image = [control_image] encode_kwargs['control_images'] = control_image self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs) if self.trigger_word is not None: diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index e9b3d260..01d37b81 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -348,6 +348,9 @@ class BaseSDTrainProcess(BaseTrainProcess): fps=sample_item.fps, ctrl_img=sample_item.ctrl_img, ctrl_idx=sample_item.ctrl_idx, + ctrl_img_1=sample_item.ctrl_img_1, + ctrl_img_2=sample_item.ctrl_img_2, + ctrl_img_3=sample_item.ctrl_img_3, **extra_args )) diff --git a/requirements.txt b/requirements.txt index 39fa3dd0..e5442be3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torchao==0.10.0 safetensors git+https://github.com/jaretburkett/easy_dwpose.git -git+https://github.com/huggingface/diffusers@7a2b78bf0f788d311cc96b61e660a8e13e3b1e63 +git+https://github.com/huggingface/diffusers@1448b035859dd57bbb565239dcdd79a025a85422 transformers==4.52.4 lycoris-lora==1.8.3 flatten_json diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index ca0b3bf2..24da74fe 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -56,6 +56,11 @@ class SampleItem: self.num_frames: int = kwargs.get('num_frames', sample_config.num_frames) self.ctrl_img: Optional[str] = kwargs.get('ctrl_img', None) self.ctrl_idx: int = kwargs.get('ctrl_idx', 0) + # for multi control image models + self.ctrl_img_1: Optional[str] = kwargs.get('ctrl_img_1', self.ctrl_img) + self.ctrl_img_2: Optional[str] = kwargs.get('ctrl_img_2', None) + self.ctrl_img_3: Optional[str] = kwargs.get('ctrl_img_3', None) + self.network_multiplier: float = kwargs.get('network_multiplier', sample_config.network_multiplier) # convert to a number if it is a string if isinstance(self.network_multiplier, str): @@ -966,6 +971,9 @@ class GenerateImageConfig: extra_values: List[float] = None, # extra values to save with prompt file logger: Optional[EmptyLogger] = None, ctrl_img: Optional[str] = None, # control image for controlnet + ctrl_img_1: Optional[str] = None, # first control image for multi control model + ctrl_img_2: Optional[str] = None, # second control image for multi control model + ctrl_img_3: Optional[str] = None, # third control image for multi control model num_frames: int = 1, fps: int = 15, ctrl_idx: int = 0 @@ -1002,6 +1010,12 @@ class GenerateImageConfig: self.ctrl_img = ctrl_img self.ctrl_idx = ctrl_idx + if ctrl_img_1 is None and ctrl_img is not None: + ctrl_img_1 = ctrl_img + + self.ctrl_img_1 = ctrl_img_1 + self.ctrl_img_2 = ctrl_img_2 + self.ctrl_img_3 = ctrl_img_3 # prompt string will override any settings above self._process_prompt_string() diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index a7ed3759..ca27dd47 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -144,6 +144,7 @@ class DataLoaderBatchDTO: self.tensor: Union[torch.Tensor, None] = None self.latents: Union[torch.Tensor, None] = None self.control_tensor: Union[torch.Tensor, None] = None + self.control_tensor_list: Union[List[List[torch.Tensor]], None] = None self.clip_image_tensor: Union[torch.Tensor, None] = None self.mask_tensor: Union[torch.Tensor, None] = None self.unaugmented_tensor: Union[torch.Tensor, None] = None @@ -160,7 +161,6 @@ class DataLoaderBatchDTO: self.latents: Union[torch.Tensor, None] = None if is_latents_cached: self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items]) - self.control_tensor: Union[torch.Tensor, None] = None self.prompt_embeds: Union[PromptEmbeds, None] = None # if self.file_items[0].control_tensor is not None: # if any have a control tensor, we concatenate them @@ -178,6 +178,16 @@ class DataLoaderBatchDTO: else: control_tensors.append(x.control_tensor) self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors]) + + # handle control tensor list + if any([x.control_tensor_list is not None for x in self.file_items]): + self.control_tensor_list = [] + for x in self.file_items: + if x.control_tensor_list is not None: + self.control_tensor_list.append(x.control_tensor_list) + else: + raise Exception(f"Could not find control tensors for all file items, missing for {x.path}") + self.inpaint_tensor: Union[torch.Tensor, None] = None if any([x.inpaint_tensor is not None for x in self.file_items]): diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 29b2aced..7a114bdd 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -850,6 +850,9 @@ class ControlFileItemDTOMixin: self.has_control_image = False self.control_path: Union[str, List[str], None] = None self.control_tensor: Union[torch.Tensor, None] = None + self.control_tensor_list: Union[List[torch.Tensor], None] = None + sd = kwargs.get('sd', None) + self.use_raw_control_images = sd is not None and sd.use_raw_control_images dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) self.full_size_control_images = False if dataset_config.control_path is not None: @@ -900,23 +903,14 @@ class ControlFileItemDTOMixin: except Exception as e: print_acc(f"Error: {e}") print_acc(f"Error loading image: {control_path}") - + if not self.full_size_control_images: # we just scale them to 512x512: w, h = img.size img = img.resize((512, 512), Image.BICUBIC) - else: + elif not self.use_raw_control_images: w, h = img.size - if w > h and self.scale_to_width < self.scale_to_height: - # throw error, they should match - raise ValueError( - f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") - elif h > w and self.scale_to_height < self.scale_to_width: - # throw error, they should match - raise ValueError( - f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") - if self.flip_x: # do a flip img = img.transpose(Image.FLIP_LEFT_RIGHT) @@ -950,11 +944,15 @@ class ControlFileItemDTOMixin: self.control_tensor = None elif len(control_tensors) == 1: self.control_tensor = control_tensors[0] + elif self.use_raw_control_images: + # just send the list of tensors as their shapes wont match + self.control_tensor_list = control_tensors else: self.control_tensor = torch.stack(control_tensors, dim=0) def cleanup_control(self: 'FileItemDTO'): self.control_tensor = None + self.control_tensor_list = None class ClipImageFileItemDTOMixin: @@ -1884,14 +1882,31 @@ class TextEmbeddingCachingMixin: if file_item.encode_control_in_text_embeddings: if file_item.control_path is None: raise Exception(f"Could not find a control image for {file_item.path} which is needed for this model") - # load the control image and feed it into the text encoder - ctrl_img = Image.open(file_item.control_path).convert("RGB") - # convert to 0 to 1 tensor - ctrl_img = ( - TF.to_tensor(ctrl_img) - .unsqueeze(0) - .to(self.sd.device_torch, dtype=self.sd.torch_dtype) - ) + ctrl_img_list = [] + control_path_list = file_item.control_path + if not isinstance(file_item.control_path, list): + control_path_list = [control_path_list] + for i in range(len(control_path_list)): + try: + img = Image.open(control_path_list[i]).convert("RGB") + img = exif_transpose(img) + # convert to 0 to 1 tensor + img = ( + TF.to_tensor(img) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(img) + except Exception as e: + print_acc(f"Error: {e}") + print_acc(f"Error loading control image: {control_path_list[i]}") + + if len(ctrl_img_list) == 0: + ctrl_img = None + elif not self.sd.has_multiple_control_images: + ctrl_img = ctrl_img_list[0] + else: + ctrl_img = ctrl_img_list prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption, control_images=ctrl_img) else: prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption) diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index af5ea06f..58d48b4f 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -181,6 +181,10 @@ class BaseModel: # set true for models that encode control image into text embeddings self.encode_control_in_text_embeddings = False + # control images will come in as a list for encoding some things if true + self.has_multiple_control_images = False + # do not resize control images + self.use_raw_control_images = False # properties for old arch for backwards compatibility @property diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 53bd689a..78960ed1 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -219,6 +219,10 @@ class StableDiffusion: # set true for models that encode control image into text embeddings self.encode_control_in_text_embeddings = False + # control images will come in as a list for encoding some things if true + self.has_multiple_control_images = False + # do not resize control images + self.use_raw_control_images = False # properties for old arch for backwards compatibility @property