mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Merge pull request #383 from ostris/qwen_image_edit
Add support for Qwen-Image-Edit
This commit is contained in:
@@ -4,7 +4,7 @@ from .f_light import FLiteModel
|
||||
from .omnigen2 import OmniGen2Model
|
||||
from .flux_kontext import FluxKontextModel
|
||||
from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel
|
||||
from .qwen_image import QwenImageModel
|
||||
from .qwen_image import QwenImageModel, QwenImageEditModel
|
||||
|
||||
AI_TOOLKIT_MODELS = [
|
||||
# put a list of models here
|
||||
@@ -18,4 +18,5 @@ AI_TOOLKIT_MODELS = [
|
||||
Wan2214bI2VModel,
|
||||
Wan2214bModel,
|
||||
QwenImageModel,
|
||||
QwenImageEditModel,
|
||||
]
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
from .qwen_image import QwenImageModel
|
||||
from .qwen_image_edit import QwenImageEditModel
|
||||
@@ -16,7 +16,7 @@ from toolkit.util.quantize import quantize, get_qtype, quantize_model
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers import QwenImagePipeline, QwenImageTransformer2DModel, AutoencoderKLQwenImage
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
|
||||
from tqdm import tqdm
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -43,7 +43,8 @@ scheduler_config = {
|
||||
|
||||
class QwenImageModel(BaseModel):
|
||||
arch = "qwen_image"
|
||||
_qwen_image_keep_processor = False
|
||||
_qwen_image_keep_visual = False
|
||||
_qwen_pipeline = QwenImagePipeline
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -119,10 +120,9 @@ class QwenImageModel(BaseModel):
|
||||
|
||||
# remove the visual model as it is not needed for image generation
|
||||
self.processor = None
|
||||
if self._qwen_image_keep_processor:
|
||||
self.processor = text_encoder.model.visual
|
||||
text_encoder.model.visual = None
|
||||
|
||||
if not self._qwen_image_keep_visual:
|
||||
text_encoder.model.visual = None
|
||||
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
@@ -140,13 +140,27 @@ class QwenImageModel(BaseModel):
|
||||
self.noise_scheduler = QwenImageModel.get_train_scheduler()
|
||||
|
||||
self.print_and_status_update("Making pipe")
|
||||
|
||||
kwargs = {}
|
||||
|
||||
if self._qwen_image_keep_visual:
|
||||
try:
|
||||
self.processor = Qwen2VLProcessor.from_pretrained(
|
||||
model_path, subfolder="processor"
|
||||
)
|
||||
except OSError:
|
||||
self.processor = Qwen2VLProcessor.from_pretrained(
|
||||
base_model_path, subfolder="processor"
|
||||
)
|
||||
kwargs['processor'] = self.processor
|
||||
|
||||
pipe: QwenImagePipeline = QwenImagePipeline(
|
||||
pipe: QwenImagePipeline = self._qwen_pipeline(
|
||||
scheduler=self.noise_scheduler,
|
||||
text_encoder=None,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
transformer=None,
|
||||
**kwargs
|
||||
)
|
||||
# for quantization, it works best to do these after making the pipe
|
||||
pipe.text_encoder = text_encoder
|
||||
@@ -261,21 +275,13 @@ class QwenImageModel(BaseModel):
|
||||
latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5)
|
||||
latent_model_input = latent_model_input.reshape(batch_size, (height // ps) * (width // ps), num_channels_latents * (ps * ps))
|
||||
|
||||
# clamp text length to RoPE capacity for this image size
|
||||
# img_shapes passed to the model
|
||||
img_h2, img_w2 = height // ps, width // ps
|
||||
img_shapes = [(1, img_h2, img_w2)] * batch_size
|
||||
img_shapes = [[(1, img_h2, img_w2)]] * batch_size
|
||||
|
||||
# QwenEmbedRope logic:
|
||||
max_vid_index = max(img_h2 // ps, img_w2 // ps)
|
||||
|
||||
rope_cap = 1024 - max_vid_index # available text positions in RoPE cache
|
||||
seq_len_actual = text_embeddings.text_embeds.shape[1]
|
||||
use_len = min(seq_len_actual, rope_cap)
|
||||
|
||||
enc_hs = text_embeddings.text_embeds[:, :use_len].to(self.device_torch, self.torch_dtype)
|
||||
prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64)[:, :use_len]
|
||||
txt_seq_lens = [use_len] * batch_size
|
||||
enc_hs = text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype)
|
||||
prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64)
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist()
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
|
||||
@@ -0,0 +1,276 @@
|
||||
import math
|
||||
import torch
|
||||
from .qwen_image import QwenImageModel
|
||||
import os
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
import yaml
|
||||
from toolkit import train_tools
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from PIL import Image
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from toolkit.basic import flush
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.samplers.custom_flowmatch_sampler import (
|
||||
CustomFlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from toolkit.accelerator import get_accelerator, unwrap_model
|
||||
from optimum.quanto import freeze, QTensor
|
||||
from toolkit.util.quantize import quantize, get_qtype, quantize_model
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers import (
|
||||
QwenImagePipeline,
|
||||
QwenImageTransformer2DModel,
|
||||
AutoencoderKLQwenImage,
|
||||
)
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
|
||||
from tqdm import tqdm
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
|
||||
try:
|
||||
from diffusers import QwenImageEditPipeline
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"QwenImageEditPipeline not found. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt"
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditModel(QwenImageModel):
|
||||
arch = "qwen_image_edit"
|
||||
_qwen_image_keep_visual = True
|
||||
_qwen_pipeline = QwenImageEditPipeline
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype="bf16",
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
|
||||
)
|
||||
self.is_flow_matching = True
|
||||
self.is_transformer = True
|
||||
self.target_lora_modules = ["QwenImageTransformer2DModel"]
|
||||
|
||||
# set true for models that encode control image into text embeddings
|
||||
self.encode_control_in_text_embeddings = True
|
||||
|
||||
def load_model(self):
|
||||
super().load_model()
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
scheduler = QwenImageModel.get_train_scheduler()
|
||||
|
||||
pipeline: QwenImageEditPipeline = QwenImageEditPipeline(
|
||||
scheduler=scheduler,
|
||||
text_encoder=unwrap_model(self.text_encoder[0]),
|
||||
tokenizer=self.tokenizer[0],
|
||||
processor=self.processor,
|
||||
vae=unwrap_model(self.vae),
|
||||
transformer=unwrap_model(self.transformer),
|
||||
)
|
||||
|
||||
pipeline = pipeline.to(self.device_torch)
|
||||
|
||||
return pipeline
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline: QwenImageEditPipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
self.model.to(self.device_torch, dtype=self.torch_dtype)
|
||||
sc = self.get_bucket_divisibility()
|
||||
gen_config.width = int(gen_config.width // sc * sc)
|
||||
gen_config.height = int(gen_config.height // sc * sc)
|
||||
|
||||
control_img = None
|
||||
if gen_config.ctrl_img is not None:
|
||||
control_img = Image.open(gen_config.ctrl_img)
|
||||
control_img = control_img.convert("RGB")
|
||||
# resize to width and height
|
||||
if control_img.size != (gen_config.width, gen_config.height):
|
||||
control_img = control_img.resize(
|
||||
(gen_config.width, gen_config.height), Image.BILINEAR
|
||||
)
|
||||
|
||||
# flush for low vram if we are doing that
|
||||
flush_between_steps = self.model_config.low_vram
|
||||
|
||||
# Fix a bug in diffusers/torch
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if flush_between_steps:
|
||||
flush()
|
||||
latents = callback_kwargs["latents"]
|
||||
|
||||
return {"latents": latents}
|
||||
|
||||
img = pipeline(
|
||||
image=control_img,
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
prompt_embeds_mask=conditional_embeds.attention_mask.to(
|
||||
self.device_torch, dtype=torch.int64
|
||||
),
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds,
|
||||
negative_prompt_embeds_mask=unconditional_embeds.attention_mask.to(
|
||||
self.device_torch, dtype=torch.int64
|
||||
),
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
true_cfg_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
generator=generator,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
**extra,
|
||||
).images[0]
|
||||
return img
|
||||
|
||||
def condition_noisy_latents(
|
||||
self, latents: torch.Tensor, batch: "DataLoaderBatchDTO"
|
||||
):
|
||||
with torch.no_grad():
|
||||
control_tensor = batch.control_tensor
|
||||
if control_tensor is not None:
|
||||
self.vae.to(self.device_torch)
|
||||
# we are not packed here, so we just need to pass them so we can pack them later
|
||||
control_tensor = control_tensor * 2 - 1
|
||||
control_tensor = control_tensor.to(
|
||||
self.vae_device_torch, dtype=self.torch_dtype
|
||||
)
|
||||
|
||||
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
|
||||
if batch.tensor is not None:
|
||||
target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3]
|
||||
else:
|
||||
# When caching latents, batch.tensor is None. We get the size from the file_items instead.
|
||||
target_h = batch.file_items[0].crop_height
|
||||
target_w = batch.file_items[0].crop_width
|
||||
|
||||
if (
|
||||
control_tensor.shape[2] != target_h
|
||||
or control_tensor.shape[3] != target_w
|
||||
):
|
||||
control_tensor = F.interpolate(
|
||||
control_tensor, size=(target_h, target_w), mode="bilinear"
|
||||
)
|
||||
|
||||
control_latent = self.encode_images(control_tensor).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = torch.cat((latents, control_latent), dim=1)
|
||||
|
||||
return latents.detach()
|
||||
|
||||
def get_prompt_embeds(self, prompt: str, control_images=None) -> PromptEmbeds:
|
||||
if self.pipeline.text_encoder.device != self.device_torch:
|
||||
self.pipeline.text_encoder.to(self.device_torch)
|
||||
|
||||
if control_images is not None:
|
||||
# control images are 0 - 1 scale, shape (bs, ch, height, width)
|
||||
# images are always run through at 1MP, based on diffusers inference code.
|
||||
target_area = 1024 * 1024
|
||||
ratio = control_images.shape[2] / control_images.shape[3]
|
||||
width = math.sqrt(target_area * ratio)
|
||||
height = width / ratio
|
||||
|
||||
width = round(width / 32) * 32
|
||||
height = round(height / 32) * 32
|
||||
|
||||
control_images = F.interpolate(
|
||||
control_images, size=(height, width), mode="bilinear"
|
||||
)
|
||||
|
||||
prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt(
|
||||
prompt,
|
||||
image=control_images,
|
||||
device=self.device_torch,
|
||||
num_images_per_prompt=1,
|
||||
)
|
||||
pe = PromptEmbeds(prompt_embeds)
|
||||
pe.attention_mask = prompt_embeds_mask
|
||||
return pe
|
||||
|
||||
def get_noise_prediction(
|
||||
self,
|
||||
latent_model_input: torch.Tensor,
|
||||
timestep: torch.Tensor, # 0 to 1000 scale
|
||||
text_embeddings: PromptEmbeds,
|
||||
**kwargs,
|
||||
):
|
||||
# control is stacked on channels, move it to the batch dimension for packing
|
||||
latent_model_input, control = torch.chunk(latent_model_input, 2, 1)
|
||||
|
||||
batch_size, num_channels_latents, height, width = latent_model_input.shape
|
||||
(
|
||||
control_batch_size,
|
||||
control_num_channels_latents,
|
||||
control_height,
|
||||
control_width,
|
||||
) = control.shape
|
||||
|
||||
# pack image tokens
|
||||
latent_model_input = latent_model_input.view(
|
||||
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
|
||||
)
|
||||
latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5)
|
||||
latent_model_input = latent_model_input.reshape(
|
||||
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
|
||||
)
|
||||
|
||||
# pack control
|
||||
control = control.view(
|
||||
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
|
||||
)
|
||||
control = control.permute(0, 2, 4, 1, 3, 5)
|
||||
control = control.reshape(
|
||||
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
|
||||
)
|
||||
|
||||
img_h2, img_w2 = height // 2, width // 2
|
||||
control_img_h2, control_img_w2 = control_height // 2, control_width // 2
|
||||
|
||||
img_shapes = [[(1, img_h2, img_w2), (1, control_img_h2, control_img_w2)]] * batch_size
|
||||
|
||||
latents = latent_model_input
|
||||
latent_model_input = torch.cat([latent_model_input, control], dim=1)
|
||||
batch_size = latent_model_input.shape[0]
|
||||
|
||||
prompt_embeds_mask = text_embeddings.attention_mask.to(
|
||||
self.device_torch, dtype=torch.int64
|
||||
)
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist()
|
||||
enc_hs = text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype)
|
||||
prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=enc_hs,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
return_dict=False,
|
||||
**kwargs,
|
||||
)[0]
|
||||
|
||||
noise_pred = noise_pred[:, : latents.size(1)]
|
||||
|
||||
# unpack
|
||||
noise_pred = noise_pred.view(
|
||||
batch_size, height // 2, width // 2, num_channels_latents, 2, 2
|
||||
)
|
||||
noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5)
|
||||
noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width)
|
||||
return noise_pred
|
||||
@@ -37,6 +37,8 @@ from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtracto
|
||||
from toolkit.util.wavelet_loss import wavelet_loss
|
||||
import torch.nn.functional as F
|
||||
from toolkit.unloader import unload_text_encoder
|
||||
from PIL import Image
|
||||
from torchvision.transforms import functional as TF
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -127,9 +129,28 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
prompt=prompt, # it will autoparse the prompt
|
||||
negative_prompt=sample_item.neg,
|
||||
output_path=output_path,
|
||||
ctrl_img=sample_item.ctrl_img
|
||||
)
|
||||
positive = self.sd.encode_prompt(gen_img_config.prompt).to('cpu')
|
||||
negative = self.sd.encode_prompt(gen_img_config.negative_prompt).to('cpu')
|
||||
# see if we need to encode the control images
|
||||
if self.sd.encode_control_in_text_embeddings and gen_img_config.ctrl_img is not None:
|
||||
ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img = (
|
||||
TF.to_tensor(ctrl_img)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
positive = self.sd.encode_prompt(
|
||||
gen_img_config.prompt,
|
||||
control_images=ctrl_img
|
||||
).to('cpu')
|
||||
negative = self.sd.encode_prompt(
|
||||
gen_img_config.negative_prompt,
|
||||
control_images=ctrl_img
|
||||
).to('cpu')
|
||||
else:
|
||||
positive = self.sd.encode_prompt(gen_img_config.prompt).to('cpu')
|
||||
negative = self.sd.encode_prompt(gen_img_config.negative_prompt).to('cpu')
|
||||
|
||||
self.sd.sample_prompts_cache.append({
|
||||
'conditional': positive,
|
||||
@@ -177,9 +198,15 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# cache unconditional embeds (blank prompt)
|
||||
with torch.no_grad():
|
||||
kwargs = {}
|
||||
if self.sd.encode_control_in_text_embeddings:
|
||||
# just do a blank image for unconditionals
|
||||
control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
kwargs['control_images'] = control_image
|
||||
self.unconditional_embeds = self.sd.encode_prompt(
|
||||
[self.train_config.unconditional_prompt],
|
||||
long_prompts=self.do_long_prompts
|
||||
long_prompts=self.do_long_prompts,
|
||||
**kwargs
|
||||
).to(
|
||||
self.device_torch,
|
||||
dtype=self.sd.torch_dtype
|
||||
@@ -241,9 +268,14 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
print_acc("***********************************")
|
||||
print_acc("")
|
||||
self.sd.text_encoder_to(self.device_torch)
|
||||
self.cached_blank_embeds = self.sd.encode_prompt("")
|
||||
encode_kwargs = {}
|
||||
if self.sd.encode_control_in_text_embeddings:
|
||||
# just do a blank image for unconditionals
|
||||
control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
encode_kwargs['control_images'] = control_image
|
||||
self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs)
|
||||
if self.trigger_word is not None:
|
||||
self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word)
|
||||
self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word, **encode_kwargs)
|
||||
if self.train_config.diff_output_preservation:
|
||||
self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class)
|
||||
|
||||
@@ -967,11 +999,16 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if batch.prompt_embeds is not None:
|
||||
embeds_to_use = batch.prompt_embeds.clone().to(self.device_torch, dtype=dtype)
|
||||
else:
|
||||
prompt_kwargs = {}
|
||||
if self.sd.encode_control_in_text_embeddings and batch.control_tensor is not None:
|
||||
prompt_kwargs['control_images'] = batch.control_tensor.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
embeds_to_use = self.sd.encode_prompt(
|
||||
prompt_list,
|
||||
long_prompts=self.do_long_prompts).to(
|
||||
self.device_torch,
|
||||
dtype=dtype).detach()
|
||||
dtype=dtype,
|
||||
**prompt_kwargs
|
||||
).detach()
|
||||
|
||||
# dont use network on this
|
||||
# self.network.multiplier = 0.0
|
||||
@@ -1338,6 +1375,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
with self.timer('encode_prompt'):
|
||||
unconditional_embeds = None
|
||||
prompt_kwargs = {}
|
||||
if self.sd.encode_control_in_text_embeddings and batch.control_tensor is not None:
|
||||
prompt_kwargs['control_images'] = batch.control_tensor.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
if self.train_config.unload_text_encoder or self.is_caching_text_embeddings:
|
||||
with torch.set_grad_enabled(False):
|
||||
if batch.prompt_embeds is not None:
|
||||
@@ -1374,7 +1414,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
conditional_embeds = self.sd.encode_prompt(
|
||||
conditioned_prompts, prompt_2,
|
||||
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||
long_prompts=self.do_long_prompts).to(
|
||||
long_prompts=self.do_long_prompts,
|
||||
**prompt_kwargs
|
||||
).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
|
||||
@@ -1386,7 +1428,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.batch_negative_prompt,
|
||||
self.batch_negative_prompt,
|
||||
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||
long_prompts=self.do_long_prompts).to(
|
||||
long_prompts=self.do_long_prompts,
|
||||
**prompt_kwargs
|
||||
).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
@@ -1404,7 +1448,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
conditional_embeds = self.sd.encode_prompt(
|
||||
conditioned_prompts, prompt_2,
|
||||
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||
long_prompts=self.do_long_prompts).to(
|
||||
long_prompts=self.do_long_prompts,
|
||||
**prompt_kwargs
|
||||
).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
if self.train_config.do_cfg:
|
||||
@@ -1413,7 +1459,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
unconditional_embeds = self.sd.encode_prompt(
|
||||
self.batch_negative_prompt,
|
||||
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||
long_prompts=self.do_long_prompts).to(
|
||||
long_prompts=self.do_long_prompts,
|
||||
**prompt_kwargs
|
||||
).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
@@ -1427,7 +1475,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.diff_output_preservation_embeds = self.sd.encode_prompt(
|
||||
dop_prompts, dop_prompts_2,
|
||||
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||
long_prompts=self.do_long_prompts).to(
|
||||
long_prompts=self.do_long_prompts,
|
||||
**prompt_kwargs
|
||||
).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
# detach the embeddings
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
torchao==0.10.0
|
||||
safetensors
|
||||
git+https://github.com/jaretburkett/easy_dwpose.git
|
||||
git+https://github.com/huggingface/diffusers@7ea065c5070a5278259e6f1effa9dccea232e62a
|
||||
git+https://github.com/huggingface/diffusers@7a2b78bf0f788d311cc96b61e660a8e13e3b1e63
|
||||
transformers==4.52.4
|
||||
lycoris-lora==1.8.3
|
||||
flatten_json
|
||||
|
||||
@@ -1232,5 +1232,10 @@ def validate_configs(
|
||||
for dataset in dataset_configs:
|
||||
if not dataset.cache_text_embeddings:
|
||||
raise ValueError("All datasets must have cache_text_embeddings set to True when caching text embeddings is enabled.")
|
||||
|
||||
# qwen image edit cannot cache text embeddings
|
||||
if model_config.arch == 'qwen_image_edit':
|
||||
if train_config.unload_text_encoder:
|
||||
raise ValueError("Cannot cache unload text encoder with qwen_image_edit model. Control images are encoded with text embeddings. You can cache the text embeddings though")
|
||||
|
||||
|
||||
|
||||
@@ -497,6 +497,7 @@ class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin
|
||||
dataloader_transforms=self.transform,
|
||||
size_database=self.size_database,
|
||||
dataset_root=dataset_folder,
|
||||
encode_control_in_text_embeddings=self.sd.encode_control_in_text_embeddings if self.sd else False,
|
||||
)
|
||||
self.file_list.append(file_item)
|
||||
except Exception as e:
|
||||
|
||||
@@ -50,6 +50,7 @@ class FileItemDTO(
|
||||
self.is_video = self.dataset_config.num_frames > 1
|
||||
size_database = kwargs.get('size_database', {})
|
||||
dataset_root = kwargs.get('dataset_root', None)
|
||||
self.encode_control_in_text_embeddings = kwargs.get('encode_control_in_text_embeddings', False)
|
||||
if dataset_root is not None:
|
||||
# remove dataset root from path
|
||||
file_key = self.path.replace(dataset_root, '')
|
||||
|
||||
@@ -30,6 +30,7 @@ import albumentations as A
|
||||
from toolkit.print import print_acc
|
||||
from toolkit.accelerator import get_accelerator
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from torchvision.transforms import functional as TF
|
||||
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
@@ -1802,6 +1803,9 @@ class TextEmbeddingFileItemDTOMixin:
|
||||
("text_embedding_space_version", self.text_embedding_space_version),
|
||||
("text_embedding_version", self.text_embedding_version),
|
||||
])
|
||||
# if we have a control image, cache the path
|
||||
if self.encode_control_in_text_embeddings and self.control_path is not None:
|
||||
item["control_path"] = self.control_path
|
||||
return item
|
||||
|
||||
def get_text_embedding_path(self: 'FileItemDTO', recalculate=False):
|
||||
@@ -1860,7 +1864,19 @@ class TextEmbeddingCachingMixin:
|
||||
if not did_move:
|
||||
self.sd.set_device_state_preset('cache_text_encoder')
|
||||
did_move = True
|
||||
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)
|
||||
|
||||
if file_item.encode_control_in_text_embeddings and file_item.control_path is not None:
|
||||
# load the control image and feed it into the text encoder
|
||||
ctrl_img = Image.open(file_item.control_path).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img = (
|
||||
TF.to_tensor(ctrl_img)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption, control_images=ctrl_img)
|
||||
else:
|
||||
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)
|
||||
# save it
|
||||
prompt_embeds.save(text_embedding_path)
|
||||
del prompt_embeds
|
||||
|
||||
@@ -36,6 +36,7 @@ from diffusers import \
|
||||
UNet2DConditionModel
|
||||
from diffusers import PixArtAlphaPipeline
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
||||
from torchvision.transforms import functional as TF
|
||||
|
||||
from toolkit.accelerator import get_accelerator, unwrap_model
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -177,6 +178,9 @@ class BaseModel:
|
||||
self.multistage_boundaries: List[float] = [0.0]
|
||||
# a list of trainable multistage boundaries
|
||||
self.trainable_multistage_boundaries: List[int] = [0]
|
||||
|
||||
# set true for models that encode control image into text embeddings
|
||||
self.encode_control_in_text_embeddings = False
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
@@ -287,7 +291,7 @@ class BaseModel:
|
||||
raise NotImplementedError(
|
||||
"get_noise_prediction must be implemented in child classes")
|
||||
|
||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||
def get_prompt_embeds(self, prompt: str, control_images=None) -> PromptEmbeds:
|
||||
raise NotImplementedError(
|
||||
"get_prompt_embeds must be implemented in child classes")
|
||||
|
||||
@@ -496,17 +500,34 @@ class BaseModel:
|
||||
if self.sample_prompts_cache is not None:
|
||||
conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype)
|
||||
unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype)
|
||||
else:
|
||||
else:
|
||||
ctrl_img = None
|
||||
# load the control image if out model uses it in text encoding
|
||||
if gen_config.ctrl_img is not None and self.encode_control_in_text_embeddings:
|
||||
ctrl_img = Image.open(gen_config.ctrl_img).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img = (
|
||||
TF.to_tensor(ctrl_img)
|
||||
.unsqueeze(0)
|
||||
.to(self.device_torch, dtype=self.torch_dtype)
|
||||
)
|
||||
# encode the prompt ourselves so we can do fun stuff with embeddings
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
conditional_embeds = self.encode_prompt(
|
||||
gen_config.prompt, gen_config.prompt_2, force_all=True)
|
||||
gen_config.prompt,
|
||||
gen_config.prompt_2,
|
||||
force_all=True,
|
||||
control_images=ctrl_img
|
||||
)
|
||||
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = True
|
||||
unconditional_embeds = self.encode_prompt(
|
||||
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
|
||||
gen_config.negative_prompt,
|
||||
gen_config.negative_prompt_2,
|
||||
force_all=True,
|
||||
control_images=ctrl_img
|
||||
)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
@@ -989,6 +1010,7 @@ class BaseModel:
|
||||
long_prompts=False,
|
||||
max_length=None,
|
||||
dropout_prob=0.0,
|
||||
control_images=None,
|
||||
) -> PromptEmbeds:
|
||||
# sd1.5 embeddings are (bs, 77, 768)
|
||||
prompt = prompt
|
||||
@@ -998,6 +1020,9 @@ class BaseModel:
|
||||
|
||||
if prompt2 is not None and not isinstance(prompt2, list):
|
||||
prompt2 = [prompt2]
|
||||
# if control_images in the signature, pass it. This keep from breaking plugins
|
||||
if self.encode_control_in_text_embeddings:
|
||||
return self.get_prompt_embeds(prompt, control_images=control_images)
|
||||
|
||||
return self.get_prompt_embeds(prompt)
|
||||
|
||||
|
||||
@@ -217,6 +217,9 @@ class StableDiffusion:
|
||||
# a list of trainable multistage boundaries
|
||||
self.trainable_multistage_boundaries: List[int] = [0]
|
||||
|
||||
# set true for models that encode control image into text embeddings
|
||||
self.encode_control_in_text_embeddings = False
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
def is_xl(self):
|
||||
@@ -2356,6 +2359,7 @@ class StableDiffusion:
|
||||
long_prompts=False,
|
||||
max_length=None,
|
||||
dropout_prob=0.0,
|
||||
control_images=None,
|
||||
) -> PromptEmbeds:
|
||||
# sd1.5 embeddings are (bs, 77, 768)
|
||||
prompt = prompt
|
||||
|
||||
@@ -10,7 +10,7 @@ type AdditionalSections =
|
||||
| 'datasets.num_frames'
|
||||
| 'model.multistage'
|
||||
| 'model.low_vram';
|
||||
type ModelGroup = 'image' | 'video';
|
||||
type ModelGroup = 'image' | 'instruction' | 'video';
|
||||
|
||||
export interface ModelArch {
|
||||
name: string;
|
||||
@@ -44,7 +44,7 @@ export const modelArchs: ModelArch[] = [
|
||||
{
|
||||
name: 'flux_kontext',
|
||||
label: 'FLUX.1-Kontext-dev',
|
||||
group: 'image',
|
||||
group: 'instruction',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-Kontext-dev', defaultNameOrPath],
|
||||
@@ -306,6 +306,27 @@ export const modelArchs: ModelArch[] = [
|
||||
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors',
|
||||
},
|
||||
},
|
||||
{
|
||||
name: 'qwen_image_edit',
|
||||
label: 'Qwen-Image-Edit',
|
||||
group: 'instruction',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['Qwen/Qwen-Image-Edit', defaultNameOrPath],
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].model.low_vram': [true, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
||||
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram'],
|
||||
accuracyRecoveryAdapters: {
|
||||
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_torchao_uint3.safetensors',
|
||||
},
|
||||
},
|
||||
{
|
||||
name: 'hidream',
|
||||
label: 'HiDream',
|
||||
@@ -327,7 +348,7 @@ export const modelArchs: ModelArch[] = [
|
||||
{
|
||||
name: 'hidream_e1',
|
||||
label: 'HiDream E1',
|
||||
group: 'image',
|
||||
group: 'instruction',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-E1-1', defaultNameOrPath],
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.5.3"
|
||||
VERSION = "0.5.4"
|
||||
Reference in New Issue
Block a user