Initial support for qwen image edit plus

This commit is contained in:
Jaret Burkett
2025-09-24 11:39:10 -06:00
parent f74475161e
commit 454be0958a
11 changed files with 445 additions and 32 deletions

View File

@@ -850,6 +850,9 @@ class ControlFileItemDTOMixin:
self.has_control_image = False
self.control_path: Union[str, List[str], None] = None
self.control_tensor: Union[torch.Tensor, None] = None
self.control_tensor_list: Union[List[torch.Tensor], None] = None
sd = kwargs.get('sd', None)
self.use_raw_control_images = sd is not None and sd.use_raw_control_images
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
self.full_size_control_images = False
if dataset_config.control_path is not None:
@@ -900,23 +903,14 @@ class ControlFileItemDTOMixin:
except Exception as e:
print_acc(f"Error: {e}")
print_acc(f"Error loading image: {control_path}")
if not self.full_size_control_images:
# we just scale them to 512x512:
w, h = img.size
img = img.resize((512, 512), Image.BICUBIC)
else:
elif not self.use_raw_control_images:
w, h = img.size
if w > h and self.scale_to_width < self.scale_to_height:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
elif h > w and self.scale_to_height < self.scale_to_width:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
if self.flip_x:
# do a flip
img = img.transpose(Image.FLIP_LEFT_RIGHT)
@@ -950,11 +944,15 @@ class ControlFileItemDTOMixin:
self.control_tensor = None
elif len(control_tensors) == 1:
self.control_tensor = control_tensors[0]
elif self.use_raw_control_images:
# just send the list of tensors as their shapes wont match
self.control_tensor_list = control_tensors
else:
self.control_tensor = torch.stack(control_tensors, dim=0)
def cleanup_control(self: 'FileItemDTO'):
self.control_tensor = None
self.control_tensor_list = None
class ClipImageFileItemDTOMixin:
@@ -1884,14 +1882,31 @@ class TextEmbeddingCachingMixin:
if file_item.encode_control_in_text_embeddings:
if file_item.control_path is None:
raise Exception(f"Could not find a control image for {file_item.path} which is needed for this model")
# load the control image and feed it into the text encoder
ctrl_img = Image.open(file_item.control_path).convert("RGB")
# convert to 0 to 1 tensor
ctrl_img = (
TF.to_tensor(ctrl_img)
.unsqueeze(0)
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
)
ctrl_img_list = []
control_path_list = file_item.control_path
if not isinstance(file_item.control_path, list):
control_path_list = [control_path_list]
for i in range(len(control_path_list)):
try:
img = Image.open(control_path_list[i]).convert("RGB")
img = exif_transpose(img)
# convert to 0 to 1 tensor
img = (
TF.to_tensor(img)
.unsqueeze(0)
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
)
ctrl_img_list.append(img)
except Exception as e:
print_acc(f"Error: {e}")
print_acc(f"Error loading control image: {control_path_list[i]}")
if len(ctrl_img_list) == 0:
ctrl_img = None
elif not self.sd.has_multiple_control_images:
ctrl_img = ctrl_img_list[0]
else:
ctrl_img = ctrl_img_list
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption, control_images=ctrl_img)
else:
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)