Initial support for qwen image edit plus

This commit is contained in:
Jaret Burkett
2025-09-24 11:39:10 -06:00
parent f74475161e
commit 454be0958a
11 changed files with 445 additions and 32 deletions

View File

@@ -144,6 +144,7 @@ class DataLoaderBatchDTO:
self.tensor: Union[torch.Tensor, None] = None
self.latents: Union[torch.Tensor, None] = None
self.control_tensor: Union[torch.Tensor, None] = None
self.control_tensor_list: Union[List[List[torch.Tensor]], None] = None
self.clip_image_tensor: Union[torch.Tensor, None] = None
self.mask_tensor: Union[torch.Tensor, None] = None
self.unaugmented_tensor: Union[torch.Tensor, None] = None
@@ -160,7 +161,6 @@ class DataLoaderBatchDTO:
self.latents: Union[torch.Tensor, None] = None
if is_latents_cached:
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
self.control_tensor: Union[torch.Tensor, None] = None
self.prompt_embeds: Union[PromptEmbeds, None] = None
# if self.file_items[0].control_tensor is not None:
# if any have a control tensor, we concatenate them
@@ -178,6 +178,16 @@ class DataLoaderBatchDTO:
else:
control_tensors.append(x.control_tensor)
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
# handle control tensor list
if any([x.control_tensor_list is not None for x in self.file_items]):
self.control_tensor_list = []
for x in self.file_items:
if x.control_tensor_list is not None:
self.control_tensor_list.append(x.control_tensor_list)
else:
raise Exception(f"Could not find control tensors for all file items, missing for {x.path}")
self.inpaint_tensor: Union[torch.Tensor, None] = None
if any([x.inpaint_tensor is not None for x in self.file_items]):