mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-11 08:20:35 +00:00
Fixed issue with qwen image edit models when using multiple control images when not caching text embeddings.
This commit is contained in:
@@ -161,7 +161,7 @@ class QwenImageEditPlusModel(QwenImageModel):
|
||||
# we get the control image from the batch
|
||||
return latents.detach()
|
||||
|
||||
def get_prompt_embeds(self, prompt: str, control_images=None) -> PromptEmbeds:
|
||||
def get_prompt_embeds(self, prompt: List, control_images=None) -> PromptEmbeds:
|
||||
# todo handle not caching text encoder
|
||||
if self.pipeline.text_encoder.device != self.device_torch:
|
||||
self.pipeline.text_encoder.to(self.device_torch)
|
||||
@@ -169,36 +169,54 @@ class QwenImageEditPlusModel(QwenImageModel):
|
||||
if control_images is None:
|
||||
raise ValueError("Missing control images for QwenImageEditPlusModel")
|
||||
|
||||
if not isinstance(control_images, List):
|
||||
if not isinstance(control_images, list):
|
||||
control_images = [control_images]
|
||||
|
||||
# expects a list of list of control images List[List[Tensor]] where each item corresponds to a batch item,
|
||||
# and each item in the inner list corresponds to a control image for that batch item.
|
||||
# for single image/caching, it may come in as just List[Tensor], so we handle that case by wrapping it in another list
|
||||
if not isinstance(control_images[0], list):
|
||||
control_images = [control_images]
|
||||
|
||||
if len(prompt) != len(control_images):
|
||||
raise ValueError("Number of prompts must match number of control image sets")
|
||||
|
||||
prompt_embeds_list = []
|
||||
prompt_embeds_mask_list = []
|
||||
|
||||
for b in range(len(prompt)):
|
||||
batch_control_images = control_images[b]
|
||||
|
||||
if control_images is not None and len(control_images) > 0:
|
||||
for i in range(len(control_images)):
|
||||
for i in range(len(batch_control_images)):
|
||||
if len(batch_control_images[i].shape) == 3:
|
||||
batch_control_images[i] = batch_control_images[i].unsqueeze(0)
|
||||
# control images are 0 - 1 scale, shape (bs, ch, height, width)
|
||||
ratio = control_images[i].shape[2] / control_images[i].shape[3]
|
||||
ratio = batch_control_images[i].shape[2] / batch_control_images[i].shape[3]
|
||||
height = math.sqrt(CONDITION_IMAGE_SIZE * ratio)
|
||||
width = height / ratio
|
||||
|
||||
width = round(width / 32) * 32
|
||||
height = round(height / 32) * 32
|
||||
|
||||
control_images[i] = F.interpolate(
|
||||
control_images[i], size=(height, width), mode="bilinear"
|
||||
batch_control_images[i] = F.interpolate(
|
||||
batch_control_images[i], size=(height, width), mode="bilinear"
|
||||
)
|
||||
|
||||
prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt(
|
||||
prompt,
|
||||
image=control_images,
|
||||
device=self.device_torch,
|
||||
num_images_per_prompt=1,
|
||||
)
|
||||
# diffusers >=0.37 returns None when all tokens are valid (no padding)
|
||||
if prompt_embeds_mask is None:
|
||||
prompt_embeds_mask = torch.ones(
|
||||
prompt_embeds.shape[:2], device=prompt_embeds.device, dtype=torch.int64
|
||||
prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt(
|
||||
prompt,
|
||||
image=batch_control_images,
|
||||
device=self.device_torch,
|
||||
num_images_per_prompt=1,
|
||||
)
|
||||
pe = PromptEmbeds(prompt_embeds)
|
||||
pe.attention_mask = prompt_embeds_mask
|
||||
# diffusers >=0.37 returns None when all tokens are valid (no padding)
|
||||
if prompt_embeds_mask is None:
|
||||
prompt_embeds_mask = torch.ones(
|
||||
prompt_embeds.shape[:2], device=prompt_embeds.device, dtype=torch.int64
|
||||
)
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
prompt_embeds_mask_list.append(prompt_embeds_mask)
|
||||
pe = PromptEmbeds(torch.cat(prompt_embeds_list, dim=0))
|
||||
pe.attention_mask = torch.cat(prompt_embeds_mask_list, dim=0)
|
||||
return pe
|
||||
|
||||
def get_noise_prediction(
|
||||
|
||||
@@ -1602,6 +1602,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.sd.text_encoder.eval()
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
if self.sd.encode_control_in_text_embeddings and batch.control_tensor_list is not None:
|
||||
prompt_kwargs['control_images'] = batch.control_tensor_list
|
||||
conditional_embeds = self.sd.encode_prompt(
|
||||
conditioned_prompts, prompt_2,
|
||||
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||
|
||||
Reference in New Issue
Block a user