diff --git a/extensions_built_in/diffusion_models/__init__.py b/extensions_built_in/diffusion_models/__init__.py index d8bf232b..5682bc46 100644 --- a/extensions_built_in/diffusion_models/__init__.py +++ b/extensions_built_in/diffusion_models/__init__.py @@ -4,7 +4,7 @@ from .f_light import FLiteModel from .omnigen2 import OmniGen2Model from .flux_kontext import FluxKontextModel from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel -from .qwen_image import QwenImageModel +from .qwen_image import QwenImageModel, QwenImageEditModel AI_TOOLKIT_MODELS = [ # put a list of models here @@ -18,4 +18,5 @@ AI_TOOLKIT_MODELS = [ Wan2214bI2VModel, Wan2214bModel, QwenImageModel, + QwenImageEditModel, ] diff --git a/extensions_built_in/diffusion_models/qwen_image/__init__.py b/extensions_built_in/diffusion_models/qwen_image/__init__.py index df7af5e0..d8b32a85 100644 --- a/extensions_built_in/diffusion_models/qwen_image/__init__.py +++ b/extensions_built_in/diffusion_models/qwen_image/__init__.py @@ -1 +1,2 @@ from .qwen_image import QwenImageModel +from .qwen_image_edit import QwenImageEditModel \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py index 3c002878..07dcde15 100644 --- a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py @@ -16,7 +16,7 @@ from toolkit.util.quantize import quantize, get_qtype, quantize_model import torch.nn.functional as F from diffusers import QwenImagePipeline, QwenImageTransformer2DModel, AutoencoderKLQwenImage -from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor from tqdm import tqdm if TYPE_CHECKING: @@ -43,7 +43,8 @@ scheduler_config = { class QwenImageModel(BaseModel): arch = "qwen_image" - _qwen_image_keep_processor = False + _qwen_image_keep_visual = False + _qwen_pipeline = QwenImagePipeline def __init__( self, @@ -119,10 +120,9 @@ class QwenImageModel(BaseModel): # remove the visual model as it is not needed for image generation self.processor = None - if self._qwen_image_keep_processor: - self.processor = text_encoder.model.visual - text_encoder.model.visual = None - + if not self._qwen_image_keep_visual: + text_encoder.model.visual = None + text_encoder.to(self.device_torch, dtype=dtype) flush() @@ -140,13 +140,27 @@ class QwenImageModel(BaseModel): self.noise_scheduler = QwenImageModel.get_train_scheduler() self.print_and_status_update("Making pipe") + + kwargs = {} + + if self._qwen_image_keep_visual: + try: + self.processor = Qwen2VLProcessor.from_pretrained( + model_path, subfolder="processor" + ) + except OSError: + self.processor = Qwen2VLProcessor.from_pretrained( + base_model_path, subfolder="processor" + ) + kwargs['processor'] = self.processor - pipe: QwenImagePipeline = QwenImagePipeline( + pipe: QwenImagePipeline = self._qwen_pipeline( scheduler=self.noise_scheduler, text_encoder=None, tokenizer=tokenizer, vae=vae, transformer=None, + **kwargs ) # for quantization, it works best to do these after making the pipe pipe.text_encoder = text_encoder @@ -261,21 +275,13 @@ class QwenImageModel(BaseModel): latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5) latent_model_input = latent_model_input.reshape(batch_size, (height // ps) * (width // ps), num_channels_latents * (ps * ps)) - # clamp text length to RoPE capacity for this image size # img_shapes passed to the model img_h2, img_w2 = height // ps, width // ps - img_shapes = [(1, img_h2, img_w2)] * batch_size + img_shapes = [[(1, img_h2, img_w2)]] * batch_size - # QwenEmbedRope logic: - max_vid_index = max(img_h2 // ps, img_w2 // ps) - - rope_cap = 1024 - max_vid_index # available text positions in RoPE cache - seq_len_actual = text_embeddings.text_embeds.shape[1] - use_len = min(seq_len_actual, rope_cap) - - enc_hs = text_embeddings.text_embeds[:, :use_len].to(self.device_torch, self.torch_dtype) - prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64)[:, :use_len] - txt_seq_lens = [use_len] * batch_size + enc_hs = text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype) + prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64) + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() noise_pred = self.transformer( hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit.py new file mode 100644 index 00000000..bcc8d735 --- /dev/null +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit.py @@ -0,0 +1,276 @@ +import math +import torch +from .qwen_image import QwenImageModel +import os +from typing import TYPE_CHECKING, List, Optional +import yaml +from toolkit import train_tools +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from PIL import Image +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.accelerator import get_accelerator, unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.quantize import quantize, get_qtype, quantize_model +import torch.nn.functional as F + +from diffusers import ( + QwenImagePipeline, + QwenImageTransformer2DModel, + AutoencoderKLQwenImage, +) +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer +from tqdm import tqdm + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +try: + from diffusers import QwenImageEditPipeline +except ImportError: + raise ImportError( + "QwenImageEditPipeline not found. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt" + ) + + +class QwenImageEditModel(QwenImageModel): + arch = "qwen_image_edit" + _qwen_image_keep_visual = True + _qwen_pipeline = QwenImageEditPipeline + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["QwenImageTransformer2DModel"] + + # set true for models that encode control image into text embeddings + self.encode_control_in_text_embeddings = True + + def load_model(self): + super().load_model() + + def get_generation_pipeline(self): + scheduler = QwenImageModel.get_train_scheduler() + + pipeline: QwenImageEditPipeline = QwenImageEditPipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + processor=self.processor, + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer), + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: QwenImageEditPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + self.model.to(self.device_torch, dtype=self.torch_dtype) + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + + control_img = None + if gen_config.ctrl_img is not None: + control_img = Image.open(gen_config.ctrl_img) + control_img = control_img.convert("RGB") + # resize to width and height + if control_img.size != (gen_config.width, gen_config.height): + control_img = control_img.resize( + (gen_config.width, gen_config.height), Image.BILINEAR + ) + + # flush for low vram if we are doing that + flush_between_steps = self.model_config.low_vram + + # Fix a bug in diffusers/torch + def callback_on_step_end(pipe, i, t, callback_kwargs): + if flush_between_steps: + flush() + latents = callback_kwargs["latents"] + + return {"latents": latents} + + img = pipeline( + image=control_img, + prompt_embeds=conditional_embeds.text_embeds, + prompt_embeds_mask=conditional_embeds.attention_mask.to( + self.device_torch, dtype=torch.int64 + ), + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_prompt_embeds_mask=unconditional_embeds.attention_mask.to( + self.device_torch, dtype=torch.int64 + ), + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + true_cfg_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + callback_on_step_end=callback_on_step_end, + **extra, + ).images[0] + return img + + def condition_noisy_latents( + self, latents: torch.Tensor, batch: "DataLoaderBatchDTO" + ): + with torch.no_grad(): + control_tensor = batch.control_tensor + if control_tensor is not None: + self.vae.to(self.device_torch) + # we are not packed here, so we just need to pass them so we can pack them later + control_tensor = control_tensor * 2 - 1 + control_tensor = control_tensor.to( + self.vae_device_torch, dtype=self.torch_dtype + ) + + # if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it + if batch.tensor is not None: + target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3] + else: + # When caching latents, batch.tensor is None. We get the size from the file_items instead. + target_h = batch.file_items[0].crop_height + target_w = batch.file_items[0].crop_width + + if ( + control_tensor.shape[2] != target_h + or control_tensor.shape[3] != target_w + ): + control_tensor = F.interpolate( + control_tensor, size=(target_h, target_w), mode="bilinear" + ) + + control_latent = self.encode_images(control_tensor).to( + latents.device, latents.dtype + ) + latents = torch.cat((latents, control_latent), dim=1) + + return latents.detach() + + def get_prompt_embeds(self, prompt: str, control_images=None) -> PromptEmbeds: + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + if control_images is not None: + # control images are 0 - 1 scale, shape (bs, ch, height, width) + # images are always run through at 1MP, based on diffusers inference code. + target_area = 1024 * 1024 + ratio = control_images.shape[2] / control_images.shape[3] + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + control_images = F.interpolate( + control_images, size=(height, width), mode="bilinear" + ) + + prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt( + prompt, + image=control_images, + device=self.device_torch, + num_images_per_prompt=1, + ) + pe = PromptEmbeds(prompt_embeds) + pe.attention_mask = prompt_embeds_mask + return pe + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs, + ): + # control is stacked on channels, move it to the batch dimension for packing + latent_model_input, control = torch.chunk(latent_model_input, 2, 1) + + batch_size, num_channels_latents, height, width = latent_model_input.shape + ( + control_batch_size, + control_num_channels_latents, + control_height, + control_width, + ) = control.shape + + # pack image tokens + latent_model_input = latent_model_input.view( + batch_size, num_channels_latents, height // 2, 2, width // 2, 2 + ) + latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5) + latent_model_input = latent_model_input.reshape( + batch_size, (height // 2) * (width // 2), num_channels_latents * 4 + ) + + # pack control + control = control.view( + batch_size, num_channels_latents, height // 2, 2, width // 2, 2 + ) + control = control.permute(0, 2, 4, 1, 3, 5) + control = control.reshape( + batch_size, (height // 2) * (width // 2), num_channels_latents * 4 + ) + + img_h2, img_w2 = height // 2, width // 2 + control_img_h2, control_img_w2 = control_height // 2, control_width // 2 + + img_shapes = [[(1, img_h2, img_w2), (1, control_img_h2, control_img_w2)]] * batch_size + + latents = latent_model_input + latent_model_input = torch.cat([latent_model_input, control], dim=1) + batch_size = latent_model_input.shape[0] + + prompt_embeds_mask = text_embeddings.attention_mask.to( + self.device_torch, dtype=torch.int64 + ) + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() + enc_hs = text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype) + prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64) + + noise_pred = self.transformer( + hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=enc_hs, + encoder_hidden_states_mask=prompt_embeds_mask, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + return_dict=False, + **kwargs, + )[0] + + noise_pred = noise_pred[:, : latents.size(1)] + + # unpack + noise_pred = noise_pred.view( + batch_size, height // 2, width // 2, num_channels_latents, 2, 2 + ) + noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5) + noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width) + return noise_pred diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index d9d7aff5..7755d22e 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -37,6 +37,8 @@ from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtracto from toolkit.util.wavelet_loss import wavelet_loss import torch.nn.functional as F from toolkit.unloader import unload_text_encoder +from PIL import Image +from torchvision.transforms import functional as TF def flush(): @@ -127,9 +129,28 @@ class SDTrainer(BaseSDTrainProcess): prompt=prompt, # it will autoparse the prompt negative_prompt=sample_item.neg, output_path=output_path, + ctrl_img=sample_item.ctrl_img ) - positive = self.sd.encode_prompt(gen_img_config.prompt).to('cpu') - negative = self.sd.encode_prompt(gen_img_config.negative_prompt).to('cpu') + # see if we need to encode the control images + if self.sd.encode_control_in_text_embeddings and gen_img_config.ctrl_img is not None: + ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img = ( + TF.to_tensor(ctrl_img) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + positive = self.sd.encode_prompt( + gen_img_config.prompt, + control_images=ctrl_img + ).to('cpu') + negative = self.sd.encode_prompt( + gen_img_config.negative_prompt, + control_images=ctrl_img + ).to('cpu') + else: + positive = self.sd.encode_prompt(gen_img_config.prompt).to('cpu') + negative = self.sd.encode_prompt(gen_img_config.negative_prompt).to('cpu') self.sd.sample_prompts_cache.append({ 'conditional': positive, @@ -177,9 +198,15 @@ class SDTrainer(BaseSDTrainProcess): # cache unconditional embeds (blank prompt) with torch.no_grad(): + kwargs = {} + if self.sd.encode_control_in_text_embeddings: + # just do a blank image for unconditionals + control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype) + kwargs['control_images'] = control_image self.unconditional_embeds = self.sd.encode_prompt( [self.train_config.unconditional_prompt], - long_prompts=self.do_long_prompts + long_prompts=self.do_long_prompts, + **kwargs ).to( self.device_torch, dtype=self.sd.torch_dtype @@ -241,9 +268,14 @@ class SDTrainer(BaseSDTrainProcess): print_acc("***********************************") print_acc("") self.sd.text_encoder_to(self.device_torch) - self.cached_blank_embeds = self.sd.encode_prompt("") + encode_kwargs = {} + if self.sd.encode_control_in_text_embeddings: + # just do a blank image for unconditionals + control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype) + encode_kwargs['control_images'] = control_image + self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs) if self.trigger_word is not None: - self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word) + self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word, **encode_kwargs) if self.train_config.diff_output_preservation: self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class) diff --git a/requirements.txt b/requirements.txt index 8aa997d2..39fa3dd0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torchao==0.10.0 safetensors git+https://github.com/jaretburkett/easy_dwpose.git -git+https://github.com/huggingface/diffusers@7ea065c5070a5278259e6f1effa9dccea232e62a +git+https://github.com/huggingface/diffusers@7a2b78bf0f788d311cc96b61e660a8e13e3b1e63 transformers==4.52.4 lycoris-lora==1.8.3 flatten_json diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 403abc54..b6c0a3d7 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -1232,5 +1232,10 @@ def validate_configs( for dataset in dataset_configs: if not dataset.cache_text_embeddings: raise ValueError("All datasets must have cache_text_embeddings set to True when caching text embeddings is enabled.") + + # qwen image edit cannot cache text embeddings + if model_config.arch == 'qwen_image_edit': + if train_config.unload_text_encoder: + raise ValueError("Cannot cache unload text encoder with qwen_image_edit model. Control images are encoded with text embeddings. You can cache the text embeddings though") diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 43f03647..95075a61 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -497,6 +497,7 @@ class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin dataloader_transforms=self.transform, size_database=self.size_database, dataset_root=dataset_folder, + encode_control_in_text_embeddings=self.sd.encode_control_in_text_embeddings if self.sd else False, ) self.file_list.append(file_item) except Exception as e: diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index c37bec1c..b6863663 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -50,6 +50,7 @@ class FileItemDTO( self.is_video = self.dataset_config.num_frames > 1 size_database = kwargs.get('size_database', {}) dataset_root = kwargs.get('dataset_root', None) + self.encode_control_in_text_embeddings = kwargs.get('encode_control_in_text_embeddings', False) if dataset_root is not None: # remove dataset root from path file_key = self.path.replace(dataset_root, '') diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index e39be2b9..4a01d752 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -30,6 +30,7 @@ import albumentations as A from toolkit.print import print_acc from toolkit.accelerator import get_accelerator from toolkit.prompt_utils import PromptEmbeds +from torchvision.transforms import functional as TF from toolkit.train_tools import get_torch_dtype @@ -1802,6 +1803,9 @@ class TextEmbeddingFileItemDTOMixin: ("text_embedding_space_version", self.text_embedding_space_version), ("text_embedding_version", self.text_embedding_version), ]) + # if we have a control image, cache the path + if self.encode_control_in_text_embeddings and self.control_path is not None: + item["control_path"] = self.control_path return item def get_text_embedding_path(self: 'FileItemDTO', recalculate=False): @@ -1860,7 +1864,19 @@ class TextEmbeddingCachingMixin: if not did_move: self.sd.set_device_state_preset('cache_text_encoder') did_move = True - prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption) + + if file_item.encode_control_in_text_embeddings and file_item.control_path is not None: + # load the control image and feed it into the text encoder + ctrl_img = Image.open(file_item.control_path).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img = ( + TF.to_tensor(ctrl_img) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption, control_images=ctrl_img) + else: + prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption) # save it prompt_embeds.save(text_embedding_path) del prompt_embeds diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index d446c25e..af5ea06f 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -36,6 +36,7 @@ from diffusers import \ UNet2DConditionModel from diffusers import PixArtAlphaPipeline from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection +from torchvision.transforms import functional as TF from toolkit.accelerator import get_accelerator, unwrap_model from typing import TYPE_CHECKING @@ -177,6 +178,9 @@ class BaseModel: self.multistage_boundaries: List[float] = [0.0] # a list of trainable multistage boundaries self.trainable_multistage_boundaries: List[int] = [0] + + # set true for models that encode control image into text embeddings + self.encode_control_in_text_embeddings = False # properties for old arch for backwards compatibility @property @@ -287,7 +291,7 @@ class BaseModel: raise NotImplementedError( "get_noise_prediction must be implemented in child classes") - def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + def get_prompt_embeds(self, prompt: str, control_images=None) -> PromptEmbeds: raise NotImplementedError( "get_prompt_embeds must be implemented in child classes") @@ -496,17 +500,34 @@ class BaseModel: if self.sample_prompts_cache is not None: conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype) unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype) - else: + else: + ctrl_img = None + # load the control image if out model uses it in text encoding + if gen_config.ctrl_img is not None and self.encode_control_in_text_embeddings: + ctrl_img = Image.open(gen_config.ctrl_img).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img = ( + TF.to_tensor(ctrl_img) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + ) # encode the prompt ourselves so we can do fun stuff with embeddings if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = False conditional_embeds = self.encode_prompt( - gen_config.prompt, gen_config.prompt_2, force_all=True) + gen_config.prompt, + gen_config.prompt_2, + force_all=True, + control_images=ctrl_img + ) if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = True unconditional_embeds = self.encode_prompt( - gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True + gen_config.negative_prompt, + gen_config.negative_prompt_2, + force_all=True, + control_images=ctrl_img ) if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = False @@ -989,6 +1010,7 @@ class BaseModel: long_prompts=False, max_length=None, dropout_prob=0.0, + control_images=None, ) -> PromptEmbeds: # sd1.5 embeddings are (bs, 77, 768) prompt = prompt @@ -998,6 +1020,9 @@ class BaseModel: if prompt2 is not None and not isinstance(prompt2, list): prompt2 = [prompt2] + # if control_images in the signature, pass it. This keep from breaking plugins + if self.encode_control_in_text_embeddings: + return self.get_prompt_embeds(prompt, control_images=control_images) return self.get_prompt_embeds(prompt) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 86908884..b042d699 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -217,6 +217,9 @@ class StableDiffusion: # a list of trainable multistage boundaries self.trainable_multistage_boundaries: List[int] = [0] + # set true for models that encode control image into text embeddings + self.encode_control_in_text_embeddings = False + # properties for old arch for backwards compatibility @property def is_xl(self): @@ -2356,6 +2359,7 @@ class StableDiffusion: long_prompts=False, max_length=None, dropout_prob=0.0, + control_images=None, ) -> PromptEmbeds: # sd1.5 embeddings are (bs, 77, 768) prompt = prompt