mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Initial support for qwen image edit plus
This commit is contained in:
@@ -129,17 +129,64 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
prompt=prompt, # it will autoparse the prompt
|
||||
negative_prompt=sample_item.neg,
|
||||
output_path=output_path,
|
||||
ctrl_img=sample_item.ctrl_img
|
||||
ctrl_img=sample_item.ctrl_img,
|
||||
ctrl_img_1=sample_item.ctrl_img_1,
|
||||
ctrl_img_2=sample_item.ctrl_img_2,
|
||||
ctrl_img_3=sample_item.ctrl_img_3,
|
||||
)
|
||||
|
||||
has_control_images = False
|
||||
if gen_img_config.ctrl_img is not None or gen_img_config.ctrl_img_1 is not None or gen_img_config.ctrl_img_2 is not None or gen_img_config.ctrl_img_3 is not None:
|
||||
has_control_images = True
|
||||
# see if we need to encode the control images
|
||||
if self.sd.encode_control_in_text_embeddings and gen_img_config.ctrl_img is not None:
|
||||
ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img = (
|
||||
TF.to_tensor(ctrl_img)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
if self.sd.encode_control_in_text_embeddings and has_control_images:
|
||||
|
||||
ctrl_img_list = []
|
||||
|
||||
if gen_img_config.ctrl_img is not None:
|
||||
ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img = (
|
||||
TF.to_tensor(ctrl_img)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
ctrl_img_list.append(ctrl_img)
|
||||
|
||||
if gen_img_config.ctrl_img_1 is not None:
|
||||
ctrl_img_1 = Image.open(gen_img_config.ctrl_img_1).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img_1 = (
|
||||
TF.to_tensor(ctrl_img_1)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
ctrl_img_list.append(ctrl_img_1)
|
||||
if gen_img_config.ctrl_img_2 is not None:
|
||||
ctrl_img_2 = Image.open(gen_img_config.ctrl_img_2).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img_2 = (
|
||||
TF.to_tensor(ctrl_img_2)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
ctrl_img_list.append(ctrl_img_2)
|
||||
if gen_img_config.ctrl_img_3 is not None:
|
||||
ctrl_img_3 = Image.open(gen_img_config.ctrl_img_3).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img_3 = (
|
||||
TF.to_tensor(ctrl_img_3)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
ctrl_img_list.append(ctrl_img_3)
|
||||
|
||||
if self.sd.has_multiple_control_images:
|
||||
ctrl_img = ctrl_img_list
|
||||
else:
|
||||
ctrl_img = ctrl_img_list[0] if len(ctrl_img_list) > 0 else None
|
||||
|
||||
|
||||
positive = self.sd.encode_prompt(
|
||||
gen_img_config.prompt,
|
||||
control_images=ctrl_img
|
||||
@@ -202,6 +249,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.sd.encode_control_in_text_embeddings:
|
||||
# just do a blank image for unconditionals
|
||||
control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
if self.sd.has_multiple_control_images:
|
||||
control_image = [control_image]
|
||||
|
||||
kwargs['control_images'] = control_image
|
||||
self.unconditional_embeds = self.sd.encode_prompt(
|
||||
[self.train_config.unconditional_prompt],
|
||||
@@ -272,6 +322,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.sd.encode_control_in_text_embeddings:
|
||||
# just do a blank image for unconditionals
|
||||
control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
if self.sd.has_multiple_control_images:
|
||||
control_image = [control_image]
|
||||
encode_kwargs['control_images'] = control_image
|
||||
self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs)
|
||||
if self.trigger_word is not None:
|
||||
|
||||
Reference in New Issue
Block a user