Fixed issue with qwen image edit models when using multiple control images when not caching text embeddings.

This commit is contained in:
Jaret Burkett
2026-04-30 11:59:12 +00:00
parent deb409085a
commit e9ab387dfd
2 changed files with 39 additions and 19 deletions

View File

@@ -161,7 +161,7 @@ class QwenImageEditPlusModel(QwenImageModel):
# we get the control image from the batch
return latents.detach()
def get_prompt_embeds(self, prompt: str, control_images=None) -> PromptEmbeds:
def get_prompt_embeds(self, prompt: List, control_images=None) -> PromptEmbeds:
# todo handle not caching text encoder
if self.pipeline.text_encoder.device != self.device_torch:
self.pipeline.text_encoder.to(self.device_torch)
@@ -169,36 +169,54 @@ class QwenImageEditPlusModel(QwenImageModel):
if control_images is None:
raise ValueError("Missing control images for QwenImageEditPlusModel")
if not isinstance(control_images, List):
if not isinstance(control_images, list):
control_images = [control_images]
# expects a list of list of control images List[List[Tensor]] where each item corresponds to a batch item,
# and each item in the inner list corresponds to a control image for that batch item.
# for single image/caching, it may come in as just List[Tensor], so we handle that case by wrapping it in another list
if not isinstance(control_images[0], list):
control_images = [control_images]
if len(prompt) != len(control_images):
raise ValueError("Number of prompts must match number of control image sets")
prompt_embeds_list = []
prompt_embeds_mask_list = []
for b in range(len(prompt)):
batch_control_images = control_images[b]
if control_images is not None and len(control_images) > 0:
for i in range(len(control_images)):
for i in range(len(batch_control_images)):
if len(batch_control_images[i].shape) == 3:
batch_control_images[i] = batch_control_images[i].unsqueeze(0)
# control images are 0 - 1 scale, shape (bs, ch, height, width)
ratio = control_images[i].shape[2] / control_images[i].shape[3]
ratio = batch_control_images[i].shape[2] / batch_control_images[i].shape[3]
height = math.sqrt(CONDITION_IMAGE_SIZE * ratio)
width = height / ratio
width = round(width / 32) * 32
height = round(height / 32) * 32
control_images[i] = F.interpolate(
control_images[i], size=(height, width), mode="bilinear"
batch_control_images[i] = F.interpolate(
batch_control_images[i], size=(height, width), mode="bilinear"
)
prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt(
prompt,
image=control_images,
device=self.device_torch,
num_images_per_prompt=1,
)
# diffusers >=0.37 returns None when all tokens are valid (no padding)
if prompt_embeds_mask is None:
prompt_embeds_mask = torch.ones(
prompt_embeds.shape[:2], device=prompt_embeds.device, dtype=torch.int64
prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt(
prompt,
image=batch_control_images,
device=self.device_torch,
num_images_per_prompt=1,
)
pe = PromptEmbeds(prompt_embeds)
pe.attention_mask = prompt_embeds_mask
# diffusers >=0.37 returns None when all tokens are valid (no padding)
if prompt_embeds_mask is None:
prompt_embeds_mask = torch.ones(
prompt_embeds.shape[:2], device=prompt_embeds.device, dtype=torch.int64
)
prompt_embeds_list.append(prompt_embeds)
prompt_embeds_mask_list.append(prompt_embeds_mask)
pe = PromptEmbeds(torch.cat(prompt_embeds_list, dim=0))
pe.attention_mask = torch.cat(prompt_embeds_mask_list, dim=0)
return pe
def get_noise_prediction(

View File

@@ -1602,6 +1602,8 @@ class SDTrainer(BaseSDTrainProcess):
self.sd.text_encoder.eval()
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
if self.sd.encode_control_in_text_embeddings and batch.control_tensor_list is not None:
prompt_kwargs['control_images'] = batch.control_tensor_list
conditional_embeds = self.sd.encode_prompt(
conditioned_prompts, prompt_2,
dropout_prob=self.train_config.prompt_dropout_prob,