diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py index 8e366d03..1f5a376f 100644 --- a/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py @@ -161,7 +161,7 @@ class QwenImageEditPlusModel(QwenImageModel): # we get the control image from the batch return latents.detach() - def get_prompt_embeds(self, prompt: str, control_images=None) -> PromptEmbeds: + def get_prompt_embeds(self, prompt: List, control_images=None) -> PromptEmbeds: # todo handle not caching text encoder if self.pipeline.text_encoder.device != self.device_torch: self.pipeline.text_encoder.to(self.device_torch) @@ -169,36 +169,54 @@ class QwenImageEditPlusModel(QwenImageModel): if control_images is None: raise ValueError("Missing control images for QwenImageEditPlusModel") - if not isinstance(control_images, List): + if not isinstance(control_images, list): control_images = [control_images] + + # expects a list of list of control images List[List[Tensor]] where each item corresponds to a batch item, + # and each item in the inner list corresponds to a control image for that batch item. + # for single image/caching, it may come in as just List[Tensor], so we handle that case by wrapping it in another list + if not isinstance(control_images[0], list): + control_images = [control_images] + + if len(prompt) != len(control_images): + raise ValueError("Number of prompts must match number of control image sets") + + prompt_embeds_list = [] + prompt_embeds_mask_list = [] + + for b in range(len(prompt)): + batch_control_images = control_images[b] - if control_images is not None and len(control_images) > 0: - for i in range(len(control_images)): + for i in range(len(batch_control_images)): + if len(batch_control_images[i].shape) == 3: + batch_control_images[i] = batch_control_images[i].unsqueeze(0) # control images are 0 - 1 scale, shape (bs, ch, height, width) - ratio = control_images[i].shape[2] / control_images[i].shape[3] + ratio = batch_control_images[i].shape[2] / batch_control_images[i].shape[3] height = math.sqrt(CONDITION_IMAGE_SIZE * ratio) width = height / ratio width = round(width / 32) * 32 height = round(height / 32) * 32 - control_images[i] = F.interpolate( - control_images[i], size=(height, width), mode="bilinear" + batch_control_images[i] = F.interpolate( + batch_control_images[i], size=(height, width), mode="bilinear" ) - prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt( - prompt, - image=control_images, - device=self.device_torch, - num_images_per_prompt=1, - ) - # diffusers >=0.37 returns None when all tokens are valid (no padding) - if prompt_embeds_mask is None: - prompt_embeds_mask = torch.ones( - prompt_embeds.shape[:2], device=prompt_embeds.device, dtype=torch.int64 + prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt( + prompt, + image=batch_control_images, + device=self.device_torch, + num_images_per_prompt=1, ) - pe = PromptEmbeds(prompt_embeds) - pe.attention_mask = prompt_embeds_mask + # diffusers >=0.37 returns None when all tokens are valid (no padding) + if prompt_embeds_mask is None: + prompt_embeds_mask = torch.ones( + prompt_embeds.shape[:2], device=prompt_embeds.device, dtype=torch.int64 + ) + prompt_embeds_list.append(prompt_embeds) + prompt_embeds_mask_list.append(prompt_embeds_mask) + pe = PromptEmbeds(torch.cat(prompt_embeds_list, dim=0)) + pe.attention_mask = torch.cat(prompt_embeds_mask_list, dim=0) return pe def get_noise_prediction( diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 8d1bcffa..63c79402 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1602,6 +1602,8 @@ class SDTrainer(BaseSDTrainProcess): self.sd.text_encoder.eval() if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = False + if self.sd.encode_control_in_text_embeddings and batch.control_tensor_list is not None: + prompt_kwargs['control_images'] = batch.control_tensor_list conditional_embeds = self.sd.encode_prompt( conditioned_prompts, prompt_2, dropout_prob=self.train_config.prompt_dropout_prob,