Merge pull request #383 from ostris/qwen_image_edit

Add support for Qwen-Image-Edit
This commit is contained in:
Jaret Burkett
2025-08-22 10:29:25 -06:00
committed by GitHub
14 changed files with 448 additions and 41 deletions

View File

@@ -4,7 +4,7 @@ from .f_light import FLiteModel
from .omnigen2 import OmniGen2Model from .omnigen2 import OmniGen2Model
from .flux_kontext import FluxKontextModel from .flux_kontext import FluxKontextModel
from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel
from .qwen_image import QwenImageModel from .qwen_image import QwenImageModel, QwenImageEditModel
AI_TOOLKIT_MODELS = [ AI_TOOLKIT_MODELS = [
# put a list of models here # put a list of models here
@@ -18,4 +18,5 @@ AI_TOOLKIT_MODELS = [
Wan2214bI2VModel, Wan2214bI2VModel,
Wan2214bModel, Wan2214bModel,
QwenImageModel, QwenImageModel,
QwenImageEditModel,
] ]

View File

@@ -1 +1,2 @@
from .qwen_image import QwenImageModel from .qwen_image import QwenImageModel
from .qwen_image_edit import QwenImageEditModel

View File

@@ -16,7 +16,7 @@ from toolkit.util.quantize import quantize, get_qtype, quantize_model
import torch.nn.functional as F import torch.nn.functional as F
from diffusers import QwenImagePipeline, QwenImageTransformer2DModel, AutoencoderKLQwenImage from diffusers import QwenImagePipeline, QwenImageTransformer2DModel, AutoencoderKLQwenImage
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
from tqdm import tqdm from tqdm import tqdm
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -43,7 +43,8 @@ scheduler_config = {
class QwenImageModel(BaseModel): class QwenImageModel(BaseModel):
arch = "qwen_image" arch = "qwen_image"
_qwen_image_keep_processor = False _qwen_image_keep_visual = False
_qwen_pipeline = QwenImagePipeline
def __init__( def __init__(
self, self,
@@ -119,10 +120,9 @@ class QwenImageModel(BaseModel):
# remove the visual model as it is not needed for image generation # remove the visual model as it is not needed for image generation
self.processor = None self.processor = None
if self._qwen_image_keep_processor: if not self._qwen_image_keep_visual:
self.processor = text_encoder.model.visual text_encoder.model.visual = None
text_encoder.model.visual = None
text_encoder.to(self.device_torch, dtype=dtype) text_encoder.to(self.device_torch, dtype=dtype)
flush() flush()
@@ -140,13 +140,27 @@ class QwenImageModel(BaseModel):
self.noise_scheduler = QwenImageModel.get_train_scheduler() self.noise_scheduler = QwenImageModel.get_train_scheduler()
self.print_and_status_update("Making pipe") self.print_and_status_update("Making pipe")
kwargs = {}
if self._qwen_image_keep_visual:
try:
self.processor = Qwen2VLProcessor.from_pretrained(
model_path, subfolder="processor"
)
except OSError:
self.processor = Qwen2VLProcessor.from_pretrained(
base_model_path, subfolder="processor"
)
kwargs['processor'] = self.processor
pipe: QwenImagePipeline = QwenImagePipeline( pipe: QwenImagePipeline = self._qwen_pipeline(
scheduler=self.noise_scheduler, scheduler=self.noise_scheduler,
text_encoder=None, text_encoder=None,
tokenizer=tokenizer, tokenizer=tokenizer,
vae=vae, vae=vae,
transformer=None, transformer=None,
**kwargs
) )
# for quantization, it works best to do these after making the pipe # for quantization, it works best to do these after making the pipe
pipe.text_encoder = text_encoder pipe.text_encoder = text_encoder
@@ -261,21 +275,13 @@ class QwenImageModel(BaseModel):
latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5) latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5)
latent_model_input = latent_model_input.reshape(batch_size, (height // ps) * (width // ps), num_channels_latents * (ps * ps)) latent_model_input = latent_model_input.reshape(batch_size, (height // ps) * (width // ps), num_channels_latents * (ps * ps))
# clamp text length to RoPE capacity for this image size
# img_shapes passed to the model # img_shapes passed to the model
img_h2, img_w2 = height // ps, width // ps img_h2, img_w2 = height // ps, width // ps
img_shapes = [(1, img_h2, img_w2)] * batch_size img_shapes = [[(1, img_h2, img_w2)]] * batch_size
# QwenEmbedRope logic: enc_hs = text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype)
max_vid_index = max(img_h2 // ps, img_w2 // ps) prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64)
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist()
rope_cap = 1024 - max_vid_index # available text positions in RoPE cache
seq_len_actual = text_embeddings.text_embeds.shape[1]
use_len = min(seq_len_actual, rope_cap)
enc_hs = text_embeddings.text_embeds[:, :use_len].to(self.device_torch, self.torch_dtype)
prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64)[:, :use_len]
txt_seq_lens = [use_len] * batch_size
noise_pred = self.transformer( noise_pred = self.transformer(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),

View File

@@ -0,0 +1,276 @@
import math
import torch
from .qwen_image import QwenImageModel
import os
from typing import TYPE_CHECKING, List, Optional
import yaml
from toolkit import train_tools
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from PIL import Image
from toolkit.models.base_model import BaseModel
from toolkit.basic import flush
from toolkit.prompt_utils import PromptEmbeds
from toolkit.samplers.custom_flowmatch_sampler import (
CustomFlowMatchEulerDiscreteScheduler,
)
from toolkit.accelerator import get_accelerator, unwrap_model
from optimum.quanto import freeze, QTensor
from toolkit.util.quantize import quantize, get_qtype, quantize_model
import torch.nn.functional as F
from diffusers import (
QwenImagePipeline,
QwenImageTransformer2DModel,
AutoencoderKLQwenImage,
)
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
from tqdm import tqdm
if TYPE_CHECKING:
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
try:
from diffusers import QwenImageEditPipeline
except ImportError:
raise ImportError(
"QwenImageEditPipeline not found. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt"
)
class QwenImageEditModel(QwenImageModel):
arch = "qwen_image_edit"
_qwen_image_keep_visual = True
_qwen_pipeline = QwenImageEditPipeline
def __init__(
self,
device,
model_config: ModelConfig,
dtype="bf16",
custom_pipeline=None,
noise_scheduler=None,
**kwargs,
):
super().__init__(
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
)
self.is_flow_matching = True
self.is_transformer = True
self.target_lora_modules = ["QwenImageTransformer2DModel"]
# set true for models that encode control image into text embeddings
self.encode_control_in_text_embeddings = True
def load_model(self):
super().load_model()
def get_generation_pipeline(self):
scheduler = QwenImageModel.get_train_scheduler()
pipeline: QwenImageEditPipeline = QwenImageEditPipeline(
scheduler=scheduler,
text_encoder=unwrap_model(self.text_encoder[0]),
tokenizer=self.tokenizer[0],
processor=self.processor,
vae=unwrap_model(self.vae),
transformer=unwrap_model(self.transformer),
)
pipeline = pipeline.to(self.device_torch)
return pipeline
def generate_single_image(
self,
pipeline: QwenImageEditPipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
self.model.to(self.device_torch, dtype=self.torch_dtype)
sc = self.get_bucket_divisibility()
gen_config.width = int(gen_config.width // sc * sc)
gen_config.height = int(gen_config.height // sc * sc)
control_img = None
if gen_config.ctrl_img is not None:
control_img = Image.open(gen_config.ctrl_img)
control_img = control_img.convert("RGB")
# resize to width and height
if control_img.size != (gen_config.width, gen_config.height):
control_img = control_img.resize(
(gen_config.width, gen_config.height), Image.BILINEAR
)
# flush for low vram if we are doing that
flush_between_steps = self.model_config.low_vram
# Fix a bug in diffusers/torch
def callback_on_step_end(pipe, i, t, callback_kwargs):
if flush_between_steps:
flush()
latents = callback_kwargs["latents"]
return {"latents": latents}
img = pipeline(
image=control_img,
prompt_embeds=conditional_embeds.text_embeds,
prompt_embeds_mask=conditional_embeds.attention_mask.to(
self.device_torch, dtype=torch.int64
),
negative_prompt_embeds=unconditional_embeds.text_embeds,
negative_prompt_embeds_mask=unconditional_embeds.attention_mask.to(
self.device_torch, dtype=torch.int64
),
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
true_cfg_scale=gen_config.guidance_scale,
latents=gen_config.latents,
generator=generator,
callback_on_step_end=callback_on_step_end,
**extra,
).images[0]
return img
def condition_noisy_latents(
self, latents: torch.Tensor, batch: "DataLoaderBatchDTO"
):
with torch.no_grad():
control_tensor = batch.control_tensor
if control_tensor is not None:
self.vae.to(self.device_torch)
# we are not packed here, so we just need to pass them so we can pack them later
control_tensor = control_tensor * 2 - 1
control_tensor = control_tensor.to(
self.vae_device_torch, dtype=self.torch_dtype
)
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
if batch.tensor is not None:
target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3]
else:
# When caching latents, batch.tensor is None. We get the size from the file_items instead.
target_h = batch.file_items[0].crop_height
target_w = batch.file_items[0].crop_width
if (
control_tensor.shape[2] != target_h
or control_tensor.shape[3] != target_w
):
control_tensor = F.interpolate(
control_tensor, size=(target_h, target_w), mode="bilinear"
)
control_latent = self.encode_images(control_tensor).to(
latents.device, latents.dtype
)
latents = torch.cat((latents, control_latent), dim=1)
return latents.detach()
def get_prompt_embeds(self, prompt: str, control_images=None) -> PromptEmbeds:
if self.pipeline.text_encoder.device != self.device_torch:
self.pipeline.text_encoder.to(self.device_torch)
if control_images is not None:
# control images are 0 - 1 scale, shape (bs, ch, height, width)
# images are always run through at 1MP, based on diffusers inference code.
target_area = 1024 * 1024
ratio = control_images.shape[2] / control_images.shape[3]
width = math.sqrt(target_area * ratio)
height = width / ratio
width = round(width / 32) * 32
height = round(height / 32) * 32
control_images = F.interpolate(
control_images, size=(height, width), mode="bilinear"
)
prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt(
prompt,
image=control_images,
device=self.device_torch,
num_images_per_prompt=1,
)
pe = PromptEmbeds(prompt_embeds)
pe.attention_mask = prompt_embeds_mask
return pe
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
**kwargs,
):
# control is stacked on channels, move it to the batch dimension for packing
latent_model_input, control = torch.chunk(latent_model_input, 2, 1)
batch_size, num_channels_latents, height, width = latent_model_input.shape
(
control_batch_size,
control_num_channels_latents,
control_height,
control_width,
) = control.shape
# pack image tokens
latent_model_input = latent_model_input.view(
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
)
latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5)
latent_model_input = latent_model_input.reshape(
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
)
# pack control
control = control.view(
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
)
control = control.permute(0, 2, 4, 1, 3, 5)
control = control.reshape(
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
)
img_h2, img_w2 = height // 2, width // 2
control_img_h2, control_img_w2 = control_height // 2, control_width // 2
img_shapes = [[(1, img_h2, img_w2), (1, control_img_h2, control_img_w2)]] * batch_size
latents = latent_model_input
latent_model_input = torch.cat([latent_model_input, control], dim=1)
batch_size = latent_model_input.shape[0]
prompt_embeds_mask = text_embeddings.attention_mask.to(
self.device_torch, dtype=torch.int64
)
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist()
enc_hs = text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype)
prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64)
noise_pred = self.transformer(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
timestep=timestep / 1000,
guidance=None,
encoder_hidden_states=enc_hs,
encoder_hidden_states_mask=prompt_embeds_mask,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
return_dict=False,
**kwargs,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]
# unpack
noise_pred = noise_pred.view(
batch_size, height // 2, width // 2, num_channels_latents, 2, 2
)
noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5)
noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width)
return noise_pred

View File

@@ -37,6 +37,8 @@ from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtracto
from toolkit.util.wavelet_loss import wavelet_loss from toolkit.util.wavelet_loss import wavelet_loss
import torch.nn.functional as F import torch.nn.functional as F
from toolkit.unloader import unload_text_encoder from toolkit.unloader import unload_text_encoder
from PIL import Image
from torchvision.transforms import functional as TF
def flush(): def flush():
@@ -127,9 +129,28 @@ class SDTrainer(BaseSDTrainProcess):
prompt=prompt, # it will autoparse the prompt prompt=prompt, # it will autoparse the prompt
negative_prompt=sample_item.neg, negative_prompt=sample_item.neg,
output_path=output_path, output_path=output_path,
ctrl_img=sample_item.ctrl_img
) )
positive = self.sd.encode_prompt(gen_img_config.prompt).to('cpu') # see if we need to encode the control images
negative = self.sd.encode_prompt(gen_img_config.negative_prompt).to('cpu') if self.sd.encode_control_in_text_embeddings and gen_img_config.ctrl_img is not None:
ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB")
# convert to 0 to 1 tensor
ctrl_img = (
TF.to_tensor(ctrl_img)
.unsqueeze(0)
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
)
positive = self.sd.encode_prompt(
gen_img_config.prompt,
control_images=ctrl_img
).to('cpu')
negative = self.sd.encode_prompt(
gen_img_config.negative_prompt,
control_images=ctrl_img
).to('cpu')
else:
positive = self.sd.encode_prompt(gen_img_config.prompt).to('cpu')
negative = self.sd.encode_prompt(gen_img_config.negative_prompt).to('cpu')
self.sd.sample_prompts_cache.append({ self.sd.sample_prompts_cache.append({
'conditional': positive, 'conditional': positive,
@@ -177,9 +198,15 @@ class SDTrainer(BaseSDTrainProcess):
# cache unconditional embeds (blank prompt) # cache unconditional embeds (blank prompt)
with torch.no_grad(): with torch.no_grad():
kwargs = {}
if self.sd.encode_control_in_text_embeddings:
# just do a blank image for unconditionals
control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype)
kwargs['control_images'] = control_image
self.unconditional_embeds = self.sd.encode_prompt( self.unconditional_embeds = self.sd.encode_prompt(
[self.train_config.unconditional_prompt], [self.train_config.unconditional_prompt],
long_prompts=self.do_long_prompts long_prompts=self.do_long_prompts,
**kwargs
).to( ).to(
self.device_torch, self.device_torch,
dtype=self.sd.torch_dtype dtype=self.sd.torch_dtype
@@ -241,9 +268,14 @@ class SDTrainer(BaseSDTrainProcess):
print_acc("***********************************") print_acc("***********************************")
print_acc("") print_acc("")
self.sd.text_encoder_to(self.device_torch) self.sd.text_encoder_to(self.device_torch)
self.cached_blank_embeds = self.sd.encode_prompt("") encode_kwargs = {}
if self.sd.encode_control_in_text_embeddings:
# just do a blank image for unconditionals
control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype)
encode_kwargs['control_images'] = control_image
self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs)
if self.trigger_word is not None: if self.trigger_word is not None:
self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word) self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word, **encode_kwargs)
if self.train_config.diff_output_preservation: if self.train_config.diff_output_preservation:
self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class) self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class)
@@ -967,11 +999,16 @@ class SDTrainer(BaseSDTrainProcess):
if batch.prompt_embeds is not None: if batch.prompt_embeds is not None:
embeds_to_use = batch.prompt_embeds.clone().to(self.device_torch, dtype=dtype) embeds_to_use = batch.prompt_embeds.clone().to(self.device_torch, dtype=dtype)
else: else:
prompt_kwargs = {}
if self.sd.encode_control_in_text_embeddings and batch.control_tensor is not None:
prompt_kwargs['control_images'] = batch.control_tensor.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
embeds_to_use = self.sd.encode_prompt( embeds_to_use = self.sd.encode_prompt(
prompt_list, prompt_list,
long_prompts=self.do_long_prompts).to( long_prompts=self.do_long_prompts).to(
self.device_torch, self.device_torch,
dtype=dtype).detach() dtype=dtype,
**prompt_kwargs
).detach()
# dont use network on this # dont use network on this
# self.network.multiplier = 0.0 # self.network.multiplier = 0.0
@@ -1338,6 +1375,9 @@ class SDTrainer(BaseSDTrainProcess):
with self.timer('encode_prompt'): with self.timer('encode_prompt'):
unconditional_embeds = None unconditional_embeds = None
prompt_kwargs = {}
if self.sd.encode_control_in_text_embeddings and batch.control_tensor is not None:
prompt_kwargs['control_images'] = batch.control_tensor.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
if self.train_config.unload_text_encoder or self.is_caching_text_embeddings: if self.train_config.unload_text_encoder or self.is_caching_text_embeddings:
with torch.set_grad_enabled(False): with torch.set_grad_enabled(False):
if batch.prompt_embeds is not None: if batch.prompt_embeds is not None:
@@ -1374,7 +1414,9 @@ class SDTrainer(BaseSDTrainProcess):
conditional_embeds = self.sd.encode_prompt( conditional_embeds = self.sd.encode_prompt(
conditioned_prompts, prompt_2, conditioned_prompts, prompt_2,
dropout_prob=self.train_config.prompt_dropout_prob, dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to( long_prompts=self.do_long_prompts,
**prompt_kwargs
).to(
self.device_torch, self.device_torch,
dtype=dtype) dtype=dtype)
@@ -1386,7 +1428,9 @@ class SDTrainer(BaseSDTrainProcess):
self.batch_negative_prompt, self.batch_negative_prompt,
self.batch_negative_prompt, self.batch_negative_prompt,
dropout_prob=self.train_config.prompt_dropout_prob, dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to( long_prompts=self.do_long_prompts,
**prompt_kwargs
).to(
self.device_torch, self.device_torch,
dtype=dtype) dtype=dtype)
if isinstance(self.adapter, CustomAdapter): if isinstance(self.adapter, CustomAdapter):
@@ -1404,7 +1448,9 @@ class SDTrainer(BaseSDTrainProcess):
conditional_embeds = self.sd.encode_prompt( conditional_embeds = self.sd.encode_prompt(
conditioned_prompts, prompt_2, conditioned_prompts, prompt_2,
dropout_prob=self.train_config.prompt_dropout_prob, dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to( long_prompts=self.do_long_prompts,
**prompt_kwargs
).to(
self.device_torch, self.device_torch,
dtype=dtype) dtype=dtype)
if self.train_config.do_cfg: if self.train_config.do_cfg:
@@ -1413,7 +1459,9 @@ class SDTrainer(BaseSDTrainProcess):
unconditional_embeds = self.sd.encode_prompt( unconditional_embeds = self.sd.encode_prompt(
self.batch_negative_prompt, self.batch_negative_prompt,
dropout_prob=self.train_config.prompt_dropout_prob, dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to( long_prompts=self.do_long_prompts,
**prompt_kwargs
).to(
self.device_torch, self.device_torch,
dtype=dtype) dtype=dtype)
if isinstance(self.adapter, CustomAdapter): if isinstance(self.adapter, CustomAdapter):
@@ -1427,7 +1475,9 @@ class SDTrainer(BaseSDTrainProcess):
self.diff_output_preservation_embeds = self.sd.encode_prompt( self.diff_output_preservation_embeds = self.sd.encode_prompt(
dop_prompts, dop_prompts_2, dop_prompts, dop_prompts_2,
dropout_prob=self.train_config.prompt_dropout_prob, dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to( long_prompts=self.do_long_prompts,
**prompt_kwargs
).to(
self.device_torch, self.device_torch,
dtype=dtype) dtype=dtype)
# detach the embeddings # detach the embeddings

View File

@@ -1,7 +1,7 @@
torchao==0.10.0 torchao==0.10.0
safetensors safetensors
git+https://github.com/jaretburkett/easy_dwpose.git git+https://github.com/jaretburkett/easy_dwpose.git
git+https://github.com/huggingface/diffusers@7ea065c5070a5278259e6f1effa9dccea232e62a git+https://github.com/huggingface/diffusers@7a2b78bf0f788d311cc96b61e660a8e13e3b1e63
transformers==4.52.4 transformers==4.52.4
lycoris-lora==1.8.3 lycoris-lora==1.8.3
flatten_json flatten_json

View File

@@ -1232,5 +1232,10 @@ def validate_configs(
for dataset in dataset_configs: for dataset in dataset_configs:
if not dataset.cache_text_embeddings: if not dataset.cache_text_embeddings:
raise ValueError("All datasets must have cache_text_embeddings set to True when caching text embeddings is enabled.") raise ValueError("All datasets must have cache_text_embeddings set to True when caching text embeddings is enabled.")
# qwen image edit cannot cache text embeddings
if model_config.arch == 'qwen_image_edit':
if train_config.unload_text_encoder:
raise ValueError("Cannot cache unload text encoder with qwen_image_edit model. Control images are encoded with text embeddings. You can cache the text embeddings though")

View File

@@ -497,6 +497,7 @@ class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin
dataloader_transforms=self.transform, dataloader_transforms=self.transform,
size_database=self.size_database, size_database=self.size_database,
dataset_root=dataset_folder, dataset_root=dataset_folder,
encode_control_in_text_embeddings=self.sd.encode_control_in_text_embeddings if self.sd else False,
) )
self.file_list.append(file_item) self.file_list.append(file_item)
except Exception as e: except Exception as e:

View File

@@ -50,6 +50,7 @@ class FileItemDTO(
self.is_video = self.dataset_config.num_frames > 1 self.is_video = self.dataset_config.num_frames > 1
size_database = kwargs.get('size_database', {}) size_database = kwargs.get('size_database', {})
dataset_root = kwargs.get('dataset_root', None) dataset_root = kwargs.get('dataset_root', None)
self.encode_control_in_text_embeddings = kwargs.get('encode_control_in_text_embeddings', False)
if dataset_root is not None: if dataset_root is not None:
# remove dataset root from path # remove dataset root from path
file_key = self.path.replace(dataset_root, '') file_key = self.path.replace(dataset_root, '')

View File

@@ -30,6 +30,7 @@ import albumentations as A
from toolkit.print import print_acc from toolkit.print import print_acc
from toolkit.accelerator import get_accelerator from toolkit.accelerator import get_accelerator
from toolkit.prompt_utils import PromptEmbeds from toolkit.prompt_utils import PromptEmbeds
from torchvision.transforms import functional as TF
from toolkit.train_tools import get_torch_dtype from toolkit.train_tools import get_torch_dtype
@@ -1802,6 +1803,9 @@ class TextEmbeddingFileItemDTOMixin:
("text_embedding_space_version", self.text_embedding_space_version), ("text_embedding_space_version", self.text_embedding_space_version),
("text_embedding_version", self.text_embedding_version), ("text_embedding_version", self.text_embedding_version),
]) ])
# if we have a control image, cache the path
if self.encode_control_in_text_embeddings and self.control_path is not None:
item["control_path"] = self.control_path
return item return item
def get_text_embedding_path(self: 'FileItemDTO', recalculate=False): def get_text_embedding_path(self: 'FileItemDTO', recalculate=False):
@@ -1860,7 +1864,19 @@ class TextEmbeddingCachingMixin:
if not did_move: if not did_move:
self.sd.set_device_state_preset('cache_text_encoder') self.sd.set_device_state_preset('cache_text_encoder')
did_move = True did_move = True
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)
if file_item.encode_control_in_text_embeddings and file_item.control_path is not None:
# load the control image and feed it into the text encoder
ctrl_img = Image.open(file_item.control_path).convert("RGB")
# convert to 0 to 1 tensor
ctrl_img = (
TF.to_tensor(ctrl_img)
.unsqueeze(0)
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
)
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption, control_images=ctrl_img)
else:
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)
# save it # save it
prompt_embeds.save(text_embedding_path) prompt_embeds.save(text_embedding_path)
del prompt_embeds del prompt_embeds

View File

@@ -36,6 +36,7 @@ from diffusers import \
UNet2DConditionModel UNet2DConditionModel
from diffusers import PixArtAlphaPipeline from diffusers import PixArtAlphaPipeline
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
from torchvision.transforms import functional as TF
from toolkit.accelerator import get_accelerator, unwrap_model from toolkit.accelerator import get_accelerator, unwrap_model
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -177,6 +178,9 @@ class BaseModel:
self.multistage_boundaries: List[float] = [0.0] self.multistage_boundaries: List[float] = [0.0]
# a list of trainable multistage boundaries # a list of trainable multistage boundaries
self.trainable_multistage_boundaries: List[int] = [0] self.trainable_multistage_boundaries: List[int] = [0]
# set true for models that encode control image into text embeddings
self.encode_control_in_text_embeddings = False
# properties for old arch for backwards compatibility # properties for old arch for backwards compatibility
@property @property
@@ -287,7 +291,7 @@ class BaseModel:
raise NotImplementedError( raise NotImplementedError(
"get_noise_prediction must be implemented in child classes") "get_noise_prediction must be implemented in child classes")
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: def get_prompt_embeds(self, prompt: str, control_images=None) -> PromptEmbeds:
raise NotImplementedError( raise NotImplementedError(
"get_prompt_embeds must be implemented in child classes") "get_prompt_embeds must be implemented in child classes")
@@ -496,17 +500,34 @@ class BaseModel:
if self.sample_prompts_cache is not None: if self.sample_prompts_cache is not None:
conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype) conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype)
unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype) unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype)
else: else:
ctrl_img = None
# load the control image if out model uses it in text encoding
if gen_config.ctrl_img is not None and self.encode_control_in_text_embeddings:
ctrl_img = Image.open(gen_config.ctrl_img).convert("RGB")
# convert to 0 to 1 tensor
ctrl_img = (
TF.to_tensor(ctrl_img)
.unsqueeze(0)
.to(self.device_torch, dtype=self.torch_dtype)
)
# encode the prompt ourselves so we can do fun stuff with embeddings # encode the prompt ourselves so we can do fun stuff with embeddings
if isinstance(self.adapter, CustomAdapter): if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False self.adapter.is_unconditional_run = False
conditional_embeds = self.encode_prompt( conditional_embeds = self.encode_prompt(
gen_config.prompt, gen_config.prompt_2, force_all=True) gen_config.prompt,
gen_config.prompt_2,
force_all=True,
control_images=ctrl_img
)
if isinstance(self.adapter, CustomAdapter): if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True self.adapter.is_unconditional_run = True
unconditional_embeds = self.encode_prompt( unconditional_embeds = self.encode_prompt(
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True gen_config.negative_prompt,
gen_config.negative_prompt_2,
force_all=True,
control_images=ctrl_img
) )
if isinstance(self.adapter, CustomAdapter): if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False self.adapter.is_unconditional_run = False
@@ -989,6 +1010,7 @@ class BaseModel:
long_prompts=False, long_prompts=False,
max_length=None, max_length=None,
dropout_prob=0.0, dropout_prob=0.0,
control_images=None,
) -> PromptEmbeds: ) -> PromptEmbeds:
# sd1.5 embeddings are (bs, 77, 768) # sd1.5 embeddings are (bs, 77, 768)
prompt = prompt prompt = prompt
@@ -998,6 +1020,9 @@ class BaseModel:
if prompt2 is not None and not isinstance(prompt2, list): if prompt2 is not None and not isinstance(prompt2, list):
prompt2 = [prompt2] prompt2 = [prompt2]
# if control_images in the signature, pass it. This keep from breaking plugins
if self.encode_control_in_text_embeddings:
return self.get_prompt_embeds(prompt, control_images=control_images)
return self.get_prompt_embeds(prompt) return self.get_prompt_embeds(prompt)

View File

@@ -217,6 +217,9 @@ class StableDiffusion:
# a list of trainable multistage boundaries # a list of trainable multistage boundaries
self.trainable_multistage_boundaries: List[int] = [0] self.trainable_multistage_boundaries: List[int] = [0]
# set true for models that encode control image into text embeddings
self.encode_control_in_text_embeddings = False
# properties for old arch for backwards compatibility # properties for old arch for backwards compatibility
@property @property
def is_xl(self): def is_xl(self):
@@ -2356,6 +2359,7 @@ class StableDiffusion:
long_prompts=False, long_prompts=False,
max_length=None, max_length=None,
dropout_prob=0.0, dropout_prob=0.0,
control_images=None,
) -> PromptEmbeds: ) -> PromptEmbeds:
# sd1.5 embeddings are (bs, 77, 768) # sd1.5 embeddings are (bs, 77, 768)
prompt = prompt prompt = prompt

View File

@@ -10,7 +10,7 @@ type AdditionalSections =
| 'datasets.num_frames' | 'datasets.num_frames'
| 'model.multistage' | 'model.multistage'
| 'model.low_vram'; | 'model.low_vram';
type ModelGroup = 'image' | 'video'; type ModelGroup = 'image' | 'instruction' | 'video';
export interface ModelArch { export interface ModelArch {
name: string; name: string;
@@ -44,7 +44,7 @@ export const modelArchs: ModelArch[] = [
{ {
name: 'flux_kontext', name: 'flux_kontext',
label: 'FLUX.1-Kontext-dev', label: 'FLUX.1-Kontext-dev',
group: 'image', group: 'instruction',
defaults: { defaults: {
// default updates when [selected, unselected] in the UI // default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-Kontext-dev', defaultNameOrPath], 'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-Kontext-dev', defaultNameOrPath],
@@ -306,6 +306,27 @@ export const modelArchs: ModelArch[] = [
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors', '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors',
}, },
}, },
{
name: 'qwen_image_edit',
label: 'Qwen-Image-Edit',
group: 'instruction',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['Qwen/Qwen-Image-Edit', defaultNameOrPath],
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.low_vram': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
},
disableSections: ['network.conv'],
additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram'],
accuracyRecoveryAdapters: {
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_torchao_uint3.safetensors',
},
},
{ {
name: 'hidream', name: 'hidream',
label: 'HiDream', label: 'HiDream',
@@ -327,7 +348,7 @@ export const modelArchs: ModelArch[] = [
{ {
name: 'hidream_e1', name: 'hidream_e1',
label: 'HiDream E1', label: 'HiDream E1',
group: 'image', group: 'instruction',
defaults: { defaults: {
// default updates when [selected, unselected] in the UI // default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-E1-1', defaultNameOrPath], 'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-E1-1', defaultNameOrPath],

View File

@@ -1 +1 @@
VERSION = "0.5.3" VERSION = "0.5.4"