mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-24 08:19:24 +00:00
Initial support for qwen image edit plus
This commit is contained in:
@@ -56,6 +56,11 @@ class SampleItem:
|
||||
self.num_frames: int = kwargs.get('num_frames', sample_config.num_frames)
|
||||
self.ctrl_img: Optional[str] = kwargs.get('ctrl_img', None)
|
||||
self.ctrl_idx: int = kwargs.get('ctrl_idx', 0)
|
||||
# for multi control image models
|
||||
self.ctrl_img_1: Optional[str] = kwargs.get('ctrl_img_1', self.ctrl_img)
|
||||
self.ctrl_img_2: Optional[str] = kwargs.get('ctrl_img_2', None)
|
||||
self.ctrl_img_3: Optional[str] = kwargs.get('ctrl_img_3', None)
|
||||
|
||||
self.network_multiplier: float = kwargs.get('network_multiplier', sample_config.network_multiplier)
|
||||
# convert to a number if it is a string
|
||||
if isinstance(self.network_multiplier, str):
|
||||
@@ -966,6 +971,9 @@ class GenerateImageConfig:
|
||||
extra_values: List[float] = None, # extra values to save with prompt file
|
||||
logger: Optional[EmptyLogger] = None,
|
||||
ctrl_img: Optional[str] = None, # control image for controlnet
|
||||
ctrl_img_1: Optional[str] = None, # first control image for multi control model
|
||||
ctrl_img_2: Optional[str] = None, # second control image for multi control model
|
||||
ctrl_img_3: Optional[str] = None, # third control image for multi control model
|
||||
num_frames: int = 1,
|
||||
fps: int = 15,
|
||||
ctrl_idx: int = 0
|
||||
@@ -1002,6 +1010,12 @@ class GenerateImageConfig:
|
||||
self.ctrl_img = ctrl_img
|
||||
self.ctrl_idx = ctrl_idx
|
||||
|
||||
if ctrl_img_1 is None and ctrl_img is not None:
|
||||
ctrl_img_1 = ctrl_img
|
||||
|
||||
self.ctrl_img_1 = ctrl_img_1
|
||||
self.ctrl_img_2 = ctrl_img_2
|
||||
self.ctrl_img_3 = ctrl_img_3
|
||||
|
||||
# prompt string will override any settings above
|
||||
self._process_prompt_string()
|
||||
|
||||
@@ -144,6 +144,7 @@ class DataLoaderBatchDTO:
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
self.control_tensor_list: Union[List[List[torch.Tensor]], None] = None
|
||||
self.clip_image_tensor: Union[torch.Tensor, None] = None
|
||||
self.mask_tensor: Union[torch.Tensor, None] = None
|
||||
self.unaugmented_tensor: Union[torch.Tensor, None] = None
|
||||
@@ -160,7 +161,6 @@ class DataLoaderBatchDTO:
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
if is_latents_cached:
|
||||
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
self.prompt_embeds: Union[PromptEmbeds, None] = None
|
||||
# if self.file_items[0].control_tensor is not None:
|
||||
# if any have a control tensor, we concatenate them
|
||||
@@ -178,6 +178,16 @@ class DataLoaderBatchDTO:
|
||||
else:
|
||||
control_tensors.append(x.control_tensor)
|
||||
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
|
||||
|
||||
# handle control tensor list
|
||||
if any([x.control_tensor_list is not None for x in self.file_items]):
|
||||
self.control_tensor_list = []
|
||||
for x in self.file_items:
|
||||
if x.control_tensor_list is not None:
|
||||
self.control_tensor_list.append(x.control_tensor_list)
|
||||
else:
|
||||
raise Exception(f"Could not find control tensors for all file items, missing for {x.path}")
|
||||
|
||||
|
||||
self.inpaint_tensor: Union[torch.Tensor, None] = None
|
||||
if any([x.inpaint_tensor is not None for x in self.file_items]):
|
||||
|
||||
@@ -850,6 +850,9 @@ class ControlFileItemDTOMixin:
|
||||
self.has_control_image = False
|
||||
self.control_path: Union[str, List[str], None] = None
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
self.control_tensor_list: Union[List[torch.Tensor], None] = None
|
||||
sd = kwargs.get('sd', None)
|
||||
self.use_raw_control_images = sd is not None and sd.use_raw_control_images
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
self.full_size_control_images = False
|
||||
if dataset_config.control_path is not None:
|
||||
@@ -900,23 +903,14 @@ class ControlFileItemDTOMixin:
|
||||
except Exception as e:
|
||||
print_acc(f"Error: {e}")
|
||||
print_acc(f"Error loading image: {control_path}")
|
||||
|
||||
|
||||
if not self.full_size_control_images:
|
||||
# we just scale them to 512x512:
|
||||
w, h = img.size
|
||||
img = img.resize((512, 512), Image.BICUBIC)
|
||||
|
||||
else:
|
||||
elif not self.use_raw_control_images:
|
||||
w, h = img.size
|
||||
if w > h and self.scale_to_width < self.scale_to_height:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
elif h > w and self.scale_to_height < self.scale_to_width:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
|
||||
if self.flip_x:
|
||||
# do a flip
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
@@ -950,11 +944,15 @@ class ControlFileItemDTOMixin:
|
||||
self.control_tensor = None
|
||||
elif len(control_tensors) == 1:
|
||||
self.control_tensor = control_tensors[0]
|
||||
elif self.use_raw_control_images:
|
||||
# just send the list of tensors as their shapes wont match
|
||||
self.control_tensor_list = control_tensors
|
||||
else:
|
||||
self.control_tensor = torch.stack(control_tensors, dim=0)
|
||||
|
||||
def cleanup_control(self: 'FileItemDTO'):
|
||||
self.control_tensor = None
|
||||
self.control_tensor_list = None
|
||||
|
||||
|
||||
class ClipImageFileItemDTOMixin:
|
||||
@@ -1884,14 +1882,31 @@ class TextEmbeddingCachingMixin:
|
||||
if file_item.encode_control_in_text_embeddings:
|
||||
if file_item.control_path is None:
|
||||
raise Exception(f"Could not find a control image for {file_item.path} which is needed for this model")
|
||||
# load the control image and feed it into the text encoder
|
||||
ctrl_img = Image.open(file_item.control_path).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img = (
|
||||
TF.to_tensor(ctrl_img)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
ctrl_img_list = []
|
||||
control_path_list = file_item.control_path
|
||||
if not isinstance(file_item.control_path, list):
|
||||
control_path_list = [control_path_list]
|
||||
for i in range(len(control_path_list)):
|
||||
try:
|
||||
img = Image.open(control_path_list[i]).convert("RGB")
|
||||
img = exif_transpose(img)
|
||||
# convert to 0 to 1 tensor
|
||||
img = (
|
||||
TF.to_tensor(img)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
ctrl_img_list.append(img)
|
||||
except Exception as e:
|
||||
print_acc(f"Error: {e}")
|
||||
print_acc(f"Error loading control image: {control_path_list[i]}")
|
||||
|
||||
if len(ctrl_img_list) == 0:
|
||||
ctrl_img = None
|
||||
elif not self.sd.has_multiple_control_images:
|
||||
ctrl_img = ctrl_img_list[0]
|
||||
else:
|
||||
ctrl_img = ctrl_img_list
|
||||
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption, control_images=ctrl_img)
|
||||
else:
|
||||
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)
|
||||
|
||||
@@ -181,6 +181,10 @@ class BaseModel:
|
||||
|
||||
# set true for models that encode control image into text embeddings
|
||||
self.encode_control_in_text_embeddings = False
|
||||
# control images will come in as a list for encoding some things if true
|
||||
self.has_multiple_control_images = False
|
||||
# do not resize control images
|
||||
self.use_raw_control_images = False
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
|
||||
@@ -219,6 +219,10 @@ class StableDiffusion:
|
||||
|
||||
# set true for models that encode control image into text embeddings
|
||||
self.encode_control_in_text_embeddings = False
|
||||
# control images will come in as a list for encoding some things if true
|
||||
self.has_multiple_control_images = False
|
||||
# do not resize control images
|
||||
self.use_raw_control_images = False
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user