Merge pull request #383 from ostris/qwen_image_edit

Add support for Qwen-Image-Edit
This commit is contained in:
Jaret Burkett
2025-08-22 10:29:25 -06:00
committed by GitHub
14 changed files with 448 additions and 41 deletions

View File

@@ -4,7 +4,7 @@ from .f_light import FLiteModel
from .omnigen2 import OmniGen2Model
from .flux_kontext import FluxKontextModel
from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel
from .qwen_image import QwenImageModel
from .qwen_image import QwenImageModel, QwenImageEditModel
AI_TOOLKIT_MODELS = [
# put a list of models here
@@ -18,4 +18,5 @@ AI_TOOLKIT_MODELS = [
Wan2214bI2VModel,
Wan2214bModel,
QwenImageModel,
QwenImageEditModel,
]

View File

@@ -1 +1,2 @@
from .qwen_image import QwenImageModel
from .qwen_image_edit import QwenImageEditModel

View File

@@ -16,7 +16,7 @@ from toolkit.util.quantize import quantize, get_qtype, quantize_model
import torch.nn.functional as F
from diffusers import QwenImagePipeline, QwenImageTransformer2DModel, AutoencoderKLQwenImage
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
from tqdm import tqdm
if TYPE_CHECKING:
@@ -43,7 +43,8 @@ scheduler_config = {
class QwenImageModel(BaseModel):
arch = "qwen_image"
_qwen_image_keep_processor = False
_qwen_image_keep_visual = False
_qwen_pipeline = QwenImagePipeline
def __init__(
self,
@@ -119,10 +120,9 @@ class QwenImageModel(BaseModel):
# remove the visual model as it is not needed for image generation
self.processor = None
if self._qwen_image_keep_processor:
self.processor = text_encoder.model.visual
text_encoder.model.visual = None
if not self._qwen_image_keep_visual:
text_encoder.model.visual = None
text_encoder.to(self.device_torch, dtype=dtype)
flush()
@@ -140,13 +140,27 @@ class QwenImageModel(BaseModel):
self.noise_scheduler = QwenImageModel.get_train_scheduler()
self.print_and_status_update("Making pipe")
kwargs = {}
if self._qwen_image_keep_visual:
try:
self.processor = Qwen2VLProcessor.from_pretrained(
model_path, subfolder="processor"
)
except OSError:
self.processor = Qwen2VLProcessor.from_pretrained(
base_model_path, subfolder="processor"
)
kwargs['processor'] = self.processor
pipe: QwenImagePipeline = QwenImagePipeline(
pipe: QwenImagePipeline = self._qwen_pipeline(
scheduler=self.noise_scheduler,
text_encoder=None,
tokenizer=tokenizer,
vae=vae,
transformer=None,
**kwargs
)
# for quantization, it works best to do these after making the pipe
pipe.text_encoder = text_encoder
@@ -261,21 +275,13 @@ class QwenImageModel(BaseModel):
latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5)
latent_model_input = latent_model_input.reshape(batch_size, (height // ps) * (width // ps), num_channels_latents * (ps * ps))
# clamp text length to RoPE capacity for this image size
# img_shapes passed to the model
img_h2, img_w2 = height // ps, width // ps
img_shapes = [(1, img_h2, img_w2)] * batch_size
img_shapes = [[(1, img_h2, img_w2)]] * batch_size
# QwenEmbedRope logic:
max_vid_index = max(img_h2 // ps, img_w2 // ps)
rope_cap = 1024 - max_vid_index # available text positions in RoPE cache
seq_len_actual = text_embeddings.text_embeds.shape[1]
use_len = min(seq_len_actual, rope_cap)
enc_hs = text_embeddings.text_embeds[:, :use_len].to(self.device_torch, self.torch_dtype)
prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64)[:, :use_len]
txt_seq_lens = [use_len] * batch_size
enc_hs = text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype)
prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64)
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist()
noise_pred = self.transformer(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),

View File

@@ -0,0 +1,276 @@
import math
import torch
from .qwen_image import QwenImageModel
import os
from typing import TYPE_CHECKING, List, Optional
import yaml
from toolkit import train_tools
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from PIL import Image
from toolkit.models.base_model import BaseModel
from toolkit.basic import flush
from toolkit.prompt_utils import PromptEmbeds
from toolkit.samplers.custom_flowmatch_sampler import (
CustomFlowMatchEulerDiscreteScheduler,
)
from toolkit.accelerator import get_accelerator, unwrap_model
from optimum.quanto import freeze, QTensor
from toolkit.util.quantize import quantize, get_qtype, quantize_model
import torch.nn.functional as F
from diffusers import (
QwenImagePipeline,
QwenImageTransformer2DModel,
AutoencoderKLQwenImage,
)
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
from tqdm import tqdm
if TYPE_CHECKING:
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
try:
from diffusers import QwenImageEditPipeline
except ImportError:
raise ImportError(
"QwenImageEditPipeline not found. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt"
)
class QwenImageEditModel(QwenImageModel):
arch = "qwen_image_edit"
_qwen_image_keep_visual = True
_qwen_pipeline = QwenImageEditPipeline
def __init__(
self,
device,
model_config: ModelConfig,
dtype="bf16",
custom_pipeline=None,
noise_scheduler=None,
**kwargs,
):
super().__init__(
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
)
self.is_flow_matching = True
self.is_transformer = True
self.target_lora_modules = ["QwenImageTransformer2DModel"]
# set true for models that encode control image into text embeddings
self.encode_control_in_text_embeddings = True
def load_model(self):
super().load_model()
def get_generation_pipeline(self):
scheduler = QwenImageModel.get_train_scheduler()
pipeline: QwenImageEditPipeline = QwenImageEditPipeline(
scheduler=scheduler,
text_encoder=unwrap_model(self.text_encoder[0]),
tokenizer=self.tokenizer[0],
processor=self.processor,
vae=unwrap_model(self.vae),
transformer=unwrap_model(self.transformer),
)
pipeline = pipeline.to(self.device_torch)
return pipeline
def generate_single_image(
self,
pipeline: QwenImageEditPipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
self.model.to(self.device_torch, dtype=self.torch_dtype)
sc = self.get_bucket_divisibility()
gen_config.width = int(gen_config.width // sc * sc)
gen_config.height = int(gen_config.height // sc * sc)
control_img = None
if gen_config.ctrl_img is not None:
control_img = Image.open(gen_config.ctrl_img)
control_img = control_img.convert("RGB")
# resize to width and height
if control_img.size != (gen_config.width, gen_config.height):
control_img = control_img.resize(
(gen_config.width, gen_config.height), Image.BILINEAR
)
# flush for low vram if we are doing that
flush_between_steps = self.model_config.low_vram
# Fix a bug in diffusers/torch
def callback_on_step_end(pipe, i, t, callback_kwargs):
if flush_between_steps:
flush()
latents = callback_kwargs["latents"]
return {"latents": latents}
img = pipeline(
image=control_img,
prompt_embeds=conditional_embeds.text_embeds,
prompt_embeds_mask=conditional_embeds.attention_mask.to(
self.device_torch, dtype=torch.int64
),
negative_prompt_embeds=unconditional_embeds.text_embeds,
negative_prompt_embeds_mask=unconditional_embeds.attention_mask.to(
self.device_torch, dtype=torch.int64
),
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
true_cfg_scale=gen_config.guidance_scale,
latents=gen_config.latents,
generator=generator,
callback_on_step_end=callback_on_step_end,
**extra,
).images[0]
return img
def condition_noisy_latents(
self, latents: torch.Tensor, batch: "DataLoaderBatchDTO"
):
with torch.no_grad():
control_tensor = batch.control_tensor
if control_tensor is not None:
self.vae.to(self.device_torch)
# we are not packed here, so we just need to pass them so we can pack them later
control_tensor = control_tensor * 2 - 1
control_tensor = control_tensor.to(
self.vae_device_torch, dtype=self.torch_dtype
)
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
if batch.tensor is not None:
target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3]
else:
# When caching latents, batch.tensor is None. We get the size from the file_items instead.
target_h = batch.file_items[0].crop_height
target_w = batch.file_items[0].crop_width
if (
control_tensor.shape[2] != target_h
or control_tensor.shape[3] != target_w
):
control_tensor = F.interpolate(
control_tensor, size=(target_h, target_w), mode="bilinear"
)
control_latent = self.encode_images(control_tensor).to(
latents.device, latents.dtype
)
latents = torch.cat((latents, control_latent), dim=1)
return latents.detach()
def get_prompt_embeds(self, prompt: str, control_images=None) -> PromptEmbeds:
if self.pipeline.text_encoder.device != self.device_torch:
self.pipeline.text_encoder.to(self.device_torch)
if control_images is not None:
# control images are 0 - 1 scale, shape (bs, ch, height, width)
# images are always run through at 1MP, based on diffusers inference code.
target_area = 1024 * 1024
ratio = control_images.shape[2] / control_images.shape[3]
width = math.sqrt(target_area * ratio)
height = width / ratio
width = round(width / 32) * 32
height = round(height / 32) * 32
control_images = F.interpolate(
control_images, size=(height, width), mode="bilinear"
)
prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt(
prompt,
image=control_images,
device=self.device_torch,
num_images_per_prompt=1,
)
pe = PromptEmbeds(prompt_embeds)
pe.attention_mask = prompt_embeds_mask
return pe
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
**kwargs,
):
# control is stacked on channels, move it to the batch dimension for packing
latent_model_input, control = torch.chunk(latent_model_input, 2, 1)
batch_size, num_channels_latents, height, width = latent_model_input.shape
(
control_batch_size,
control_num_channels_latents,
control_height,
control_width,
) = control.shape
# pack image tokens
latent_model_input = latent_model_input.view(
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
)
latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5)
latent_model_input = latent_model_input.reshape(
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
)
# pack control
control = control.view(
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
)
control = control.permute(0, 2, 4, 1, 3, 5)
control = control.reshape(
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
)
img_h2, img_w2 = height // 2, width // 2
control_img_h2, control_img_w2 = control_height // 2, control_width // 2
img_shapes = [[(1, img_h2, img_w2), (1, control_img_h2, control_img_w2)]] * batch_size
latents = latent_model_input
latent_model_input = torch.cat([latent_model_input, control], dim=1)
batch_size = latent_model_input.shape[0]
prompt_embeds_mask = text_embeddings.attention_mask.to(
self.device_torch, dtype=torch.int64
)
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist()
enc_hs = text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype)
prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64)
noise_pred = self.transformer(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
timestep=timestep / 1000,
guidance=None,
encoder_hidden_states=enc_hs,
encoder_hidden_states_mask=prompt_embeds_mask,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
return_dict=False,
**kwargs,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]
# unpack
noise_pred = noise_pred.view(
batch_size, height // 2, width // 2, num_channels_latents, 2, 2
)
noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5)
noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width)
return noise_pred

View File

@@ -37,6 +37,8 @@ from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtracto
from toolkit.util.wavelet_loss import wavelet_loss
import torch.nn.functional as F
from toolkit.unloader import unload_text_encoder
from PIL import Image
from torchvision.transforms import functional as TF
def flush():
@@ -127,9 +129,28 @@ class SDTrainer(BaseSDTrainProcess):
prompt=prompt, # it will autoparse the prompt
negative_prompt=sample_item.neg,
output_path=output_path,
ctrl_img=sample_item.ctrl_img
)
positive = self.sd.encode_prompt(gen_img_config.prompt).to('cpu')
negative = self.sd.encode_prompt(gen_img_config.negative_prompt).to('cpu')
# see if we need to encode the control images
if self.sd.encode_control_in_text_embeddings and gen_img_config.ctrl_img is not None:
ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB")
# convert to 0 to 1 tensor
ctrl_img = (
TF.to_tensor(ctrl_img)
.unsqueeze(0)
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
)
positive = self.sd.encode_prompt(
gen_img_config.prompt,
control_images=ctrl_img
).to('cpu')
negative = self.sd.encode_prompt(
gen_img_config.negative_prompt,
control_images=ctrl_img
).to('cpu')
else:
positive = self.sd.encode_prompt(gen_img_config.prompt).to('cpu')
negative = self.sd.encode_prompt(gen_img_config.negative_prompt).to('cpu')
self.sd.sample_prompts_cache.append({
'conditional': positive,
@@ -177,9 +198,15 @@ class SDTrainer(BaseSDTrainProcess):
# cache unconditional embeds (blank prompt)
with torch.no_grad():
kwargs = {}
if self.sd.encode_control_in_text_embeddings:
# just do a blank image for unconditionals
control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype)
kwargs['control_images'] = control_image
self.unconditional_embeds = self.sd.encode_prompt(
[self.train_config.unconditional_prompt],
long_prompts=self.do_long_prompts
long_prompts=self.do_long_prompts,
**kwargs
).to(
self.device_torch,
dtype=self.sd.torch_dtype
@@ -241,9 +268,14 @@ class SDTrainer(BaseSDTrainProcess):
print_acc("***********************************")
print_acc("")
self.sd.text_encoder_to(self.device_torch)
self.cached_blank_embeds = self.sd.encode_prompt("")
encode_kwargs = {}
if self.sd.encode_control_in_text_embeddings:
# just do a blank image for unconditionals
control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype)
encode_kwargs['control_images'] = control_image
self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs)
if self.trigger_word is not None:
self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word)
self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word, **encode_kwargs)
if self.train_config.diff_output_preservation:
self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class)
@@ -967,11 +999,16 @@ class SDTrainer(BaseSDTrainProcess):
if batch.prompt_embeds is not None:
embeds_to_use = batch.prompt_embeds.clone().to(self.device_torch, dtype=dtype)
else:
prompt_kwargs = {}
if self.sd.encode_control_in_text_embeddings and batch.control_tensor is not None:
prompt_kwargs['control_images'] = batch.control_tensor.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
embeds_to_use = self.sd.encode_prompt(
prompt_list,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype).detach()
dtype=dtype,
**prompt_kwargs
).detach()
# dont use network on this
# self.network.multiplier = 0.0
@@ -1338,6 +1375,9 @@ class SDTrainer(BaseSDTrainProcess):
with self.timer('encode_prompt'):
unconditional_embeds = None
prompt_kwargs = {}
if self.sd.encode_control_in_text_embeddings and batch.control_tensor is not None:
prompt_kwargs['control_images'] = batch.control_tensor.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
if self.train_config.unload_text_encoder or self.is_caching_text_embeddings:
with torch.set_grad_enabled(False):
if batch.prompt_embeds is not None:
@@ -1374,7 +1414,9 @@ class SDTrainer(BaseSDTrainProcess):
conditional_embeds = self.sd.encode_prompt(
conditioned_prompts, prompt_2,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to(
long_prompts=self.do_long_prompts,
**prompt_kwargs
).to(
self.device_torch,
dtype=dtype)
@@ -1386,7 +1428,9 @@ class SDTrainer(BaseSDTrainProcess):
self.batch_negative_prompt,
self.batch_negative_prompt,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to(
long_prompts=self.do_long_prompts,
**prompt_kwargs
).to(
self.device_torch,
dtype=dtype)
if isinstance(self.adapter, CustomAdapter):
@@ -1404,7 +1448,9 @@ class SDTrainer(BaseSDTrainProcess):
conditional_embeds = self.sd.encode_prompt(
conditioned_prompts, prompt_2,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to(
long_prompts=self.do_long_prompts,
**prompt_kwargs
).to(
self.device_torch,
dtype=dtype)
if self.train_config.do_cfg:
@@ -1413,7 +1459,9 @@ class SDTrainer(BaseSDTrainProcess):
unconditional_embeds = self.sd.encode_prompt(
self.batch_negative_prompt,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to(
long_prompts=self.do_long_prompts,
**prompt_kwargs
).to(
self.device_torch,
dtype=dtype)
if isinstance(self.adapter, CustomAdapter):
@@ -1427,7 +1475,9 @@ class SDTrainer(BaseSDTrainProcess):
self.diff_output_preservation_embeds = self.sd.encode_prompt(
dop_prompts, dop_prompts_2,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to(
long_prompts=self.do_long_prompts,
**prompt_kwargs
).to(
self.device_torch,
dtype=dtype)
# detach the embeddings

View File

@@ -1,7 +1,7 @@
torchao==0.10.0
safetensors
git+https://github.com/jaretburkett/easy_dwpose.git
git+https://github.com/huggingface/diffusers@7ea065c5070a5278259e6f1effa9dccea232e62a
git+https://github.com/huggingface/diffusers@7a2b78bf0f788d311cc96b61e660a8e13e3b1e63
transformers==4.52.4
lycoris-lora==1.8.3
flatten_json

View File

@@ -1232,5 +1232,10 @@ def validate_configs(
for dataset in dataset_configs:
if not dataset.cache_text_embeddings:
raise ValueError("All datasets must have cache_text_embeddings set to True when caching text embeddings is enabled.")
# qwen image edit cannot cache text embeddings
if model_config.arch == 'qwen_image_edit':
if train_config.unload_text_encoder:
raise ValueError("Cannot cache unload text encoder with qwen_image_edit model. Control images are encoded with text embeddings. You can cache the text embeddings though")

View File

@@ -497,6 +497,7 @@ class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin
dataloader_transforms=self.transform,
size_database=self.size_database,
dataset_root=dataset_folder,
encode_control_in_text_embeddings=self.sd.encode_control_in_text_embeddings if self.sd else False,
)
self.file_list.append(file_item)
except Exception as e:

View File

@@ -50,6 +50,7 @@ class FileItemDTO(
self.is_video = self.dataset_config.num_frames > 1
size_database = kwargs.get('size_database', {})
dataset_root = kwargs.get('dataset_root', None)
self.encode_control_in_text_embeddings = kwargs.get('encode_control_in_text_embeddings', False)
if dataset_root is not None:
# remove dataset root from path
file_key = self.path.replace(dataset_root, '')

View File

@@ -30,6 +30,7 @@ import albumentations as A
from toolkit.print import print_acc
from toolkit.accelerator import get_accelerator
from toolkit.prompt_utils import PromptEmbeds
from torchvision.transforms import functional as TF
from toolkit.train_tools import get_torch_dtype
@@ -1802,6 +1803,9 @@ class TextEmbeddingFileItemDTOMixin:
("text_embedding_space_version", self.text_embedding_space_version),
("text_embedding_version", self.text_embedding_version),
])
# if we have a control image, cache the path
if self.encode_control_in_text_embeddings and self.control_path is not None:
item["control_path"] = self.control_path
return item
def get_text_embedding_path(self: 'FileItemDTO', recalculate=False):
@@ -1860,7 +1864,19 @@ class TextEmbeddingCachingMixin:
if not did_move:
self.sd.set_device_state_preset('cache_text_encoder')
did_move = True
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)
if file_item.encode_control_in_text_embeddings and file_item.control_path is not None:
# load the control image and feed it into the text encoder
ctrl_img = Image.open(file_item.control_path).convert("RGB")
# convert to 0 to 1 tensor
ctrl_img = (
TF.to_tensor(ctrl_img)
.unsqueeze(0)
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
)
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption, control_images=ctrl_img)
else:
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)
# save it
prompt_embeds.save(text_embedding_path)
del prompt_embeds

View File

@@ -36,6 +36,7 @@ from diffusers import \
UNet2DConditionModel
from diffusers import PixArtAlphaPipeline
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
from torchvision.transforms import functional as TF
from toolkit.accelerator import get_accelerator, unwrap_model
from typing import TYPE_CHECKING
@@ -177,6 +178,9 @@ class BaseModel:
self.multistage_boundaries: List[float] = [0.0]
# a list of trainable multistage boundaries
self.trainable_multistage_boundaries: List[int] = [0]
# set true for models that encode control image into text embeddings
self.encode_control_in_text_embeddings = False
# properties for old arch for backwards compatibility
@property
@@ -287,7 +291,7 @@ class BaseModel:
raise NotImplementedError(
"get_noise_prediction must be implemented in child classes")
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
def get_prompt_embeds(self, prompt: str, control_images=None) -> PromptEmbeds:
raise NotImplementedError(
"get_prompt_embeds must be implemented in child classes")
@@ -496,17 +500,34 @@ class BaseModel:
if self.sample_prompts_cache is not None:
conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype)
unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype)
else:
else:
ctrl_img = None
# load the control image if out model uses it in text encoding
if gen_config.ctrl_img is not None and self.encode_control_in_text_embeddings:
ctrl_img = Image.open(gen_config.ctrl_img).convert("RGB")
# convert to 0 to 1 tensor
ctrl_img = (
TF.to_tensor(ctrl_img)
.unsqueeze(0)
.to(self.device_torch, dtype=self.torch_dtype)
)
# encode the prompt ourselves so we can do fun stuff with embeddings
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
conditional_embeds = self.encode_prompt(
gen_config.prompt, gen_config.prompt_2, force_all=True)
gen_config.prompt,
gen_config.prompt_2,
force_all=True,
control_images=ctrl_img
)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True
unconditional_embeds = self.encode_prompt(
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
gen_config.negative_prompt,
gen_config.negative_prompt_2,
force_all=True,
control_images=ctrl_img
)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
@@ -989,6 +1010,7 @@ class BaseModel:
long_prompts=False,
max_length=None,
dropout_prob=0.0,
control_images=None,
) -> PromptEmbeds:
# sd1.5 embeddings are (bs, 77, 768)
prompt = prompt
@@ -998,6 +1020,9 @@ class BaseModel:
if prompt2 is not None and not isinstance(prompt2, list):
prompt2 = [prompt2]
# if control_images in the signature, pass it. This keep from breaking plugins
if self.encode_control_in_text_embeddings:
return self.get_prompt_embeds(prompt, control_images=control_images)
return self.get_prompt_embeds(prompt)

View File

@@ -217,6 +217,9 @@ class StableDiffusion:
# a list of trainable multistage boundaries
self.trainable_multistage_boundaries: List[int] = [0]
# set true for models that encode control image into text embeddings
self.encode_control_in_text_embeddings = False
# properties for old arch for backwards compatibility
@property
def is_xl(self):
@@ -2356,6 +2359,7 @@ class StableDiffusion:
long_prompts=False,
max_length=None,
dropout_prob=0.0,
control_images=None,
) -> PromptEmbeds:
# sd1.5 embeddings are (bs, 77, 768)
prompt = prompt

View File

@@ -10,7 +10,7 @@ type AdditionalSections =
| 'datasets.num_frames'
| 'model.multistage'
| 'model.low_vram';
type ModelGroup = 'image' | 'video';
type ModelGroup = 'image' | 'instruction' | 'video';
export interface ModelArch {
name: string;
@@ -44,7 +44,7 @@ export const modelArchs: ModelArch[] = [
{
name: 'flux_kontext',
label: 'FLUX.1-Kontext-dev',
group: 'image',
group: 'instruction',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-Kontext-dev', defaultNameOrPath],
@@ -306,6 +306,27 @@ export const modelArchs: ModelArch[] = [
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors',
},
},
{
name: 'qwen_image_edit',
label: 'Qwen-Image-Edit',
group: 'instruction',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['Qwen/Qwen-Image-Edit', defaultNameOrPath],
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.low_vram': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
},
disableSections: ['network.conv'],
additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram'],
accuracyRecoveryAdapters: {
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_torchao_uint3.safetensors',
},
},
{
name: 'hidream',
label: 'HiDream',
@@ -327,7 +348,7 @@ export const modelArchs: ModelArch[] = [
{
name: 'hidream_e1',
label: 'HiDream E1',
group: 'image',
group: 'instruction',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-E1-1', defaultNameOrPath],

View File

@@ -1 +1 @@
VERSION = "0.5.3"
VERSION = "0.5.4"