Initial support for qwen image edit plus

This commit is contained in:
Jaret Burkett
2025-09-24 11:39:10 -06:00
parent f74475161e
commit 454be0958a
11 changed files with 445 additions and 32 deletions

View File

@@ -129,17 +129,64 @@ class SDTrainer(BaseSDTrainProcess):
prompt=prompt, # it will autoparse the prompt
negative_prompt=sample_item.neg,
output_path=output_path,
ctrl_img=sample_item.ctrl_img
ctrl_img=sample_item.ctrl_img,
ctrl_img_1=sample_item.ctrl_img_1,
ctrl_img_2=sample_item.ctrl_img_2,
ctrl_img_3=sample_item.ctrl_img_3,
)
has_control_images = False
if gen_img_config.ctrl_img is not None or gen_img_config.ctrl_img_1 is not None or gen_img_config.ctrl_img_2 is not None or gen_img_config.ctrl_img_3 is not None:
has_control_images = True
# see if we need to encode the control images
if self.sd.encode_control_in_text_embeddings and gen_img_config.ctrl_img is not None:
ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB")
# convert to 0 to 1 tensor
ctrl_img = (
TF.to_tensor(ctrl_img)
.unsqueeze(0)
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
)
if self.sd.encode_control_in_text_embeddings and has_control_images:
ctrl_img_list = []
if gen_img_config.ctrl_img is not None:
ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB")
# convert to 0 to 1 tensor
ctrl_img = (
TF.to_tensor(ctrl_img)
.unsqueeze(0)
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
)
ctrl_img_list.append(ctrl_img)
if gen_img_config.ctrl_img_1 is not None:
ctrl_img_1 = Image.open(gen_img_config.ctrl_img_1).convert("RGB")
# convert to 0 to 1 tensor
ctrl_img_1 = (
TF.to_tensor(ctrl_img_1)
.unsqueeze(0)
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
)
ctrl_img_list.append(ctrl_img_1)
if gen_img_config.ctrl_img_2 is not None:
ctrl_img_2 = Image.open(gen_img_config.ctrl_img_2).convert("RGB")
# convert to 0 to 1 tensor
ctrl_img_2 = (
TF.to_tensor(ctrl_img_2)
.unsqueeze(0)
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
)
ctrl_img_list.append(ctrl_img_2)
if gen_img_config.ctrl_img_3 is not None:
ctrl_img_3 = Image.open(gen_img_config.ctrl_img_3).convert("RGB")
# convert to 0 to 1 tensor
ctrl_img_3 = (
TF.to_tensor(ctrl_img_3)
.unsqueeze(0)
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
)
ctrl_img_list.append(ctrl_img_3)
if self.sd.has_multiple_control_images:
ctrl_img = ctrl_img_list
else:
ctrl_img = ctrl_img_list[0] if len(ctrl_img_list) > 0 else None
positive = self.sd.encode_prompt(
gen_img_config.prompt,
control_images=ctrl_img
@@ -202,6 +249,9 @@ class SDTrainer(BaseSDTrainProcess):
if self.sd.encode_control_in_text_embeddings:
# just do a blank image for unconditionals
control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype)
if self.sd.has_multiple_control_images:
control_image = [control_image]
kwargs['control_images'] = control_image
self.unconditional_embeds = self.sd.encode_prompt(
[self.train_config.unconditional_prompt],
@@ -272,6 +322,8 @@ class SDTrainer(BaseSDTrainProcess):
if self.sd.encode_control_in_text_embeddings:
# just do a blank image for unconditionals
control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype)
if self.sd.has_multiple_control_images:
control_image = [control_image]
encode_kwargs['control_images'] = control_image
self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs)
if self.trigger_word is not None: