mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Initial support for qwen image edit plus
This commit is contained in:
@@ -850,6 +850,9 @@ class ControlFileItemDTOMixin:
|
||||
self.has_control_image = False
|
||||
self.control_path: Union[str, List[str], None] = None
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
self.control_tensor_list: Union[List[torch.Tensor], None] = None
|
||||
sd = kwargs.get('sd', None)
|
||||
self.use_raw_control_images = sd is not None and sd.use_raw_control_images
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
self.full_size_control_images = False
|
||||
if dataset_config.control_path is not None:
|
||||
@@ -900,23 +903,14 @@ class ControlFileItemDTOMixin:
|
||||
except Exception as e:
|
||||
print_acc(f"Error: {e}")
|
||||
print_acc(f"Error loading image: {control_path}")
|
||||
|
||||
|
||||
if not self.full_size_control_images:
|
||||
# we just scale them to 512x512:
|
||||
w, h = img.size
|
||||
img = img.resize((512, 512), Image.BICUBIC)
|
||||
|
||||
else:
|
||||
elif not self.use_raw_control_images:
|
||||
w, h = img.size
|
||||
if w > h and self.scale_to_width < self.scale_to_height:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
elif h > w and self.scale_to_height < self.scale_to_width:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
|
||||
if self.flip_x:
|
||||
# do a flip
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
@@ -950,11 +944,15 @@ class ControlFileItemDTOMixin:
|
||||
self.control_tensor = None
|
||||
elif len(control_tensors) == 1:
|
||||
self.control_tensor = control_tensors[0]
|
||||
elif self.use_raw_control_images:
|
||||
# just send the list of tensors as their shapes wont match
|
||||
self.control_tensor_list = control_tensors
|
||||
else:
|
||||
self.control_tensor = torch.stack(control_tensors, dim=0)
|
||||
|
||||
def cleanup_control(self: 'FileItemDTO'):
|
||||
self.control_tensor = None
|
||||
self.control_tensor_list = None
|
||||
|
||||
|
||||
class ClipImageFileItemDTOMixin:
|
||||
@@ -1884,14 +1882,31 @@ class TextEmbeddingCachingMixin:
|
||||
if file_item.encode_control_in_text_embeddings:
|
||||
if file_item.control_path is None:
|
||||
raise Exception(f"Could not find a control image for {file_item.path} which is needed for this model")
|
||||
# load the control image and feed it into the text encoder
|
||||
ctrl_img = Image.open(file_item.control_path).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img = (
|
||||
TF.to_tensor(ctrl_img)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
ctrl_img_list = []
|
||||
control_path_list = file_item.control_path
|
||||
if not isinstance(file_item.control_path, list):
|
||||
control_path_list = [control_path_list]
|
||||
for i in range(len(control_path_list)):
|
||||
try:
|
||||
img = Image.open(control_path_list[i]).convert("RGB")
|
||||
img = exif_transpose(img)
|
||||
# convert to 0 to 1 tensor
|
||||
img = (
|
||||
TF.to_tensor(img)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
ctrl_img_list.append(img)
|
||||
except Exception as e:
|
||||
print_acc(f"Error: {e}")
|
||||
print_acc(f"Error loading control image: {control_path_list[i]}")
|
||||
|
||||
if len(ctrl_img_list) == 0:
|
||||
ctrl_img = None
|
||||
elif not self.sd.has_multiple_control_images:
|
||||
ctrl_img = ctrl_img_list[0]
|
||||
else:
|
||||
ctrl_img = ctrl_img_list
|
||||
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption, control_images=ctrl_img)
|
||||
else:
|
||||
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)
|
||||
|
||||
Reference in New Issue
Block a user