mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-04 04:38:57 +00:00
Initial support for qwen image edit plus
This commit is contained in:
@@ -4,7 +4,7 @@ from .f_light import FLiteModel
|
||||
from .omnigen2 import OmniGen2Model
|
||||
from .flux_kontext import FluxKontextModel
|
||||
from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel
|
||||
from .qwen_image import QwenImageModel, QwenImageEditModel
|
||||
from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusModel
|
||||
|
||||
AI_TOOLKIT_MODELS = [
|
||||
# put a list of models here
|
||||
@@ -20,4 +20,5 @@ AI_TOOLKIT_MODELS = [
|
||||
Wan2214bModel,
|
||||
QwenImageModel,
|
||||
QwenImageEditModel,
|
||||
QwenImageEditPlusModel,
|
||||
]
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from .qwen_image import QwenImageModel
|
||||
from .qwen_image_edit import QwenImageEditModel
|
||||
from .qwen_image_edit import QwenImageEditModel
|
||||
from .qwen_image_edit_plus import QwenImageEditPlusModel
|
||||
|
||||
@@ -0,0 +1,309 @@
|
||||
import math
|
||||
import torch
|
||||
from .qwen_image import QwenImageModel
|
||||
import os
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
import yaml
|
||||
from toolkit import train_tools
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from PIL import Image
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from toolkit.basic import flush
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.samplers.custom_flowmatch_sampler import (
|
||||
CustomFlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from toolkit.accelerator import get_accelerator, unwrap_model
|
||||
from optimum.quanto import freeze, QTensor
|
||||
from toolkit.util.quantize import quantize, get_qtype, quantize_model
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers import (
|
||||
QwenImageTransformer2DModel,
|
||||
AutoencoderKLQwenImage,
|
||||
)
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
|
||||
try:
|
||||
from diffusers import QwenImageEditPlusPipeline
|
||||
from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus import CONDITION_IMAGE_SIZE, VAE_IMAGE_SIZE
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Diffusers is out of date. Update diffusers to the latest version by doing 'pip uninstall diffusers' and then 'pip install -r requirements.txt'"
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditPlusModel(QwenImageModel):
|
||||
arch = "qwen_image_edit_plus"
|
||||
_qwen_image_keep_visual = True
|
||||
_qwen_pipeline = QwenImageEditPlusPipeline
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype="bf16",
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
|
||||
)
|
||||
self.is_flow_matching = True
|
||||
self.is_transformer = True
|
||||
self.target_lora_modules = ["QwenImageTransformer2DModel"]
|
||||
|
||||
# set true for models that encode control image into text embeddings
|
||||
self.encode_control_in_text_embeddings = True
|
||||
# control images will come in as a list for encoding some things if true
|
||||
self.has_multiple_control_images = True
|
||||
# do not resize control images
|
||||
self.use_raw_control_images = True
|
||||
|
||||
def load_model(self):
|
||||
super().load_model()
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
scheduler = QwenImageModel.get_train_scheduler()
|
||||
|
||||
pipeline: QwenImageEditPlusPipeline = QwenImageEditPlusPipeline(
|
||||
scheduler=scheduler,
|
||||
text_encoder=unwrap_model(self.text_encoder[0]),
|
||||
tokenizer=self.tokenizer[0],
|
||||
processor=self.processor,
|
||||
vae=unwrap_model(self.vae),
|
||||
transformer=unwrap_model(self.transformer),
|
||||
)
|
||||
|
||||
pipeline = pipeline.to(self.device_torch)
|
||||
|
||||
return pipeline
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline: QwenImageEditPlusPipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
self.model.to(self.device_torch, dtype=self.torch_dtype)
|
||||
sc = self.get_bucket_divisibility()
|
||||
gen_config.width = int(gen_config.width // sc * sc)
|
||||
gen_config.height = int(gen_config.height // sc * sc)
|
||||
|
||||
control_img_list = []
|
||||
if gen_config.ctrl_img is not None:
|
||||
control_img = Image.open(gen_config.ctrl_img)
|
||||
control_img = control_img.convert("RGB")
|
||||
control_img_list.append(control_img)
|
||||
elif gen_config.ctrl_img_1 is not None:
|
||||
control_img = Image.open(gen_config.ctrl_img_1)
|
||||
control_img = control_img.convert("RGB")
|
||||
control_img_list.append(control_img)
|
||||
|
||||
if gen_config.ctrl_img_2 is not None:
|
||||
control_img = Image.open(gen_config.ctrl_img_2)
|
||||
control_img = control_img.convert("RGB")
|
||||
control_img_list.append(control_img)
|
||||
if gen_config.ctrl_img_3 is not None:
|
||||
control_img = Image.open(gen_config.ctrl_img_3)
|
||||
control_img = control_img.convert("RGB")
|
||||
control_img_list.append(control_img)
|
||||
|
||||
# flush for low vram if we are doing that
|
||||
# flush_between_steps = self.model_config.low_vram
|
||||
flush_between_steps = False
|
||||
|
||||
# Fix a bug in diffusers/torch
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if flush_between_steps:
|
||||
flush()
|
||||
latents = callback_kwargs["latents"]
|
||||
|
||||
return {"latents": latents}
|
||||
|
||||
img = pipeline(
|
||||
image=control_img_list,
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
prompt_embeds_mask=conditional_embeds.attention_mask.to(
|
||||
self.device_torch, dtype=torch.int64
|
||||
),
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds,
|
||||
negative_prompt_embeds_mask=unconditional_embeds.attention_mask.to(
|
||||
self.device_torch, dtype=torch.int64
|
||||
),
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
true_cfg_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
generator=generator,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
**extra,
|
||||
).images[0]
|
||||
return img
|
||||
|
||||
def condition_noisy_latents(
|
||||
self, latents: torch.Tensor, batch: "DataLoaderBatchDTO"
|
||||
):
|
||||
# we get the control image from the batch
|
||||
return latents.detach()
|
||||
|
||||
def get_prompt_embeds(self, prompt: str, control_images=None) -> PromptEmbeds:
|
||||
# todo handle not caching text encoder
|
||||
if self.pipeline.text_encoder.device != self.device_torch:
|
||||
self.pipeline.text_encoder.to(self.device_torch)
|
||||
|
||||
if control_images is not None and len(control_images) > 0:
|
||||
for i in range(len(control_images)):
|
||||
# control images are 0 - 1 scale, shape (bs, ch, height, width)
|
||||
ratio = control_images[i].shape[2] / control_images[i].shape[3]
|
||||
width = math.sqrt(CONDITION_IMAGE_SIZE * ratio)
|
||||
height = width / ratio
|
||||
|
||||
width = round(width / 32) * 32
|
||||
height = round(height / 32) * 32
|
||||
|
||||
control_images[i] = F.interpolate(
|
||||
control_images[i], size=(height, width), mode="bilinear"
|
||||
)
|
||||
|
||||
prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt(
|
||||
prompt,
|
||||
image=control_images,
|
||||
device=self.device_torch,
|
||||
num_images_per_prompt=1,
|
||||
)
|
||||
pe = PromptEmbeds(prompt_embeds)
|
||||
pe.attention_mask = prompt_embeds_mask
|
||||
return pe
|
||||
|
||||
def get_noise_prediction(
|
||||
self,
|
||||
latent_model_input: torch.Tensor,
|
||||
timestep: torch.Tensor, # 0 to 1000 scale
|
||||
text_embeddings: PromptEmbeds,
|
||||
batch: "DataLoaderBatchDTO" = None,
|
||||
**kwargs,
|
||||
):
|
||||
with torch.no_grad():
|
||||
batch_size, num_channels_latents, height, width = latent_model_input.shape
|
||||
|
||||
# pack image tokens
|
||||
latent_model_input = latent_model_input.view(
|
||||
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
|
||||
)
|
||||
latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5)
|
||||
latent_model_input = latent_model_input.reshape(
|
||||
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
|
||||
)
|
||||
|
||||
raw_packed_latents = latent_model_input
|
||||
|
||||
img_h2, img_w2 = height // 2, width // 2
|
||||
|
||||
img_shapes = [
|
||||
[(1, img_h2, img_w2)]
|
||||
] * batch_size
|
||||
|
||||
# pack controls
|
||||
if batch is None:
|
||||
raise ValueError("Batch is required for QwenImageEditPlusModel")
|
||||
|
||||
# split the latents into batch items so we can concat the controls
|
||||
packed_latents_list = torch.chunk(latent_model_input, batch_size, dim=0)
|
||||
packed_latents_with_controls_list = []
|
||||
|
||||
if batch.control_tensor_list is not None:
|
||||
if len(batch.control_tensor_list) != batch_size:
|
||||
raise ValueError("Control tensor list length does not match batch size")
|
||||
b = 0
|
||||
for control_tensor_list in batch.control_tensor_list:
|
||||
# control tensor list is a list of tensors for this batch item
|
||||
controls = []
|
||||
# pack control
|
||||
for control_img in control_tensor_list:
|
||||
# control images are 0 - 1 scale, shape (1, ch, height, width)
|
||||
control_img = control_img.to(self.device_torch, dtype=self.torch_dtype)
|
||||
# if it is only 3 dim, add batch dim
|
||||
if len(control_img.shape) == 3:
|
||||
control_img = control_img.unsqueeze(0)
|
||||
ratio = control_img.shape[2] / control_img.shape[3]
|
||||
c_width = math.sqrt(VAE_IMAGE_SIZE * ratio)
|
||||
c_height = c_width / ratio
|
||||
|
||||
c_width = round(c_width / 32) * 32
|
||||
c_height = round(c_height / 32) * 32
|
||||
|
||||
control_img = F.interpolate(
|
||||
control_img, size=(c_height, c_width), mode="bilinear"
|
||||
)
|
||||
|
||||
control_latent = self.encode_images(
|
||||
control_img,
|
||||
device=self.device_torch,
|
||||
dtype=self.torch_dtype,
|
||||
)
|
||||
|
||||
clb, cl_num_channels_latents, cl_height, cl_width = control_latent.shape
|
||||
|
||||
control = control_latent.view(
|
||||
1, cl_num_channels_latents, cl_height // 2, 2, cl_width // 2, 2
|
||||
)
|
||||
control = control.permute(0, 2, 4, 1, 3, 5)
|
||||
control = control.reshape(
|
||||
1, (cl_height // 2) * (cl_width // 2), num_channels_latents * 4
|
||||
)
|
||||
|
||||
img_shapes[b].append((1, cl_height // 2, cl_width // 2))
|
||||
controls.append(control)
|
||||
|
||||
# stack controls on dim 1
|
||||
control = torch.cat(controls, dim=1).to(packed_latents_list[b].device, dtype=packed_latents_list[b].dtype)
|
||||
# concat with latents
|
||||
packed_latents_with_control = torch.cat([packed_latents_list[b], control], dim=1)
|
||||
|
||||
packed_latents_with_controls_list.append(packed_latents_with_control)
|
||||
|
||||
b += 1
|
||||
|
||||
latent_model_input = torch.cat(packed_latents_with_controls_list, dim=0)
|
||||
|
||||
prompt_embeds_mask = text_embeddings.attention_mask.to(
|
||||
self.device_torch, dtype=torch.int64
|
||||
)
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist()
|
||||
enc_hs = text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype)
|
||||
prompt_embeds_mask = text_embeddings.attention_mask.to(
|
||||
self.device_torch, dtype=torch.int64
|
||||
)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=enc_hs,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
return_dict=False,
|
||||
**kwargs,
|
||||
)[0]
|
||||
|
||||
noise_pred = noise_pred[:, : raw_packed_latents.size(1)]
|
||||
|
||||
# unpack
|
||||
noise_pred = noise_pred.view(
|
||||
batch_size, height // 2, width // 2, num_channels_latents, 2, 2
|
||||
)
|
||||
noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5)
|
||||
noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width)
|
||||
return noise_pred
|
||||
@@ -129,17 +129,64 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
prompt=prompt, # it will autoparse the prompt
|
||||
negative_prompt=sample_item.neg,
|
||||
output_path=output_path,
|
||||
ctrl_img=sample_item.ctrl_img
|
||||
ctrl_img=sample_item.ctrl_img,
|
||||
ctrl_img_1=sample_item.ctrl_img_1,
|
||||
ctrl_img_2=sample_item.ctrl_img_2,
|
||||
ctrl_img_3=sample_item.ctrl_img_3,
|
||||
)
|
||||
|
||||
has_control_images = False
|
||||
if gen_img_config.ctrl_img is not None or gen_img_config.ctrl_img_1 is not None or gen_img_config.ctrl_img_2 is not None or gen_img_config.ctrl_img_3 is not None:
|
||||
has_control_images = True
|
||||
# see if we need to encode the control images
|
||||
if self.sd.encode_control_in_text_embeddings and gen_img_config.ctrl_img is not None:
|
||||
ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img = (
|
||||
TF.to_tensor(ctrl_img)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
if self.sd.encode_control_in_text_embeddings and has_control_images:
|
||||
|
||||
ctrl_img_list = []
|
||||
|
||||
if gen_img_config.ctrl_img is not None:
|
||||
ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img = (
|
||||
TF.to_tensor(ctrl_img)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
ctrl_img_list.append(ctrl_img)
|
||||
|
||||
if gen_img_config.ctrl_img_1 is not None:
|
||||
ctrl_img_1 = Image.open(gen_img_config.ctrl_img_1).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img_1 = (
|
||||
TF.to_tensor(ctrl_img_1)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
ctrl_img_list.append(ctrl_img_1)
|
||||
if gen_img_config.ctrl_img_2 is not None:
|
||||
ctrl_img_2 = Image.open(gen_img_config.ctrl_img_2).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img_2 = (
|
||||
TF.to_tensor(ctrl_img_2)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
ctrl_img_list.append(ctrl_img_2)
|
||||
if gen_img_config.ctrl_img_3 is not None:
|
||||
ctrl_img_3 = Image.open(gen_img_config.ctrl_img_3).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img_3 = (
|
||||
TF.to_tensor(ctrl_img_3)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
ctrl_img_list.append(ctrl_img_3)
|
||||
|
||||
if self.sd.has_multiple_control_images:
|
||||
ctrl_img = ctrl_img_list
|
||||
else:
|
||||
ctrl_img = ctrl_img_list[0] if len(ctrl_img_list) > 0 else None
|
||||
|
||||
|
||||
positive = self.sd.encode_prompt(
|
||||
gen_img_config.prompt,
|
||||
control_images=ctrl_img
|
||||
@@ -202,6 +249,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.sd.encode_control_in_text_embeddings:
|
||||
# just do a blank image for unconditionals
|
||||
control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
if self.sd.has_multiple_control_images:
|
||||
control_image = [control_image]
|
||||
|
||||
kwargs['control_images'] = control_image
|
||||
self.unconditional_embeds = self.sd.encode_prompt(
|
||||
[self.train_config.unconditional_prompt],
|
||||
@@ -272,6 +322,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.sd.encode_control_in_text_embeddings:
|
||||
# just do a blank image for unconditionals
|
||||
control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
if self.sd.has_multiple_control_images:
|
||||
control_image = [control_image]
|
||||
encode_kwargs['control_images'] = control_image
|
||||
self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs)
|
||||
if self.trigger_word is not None:
|
||||
|
||||
@@ -348,6 +348,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
fps=sample_item.fps,
|
||||
ctrl_img=sample_item.ctrl_img,
|
||||
ctrl_idx=sample_item.ctrl_idx,
|
||||
ctrl_img_1=sample_item.ctrl_img_1,
|
||||
ctrl_img_2=sample_item.ctrl_img_2,
|
||||
ctrl_img_3=sample_item.ctrl_img_3,
|
||||
**extra_args
|
||||
))
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
torchao==0.10.0
|
||||
safetensors
|
||||
git+https://github.com/jaretburkett/easy_dwpose.git
|
||||
git+https://github.com/huggingface/diffusers@7a2b78bf0f788d311cc96b61e660a8e13e3b1e63
|
||||
git+https://github.com/huggingface/diffusers@1448b035859dd57bbb565239dcdd79a025a85422
|
||||
transformers==4.52.4
|
||||
lycoris-lora==1.8.3
|
||||
flatten_json
|
||||
|
||||
@@ -56,6 +56,11 @@ class SampleItem:
|
||||
self.num_frames: int = kwargs.get('num_frames', sample_config.num_frames)
|
||||
self.ctrl_img: Optional[str] = kwargs.get('ctrl_img', None)
|
||||
self.ctrl_idx: int = kwargs.get('ctrl_idx', 0)
|
||||
# for multi control image models
|
||||
self.ctrl_img_1: Optional[str] = kwargs.get('ctrl_img_1', self.ctrl_img)
|
||||
self.ctrl_img_2: Optional[str] = kwargs.get('ctrl_img_2', None)
|
||||
self.ctrl_img_3: Optional[str] = kwargs.get('ctrl_img_3', None)
|
||||
|
||||
self.network_multiplier: float = kwargs.get('network_multiplier', sample_config.network_multiplier)
|
||||
# convert to a number if it is a string
|
||||
if isinstance(self.network_multiplier, str):
|
||||
@@ -966,6 +971,9 @@ class GenerateImageConfig:
|
||||
extra_values: List[float] = None, # extra values to save with prompt file
|
||||
logger: Optional[EmptyLogger] = None,
|
||||
ctrl_img: Optional[str] = None, # control image for controlnet
|
||||
ctrl_img_1: Optional[str] = None, # first control image for multi control model
|
||||
ctrl_img_2: Optional[str] = None, # second control image for multi control model
|
||||
ctrl_img_3: Optional[str] = None, # third control image for multi control model
|
||||
num_frames: int = 1,
|
||||
fps: int = 15,
|
||||
ctrl_idx: int = 0
|
||||
@@ -1002,6 +1010,12 @@ class GenerateImageConfig:
|
||||
self.ctrl_img = ctrl_img
|
||||
self.ctrl_idx = ctrl_idx
|
||||
|
||||
if ctrl_img_1 is None and ctrl_img is not None:
|
||||
ctrl_img_1 = ctrl_img
|
||||
|
||||
self.ctrl_img_1 = ctrl_img_1
|
||||
self.ctrl_img_2 = ctrl_img_2
|
||||
self.ctrl_img_3 = ctrl_img_3
|
||||
|
||||
# prompt string will override any settings above
|
||||
self._process_prompt_string()
|
||||
|
||||
@@ -144,6 +144,7 @@ class DataLoaderBatchDTO:
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
self.control_tensor_list: Union[List[List[torch.Tensor]], None] = None
|
||||
self.clip_image_tensor: Union[torch.Tensor, None] = None
|
||||
self.mask_tensor: Union[torch.Tensor, None] = None
|
||||
self.unaugmented_tensor: Union[torch.Tensor, None] = None
|
||||
@@ -160,7 +161,6 @@ class DataLoaderBatchDTO:
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
if is_latents_cached:
|
||||
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
self.prompt_embeds: Union[PromptEmbeds, None] = None
|
||||
# if self.file_items[0].control_tensor is not None:
|
||||
# if any have a control tensor, we concatenate them
|
||||
@@ -178,6 +178,16 @@ class DataLoaderBatchDTO:
|
||||
else:
|
||||
control_tensors.append(x.control_tensor)
|
||||
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
|
||||
|
||||
# handle control tensor list
|
||||
if any([x.control_tensor_list is not None for x in self.file_items]):
|
||||
self.control_tensor_list = []
|
||||
for x in self.file_items:
|
||||
if x.control_tensor_list is not None:
|
||||
self.control_tensor_list.append(x.control_tensor_list)
|
||||
else:
|
||||
raise Exception(f"Could not find control tensors for all file items, missing for {x.path}")
|
||||
|
||||
|
||||
self.inpaint_tensor: Union[torch.Tensor, None] = None
|
||||
if any([x.inpaint_tensor is not None for x in self.file_items]):
|
||||
|
||||
@@ -850,6 +850,9 @@ class ControlFileItemDTOMixin:
|
||||
self.has_control_image = False
|
||||
self.control_path: Union[str, List[str], None] = None
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
self.control_tensor_list: Union[List[torch.Tensor], None] = None
|
||||
sd = kwargs.get('sd', None)
|
||||
self.use_raw_control_images = sd is not None and sd.use_raw_control_images
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
self.full_size_control_images = False
|
||||
if dataset_config.control_path is not None:
|
||||
@@ -900,23 +903,14 @@ class ControlFileItemDTOMixin:
|
||||
except Exception as e:
|
||||
print_acc(f"Error: {e}")
|
||||
print_acc(f"Error loading image: {control_path}")
|
||||
|
||||
|
||||
if not self.full_size_control_images:
|
||||
# we just scale them to 512x512:
|
||||
w, h = img.size
|
||||
img = img.resize((512, 512), Image.BICUBIC)
|
||||
|
||||
else:
|
||||
elif not self.use_raw_control_images:
|
||||
w, h = img.size
|
||||
if w > h and self.scale_to_width < self.scale_to_height:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
elif h > w and self.scale_to_height < self.scale_to_width:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
|
||||
if self.flip_x:
|
||||
# do a flip
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
@@ -950,11 +944,15 @@ class ControlFileItemDTOMixin:
|
||||
self.control_tensor = None
|
||||
elif len(control_tensors) == 1:
|
||||
self.control_tensor = control_tensors[0]
|
||||
elif self.use_raw_control_images:
|
||||
# just send the list of tensors as their shapes wont match
|
||||
self.control_tensor_list = control_tensors
|
||||
else:
|
||||
self.control_tensor = torch.stack(control_tensors, dim=0)
|
||||
|
||||
def cleanup_control(self: 'FileItemDTO'):
|
||||
self.control_tensor = None
|
||||
self.control_tensor_list = None
|
||||
|
||||
|
||||
class ClipImageFileItemDTOMixin:
|
||||
@@ -1884,14 +1882,31 @@ class TextEmbeddingCachingMixin:
|
||||
if file_item.encode_control_in_text_embeddings:
|
||||
if file_item.control_path is None:
|
||||
raise Exception(f"Could not find a control image for {file_item.path} which is needed for this model")
|
||||
# load the control image and feed it into the text encoder
|
||||
ctrl_img = Image.open(file_item.control_path).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img = (
|
||||
TF.to_tensor(ctrl_img)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
ctrl_img_list = []
|
||||
control_path_list = file_item.control_path
|
||||
if not isinstance(file_item.control_path, list):
|
||||
control_path_list = [control_path_list]
|
||||
for i in range(len(control_path_list)):
|
||||
try:
|
||||
img = Image.open(control_path_list[i]).convert("RGB")
|
||||
img = exif_transpose(img)
|
||||
# convert to 0 to 1 tensor
|
||||
img = (
|
||||
TF.to_tensor(img)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
ctrl_img_list.append(img)
|
||||
except Exception as e:
|
||||
print_acc(f"Error: {e}")
|
||||
print_acc(f"Error loading control image: {control_path_list[i]}")
|
||||
|
||||
if len(ctrl_img_list) == 0:
|
||||
ctrl_img = None
|
||||
elif not self.sd.has_multiple_control_images:
|
||||
ctrl_img = ctrl_img_list[0]
|
||||
else:
|
||||
ctrl_img = ctrl_img_list
|
||||
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption, control_images=ctrl_img)
|
||||
else:
|
||||
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)
|
||||
|
||||
@@ -181,6 +181,10 @@ class BaseModel:
|
||||
|
||||
# set true for models that encode control image into text embeddings
|
||||
self.encode_control_in_text_embeddings = False
|
||||
# control images will come in as a list for encoding some things if true
|
||||
self.has_multiple_control_images = False
|
||||
# do not resize control images
|
||||
self.use_raw_control_images = False
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
|
||||
@@ -219,6 +219,10 @@ class StableDiffusion:
|
||||
|
||||
# set true for models that encode control image into text embeddings
|
||||
self.encode_control_in_text_embeddings = False
|
||||
# control images will come in as a list for encoding some things if true
|
||||
self.has_multiple_control_images = False
|
||||
# do not resize control images
|
||||
self.use_raw_control_images = False
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user