Initial support for qwen image edit plus

This commit is contained in:
Jaret Burkett
2025-09-24 11:39:10 -06:00
parent f74475161e
commit 454be0958a
11 changed files with 445 additions and 32 deletions

View File

@@ -56,6 +56,11 @@ class SampleItem:
self.num_frames: int = kwargs.get('num_frames', sample_config.num_frames)
self.ctrl_img: Optional[str] = kwargs.get('ctrl_img', None)
self.ctrl_idx: int = kwargs.get('ctrl_idx', 0)
# for multi control image models
self.ctrl_img_1: Optional[str] = kwargs.get('ctrl_img_1', self.ctrl_img)
self.ctrl_img_2: Optional[str] = kwargs.get('ctrl_img_2', None)
self.ctrl_img_3: Optional[str] = kwargs.get('ctrl_img_3', None)
self.network_multiplier: float = kwargs.get('network_multiplier', sample_config.network_multiplier)
# convert to a number if it is a string
if isinstance(self.network_multiplier, str):
@@ -966,6 +971,9 @@ class GenerateImageConfig:
extra_values: List[float] = None, # extra values to save with prompt file
logger: Optional[EmptyLogger] = None,
ctrl_img: Optional[str] = None, # control image for controlnet
ctrl_img_1: Optional[str] = None, # first control image for multi control model
ctrl_img_2: Optional[str] = None, # second control image for multi control model
ctrl_img_3: Optional[str] = None, # third control image for multi control model
num_frames: int = 1,
fps: int = 15,
ctrl_idx: int = 0
@@ -1002,6 +1010,12 @@ class GenerateImageConfig:
self.ctrl_img = ctrl_img
self.ctrl_idx = ctrl_idx
if ctrl_img_1 is None and ctrl_img is not None:
ctrl_img_1 = ctrl_img
self.ctrl_img_1 = ctrl_img_1
self.ctrl_img_2 = ctrl_img_2
self.ctrl_img_3 = ctrl_img_3
# prompt string will override any settings above
self._process_prompt_string()

View File

@@ -144,6 +144,7 @@ class DataLoaderBatchDTO:
self.tensor: Union[torch.Tensor, None] = None
self.latents: Union[torch.Tensor, None] = None
self.control_tensor: Union[torch.Tensor, None] = None
self.control_tensor_list: Union[List[List[torch.Tensor]], None] = None
self.clip_image_tensor: Union[torch.Tensor, None] = None
self.mask_tensor: Union[torch.Tensor, None] = None
self.unaugmented_tensor: Union[torch.Tensor, None] = None
@@ -160,7 +161,6 @@ class DataLoaderBatchDTO:
self.latents: Union[torch.Tensor, None] = None
if is_latents_cached:
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
self.control_tensor: Union[torch.Tensor, None] = None
self.prompt_embeds: Union[PromptEmbeds, None] = None
# if self.file_items[0].control_tensor is not None:
# if any have a control tensor, we concatenate them
@@ -178,6 +178,16 @@ class DataLoaderBatchDTO:
else:
control_tensors.append(x.control_tensor)
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
# handle control tensor list
if any([x.control_tensor_list is not None for x in self.file_items]):
self.control_tensor_list = []
for x in self.file_items:
if x.control_tensor_list is not None:
self.control_tensor_list.append(x.control_tensor_list)
else:
raise Exception(f"Could not find control tensors for all file items, missing for {x.path}")
self.inpaint_tensor: Union[torch.Tensor, None] = None
if any([x.inpaint_tensor is not None for x in self.file_items]):

View File

@@ -850,6 +850,9 @@ class ControlFileItemDTOMixin:
self.has_control_image = False
self.control_path: Union[str, List[str], None] = None
self.control_tensor: Union[torch.Tensor, None] = None
self.control_tensor_list: Union[List[torch.Tensor], None] = None
sd = kwargs.get('sd', None)
self.use_raw_control_images = sd is not None and sd.use_raw_control_images
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
self.full_size_control_images = False
if dataset_config.control_path is not None:
@@ -900,23 +903,14 @@ class ControlFileItemDTOMixin:
except Exception as e:
print_acc(f"Error: {e}")
print_acc(f"Error loading image: {control_path}")
if not self.full_size_control_images:
# we just scale them to 512x512:
w, h = img.size
img = img.resize((512, 512), Image.BICUBIC)
else:
elif not self.use_raw_control_images:
w, h = img.size
if w > h and self.scale_to_width < self.scale_to_height:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
elif h > w and self.scale_to_height < self.scale_to_width:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
if self.flip_x:
# do a flip
img = img.transpose(Image.FLIP_LEFT_RIGHT)
@@ -950,11 +944,15 @@ class ControlFileItemDTOMixin:
self.control_tensor = None
elif len(control_tensors) == 1:
self.control_tensor = control_tensors[0]
elif self.use_raw_control_images:
# just send the list of tensors as their shapes wont match
self.control_tensor_list = control_tensors
else:
self.control_tensor = torch.stack(control_tensors, dim=0)
def cleanup_control(self: 'FileItemDTO'):
self.control_tensor = None
self.control_tensor_list = None
class ClipImageFileItemDTOMixin:
@@ -1884,14 +1882,31 @@ class TextEmbeddingCachingMixin:
if file_item.encode_control_in_text_embeddings:
if file_item.control_path is None:
raise Exception(f"Could not find a control image for {file_item.path} which is needed for this model")
# load the control image and feed it into the text encoder
ctrl_img = Image.open(file_item.control_path).convert("RGB")
# convert to 0 to 1 tensor
ctrl_img = (
TF.to_tensor(ctrl_img)
.unsqueeze(0)
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
)
ctrl_img_list = []
control_path_list = file_item.control_path
if not isinstance(file_item.control_path, list):
control_path_list = [control_path_list]
for i in range(len(control_path_list)):
try:
img = Image.open(control_path_list[i]).convert("RGB")
img = exif_transpose(img)
# convert to 0 to 1 tensor
img = (
TF.to_tensor(img)
.unsqueeze(0)
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
)
ctrl_img_list.append(img)
except Exception as e:
print_acc(f"Error: {e}")
print_acc(f"Error loading control image: {control_path_list[i]}")
if len(ctrl_img_list) == 0:
ctrl_img = None
elif not self.sd.has_multiple_control_images:
ctrl_img = ctrl_img_list[0]
else:
ctrl_img = ctrl_img_list
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption, control_images=ctrl_img)
else:
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)

View File

@@ -181,6 +181,10 @@ class BaseModel:
# set true for models that encode control image into text embeddings
self.encode_control_in_text_embeddings = False
# control images will come in as a list for encoding some things if true
self.has_multiple_control_images = False
# do not resize control images
self.use_raw_control_images = False
# properties for old arch for backwards compatibility
@property

View File

@@ -219,6 +219,10 @@ class StableDiffusion:
# set true for models that encode control image into text embeddings
self.encode_control_in_text_embeddings = False
# control images will come in as a list for encoding some things if true
self.has_multiple_control_images = False
# do not resize control images
self.use_raw_control_images = False
# properties for old arch for backwards compatibility
@property