From 454be0958a3fbf2d82a7e1ff31dfe1368ec3339c Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 24 Sep 2025 11:39:10 -0600 Subject: [PATCH 1/2] Initial support for qwen image edit plus --- .../diffusion_models/__init__.py | 3 +- .../diffusion_models/qwen_image/__init__.py | 3 +- .../qwen_image/qwen_image_edit_plus.py | 309 ++++++++++++++++++ extensions_built_in/sd_trainer/SDTrainer.py | 70 +++- jobs/process/BaseSDTrainProcess.py | 3 + requirements.txt | 2 +- toolkit/config_modules.py | 14 + toolkit/data_transfer_object/data_loader.py | 12 +- toolkit/dataloader_mixins.py | 53 +-- toolkit/models/base_model.py | 4 + toolkit/stable_diffusion_model.py | 4 + 11 files changed, 445 insertions(+), 32 deletions(-) create mode 100644 extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py diff --git a/extensions_built_in/diffusion_models/__init__.py b/extensions_built_in/diffusion_models/__init__.py index 6449fe61..f7d874e5 100644 --- a/extensions_built_in/diffusion_models/__init__.py +++ b/extensions_built_in/diffusion_models/__init__.py @@ -4,7 +4,7 @@ from .f_light import FLiteModel from .omnigen2 import OmniGen2Model from .flux_kontext import FluxKontextModel from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel -from .qwen_image import QwenImageModel, QwenImageEditModel +from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusModel AI_TOOLKIT_MODELS = [ # put a list of models here @@ -20,4 +20,5 @@ AI_TOOLKIT_MODELS = [ Wan2214bModel, QwenImageModel, QwenImageEditModel, + QwenImageEditPlusModel, ] diff --git a/extensions_built_in/diffusion_models/qwen_image/__init__.py b/extensions_built_in/diffusion_models/qwen_image/__init__.py index d8b32a85..3ff1797b 100644 --- a/extensions_built_in/diffusion_models/qwen_image/__init__.py +++ b/extensions_built_in/diffusion_models/qwen_image/__init__.py @@ -1,2 +1,3 @@ from .qwen_image import QwenImageModel -from .qwen_image_edit import QwenImageEditModel \ No newline at end of file +from .qwen_image_edit import QwenImageEditModel +from .qwen_image_edit_plus import QwenImageEditPlusModel diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py new file mode 100644 index 00000000..32fcaa70 --- /dev/null +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py @@ -0,0 +1,309 @@ +import math +import torch +from .qwen_image import QwenImageModel +import os +from typing import TYPE_CHECKING, List, Optional +import yaml +from toolkit import train_tools +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from PIL import Image +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.accelerator import get_accelerator, unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.quantize import quantize, get_qtype, quantize_model +import torch.nn.functional as F + +from diffusers import ( + QwenImageTransformer2DModel, + AutoencoderKLQwenImage, +) +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer +from tqdm import tqdm + + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +try: + from diffusers import QwenImageEditPlusPipeline + from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus import CONDITION_IMAGE_SIZE, VAE_IMAGE_SIZE +except ImportError: + raise ImportError( + "Diffusers is out of date. Update diffusers to the latest version by doing 'pip uninstall diffusers' and then 'pip install -r requirements.txt'" + ) + + +class QwenImageEditPlusModel(QwenImageModel): + arch = "qwen_image_edit_plus" + _qwen_image_keep_visual = True + _qwen_pipeline = QwenImageEditPlusPipeline + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["QwenImageTransformer2DModel"] + + # set true for models that encode control image into text embeddings + self.encode_control_in_text_embeddings = True + # control images will come in as a list for encoding some things if true + self.has_multiple_control_images = True + # do not resize control images + self.use_raw_control_images = True + + def load_model(self): + super().load_model() + + def get_generation_pipeline(self): + scheduler = QwenImageModel.get_train_scheduler() + + pipeline: QwenImageEditPlusPipeline = QwenImageEditPlusPipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + processor=self.processor, + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer), + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: QwenImageEditPlusPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + self.model.to(self.device_torch, dtype=self.torch_dtype) + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + + control_img_list = [] + if gen_config.ctrl_img is not None: + control_img = Image.open(gen_config.ctrl_img) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + elif gen_config.ctrl_img_1 is not None: + control_img = Image.open(gen_config.ctrl_img_1) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + + if gen_config.ctrl_img_2 is not None: + control_img = Image.open(gen_config.ctrl_img_2) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + if gen_config.ctrl_img_3 is not None: + control_img = Image.open(gen_config.ctrl_img_3) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + + # flush for low vram if we are doing that + # flush_between_steps = self.model_config.low_vram + flush_between_steps = False + + # Fix a bug in diffusers/torch + def callback_on_step_end(pipe, i, t, callback_kwargs): + if flush_between_steps: + flush() + latents = callback_kwargs["latents"] + + return {"latents": latents} + + img = pipeline( + image=control_img_list, + prompt_embeds=conditional_embeds.text_embeds, + prompt_embeds_mask=conditional_embeds.attention_mask.to( + self.device_torch, dtype=torch.int64 + ), + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_prompt_embeds_mask=unconditional_embeds.attention_mask.to( + self.device_torch, dtype=torch.int64 + ), + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + true_cfg_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + callback_on_step_end=callback_on_step_end, + **extra, + ).images[0] + return img + + def condition_noisy_latents( + self, latents: torch.Tensor, batch: "DataLoaderBatchDTO" + ): + # we get the control image from the batch + return latents.detach() + + def get_prompt_embeds(self, prompt: str, control_images=None) -> PromptEmbeds: + # todo handle not caching text encoder + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + if control_images is not None and len(control_images) > 0: + for i in range(len(control_images)): + # control images are 0 - 1 scale, shape (bs, ch, height, width) + ratio = control_images[i].shape[2] / control_images[i].shape[3] + width = math.sqrt(CONDITION_IMAGE_SIZE * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + control_images[i] = F.interpolate( + control_images[i], size=(height, width), mode="bilinear" + ) + + prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt( + prompt, + image=control_images, + device=self.device_torch, + num_images_per_prompt=1, + ) + pe = PromptEmbeds(prompt_embeds) + pe.attention_mask = prompt_embeds_mask + return pe + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + batch: "DataLoaderBatchDTO" = None, + **kwargs, + ): + with torch.no_grad(): + batch_size, num_channels_latents, height, width = latent_model_input.shape + + # pack image tokens + latent_model_input = latent_model_input.view( + batch_size, num_channels_latents, height // 2, 2, width // 2, 2 + ) + latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5) + latent_model_input = latent_model_input.reshape( + batch_size, (height // 2) * (width // 2), num_channels_latents * 4 + ) + + raw_packed_latents = latent_model_input + + img_h2, img_w2 = height // 2, width // 2 + + img_shapes = [ + [(1, img_h2, img_w2)] + ] * batch_size + + # pack controls + if batch is None: + raise ValueError("Batch is required for QwenImageEditPlusModel") + + # split the latents into batch items so we can concat the controls + packed_latents_list = torch.chunk(latent_model_input, batch_size, dim=0) + packed_latents_with_controls_list = [] + + if batch.control_tensor_list is not None: + if len(batch.control_tensor_list) != batch_size: + raise ValueError("Control tensor list length does not match batch size") + b = 0 + for control_tensor_list in batch.control_tensor_list: + # control tensor list is a list of tensors for this batch item + controls = [] + # pack control + for control_img in control_tensor_list: + # control images are 0 - 1 scale, shape (1, ch, height, width) + control_img = control_img.to(self.device_torch, dtype=self.torch_dtype) + # if it is only 3 dim, add batch dim + if len(control_img.shape) == 3: + control_img = control_img.unsqueeze(0) + ratio = control_img.shape[2] / control_img.shape[3] + c_width = math.sqrt(VAE_IMAGE_SIZE * ratio) + c_height = c_width / ratio + + c_width = round(c_width / 32) * 32 + c_height = round(c_height / 32) * 32 + + control_img = F.interpolate( + control_img, size=(c_height, c_width), mode="bilinear" + ) + + control_latent = self.encode_images( + control_img, + device=self.device_torch, + dtype=self.torch_dtype, + ) + + clb, cl_num_channels_latents, cl_height, cl_width = control_latent.shape + + control = control_latent.view( + 1, cl_num_channels_latents, cl_height // 2, 2, cl_width // 2, 2 + ) + control = control.permute(0, 2, 4, 1, 3, 5) + control = control.reshape( + 1, (cl_height // 2) * (cl_width // 2), num_channels_latents * 4 + ) + + img_shapes[b].append((1, cl_height // 2, cl_width // 2)) + controls.append(control) + + # stack controls on dim 1 + control = torch.cat(controls, dim=1).to(packed_latents_list[b].device, dtype=packed_latents_list[b].dtype) + # concat with latents + packed_latents_with_control = torch.cat([packed_latents_list[b], control], dim=1) + + packed_latents_with_controls_list.append(packed_latents_with_control) + + b += 1 + + latent_model_input = torch.cat(packed_latents_with_controls_list, dim=0) + + prompt_embeds_mask = text_embeddings.attention_mask.to( + self.device_torch, dtype=torch.int64 + ) + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() + enc_hs = text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype) + prompt_embeds_mask = text_embeddings.attention_mask.to( + self.device_torch, dtype=torch.int64 + ) + + noise_pred = self.transformer( + hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=enc_hs, + encoder_hidden_states_mask=prompt_embeds_mask, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + return_dict=False, + **kwargs, + )[0] + + noise_pred = noise_pred[:, : raw_packed_latents.size(1)] + + # unpack + noise_pred = noise_pred.view( + batch_size, height // 2, width // 2, num_channels_latents, 2, 2 + ) + noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5) + noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width) + return noise_pred diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 16c392d7..86c198ea 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -129,17 +129,64 @@ class SDTrainer(BaseSDTrainProcess): prompt=prompt, # it will autoparse the prompt negative_prompt=sample_item.neg, output_path=output_path, - ctrl_img=sample_item.ctrl_img + ctrl_img=sample_item.ctrl_img, + ctrl_img_1=sample_item.ctrl_img_1, + ctrl_img_2=sample_item.ctrl_img_2, + ctrl_img_3=sample_item.ctrl_img_3, ) + + has_control_images = False + if gen_img_config.ctrl_img is not None or gen_img_config.ctrl_img_1 is not None or gen_img_config.ctrl_img_2 is not None or gen_img_config.ctrl_img_3 is not None: + has_control_images = True # see if we need to encode the control images - if self.sd.encode_control_in_text_embeddings and gen_img_config.ctrl_img is not None: - ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB") - # convert to 0 to 1 tensor - ctrl_img = ( - TF.to_tensor(ctrl_img) - .unsqueeze(0) - .to(self.sd.device_torch, dtype=self.sd.torch_dtype) - ) + if self.sd.encode_control_in_text_embeddings and has_control_images: + + ctrl_img_list = [] + + if gen_img_config.ctrl_img is not None: + ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img = ( + TF.to_tensor(ctrl_img) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img) + + if gen_img_config.ctrl_img_1 is not None: + ctrl_img_1 = Image.open(gen_img_config.ctrl_img_1).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_1 = ( + TF.to_tensor(ctrl_img_1) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_1) + if gen_img_config.ctrl_img_2 is not None: + ctrl_img_2 = Image.open(gen_img_config.ctrl_img_2).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_2 = ( + TF.to_tensor(ctrl_img_2) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_2) + if gen_img_config.ctrl_img_3 is not None: + ctrl_img_3 = Image.open(gen_img_config.ctrl_img_3).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_3 = ( + TF.to_tensor(ctrl_img_3) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_3) + + if self.sd.has_multiple_control_images: + ctrl_img = ctrl_img_list + else: + ctrl_img = ctrl_img_list[0] if len(ctrl_img_list) > 0 else None + + positive = self.sd.encode_prompt( gen_img_config.prompt, control_images=ctrl_img @@ -202,6 +249,9 @@ class SDTrainer(BaseSDTrainProcess): if self.sd.encode_control_in_text_embeddings: # just do a blank image for unconditionals control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype) + if self.sd.has_multiple_control_images: + control_image = [control_image] + kwargs['control_images'] = control_image self.unconditional_embeds = self.sd.encode_prompt( [self.train_config.unconditional_prompt], @@ -272,6 +322,8 @@ class SDTrainer(BaseSDTrainProcess): if self.sd.encode_control_in_text_embeddings: # just do a blank image for unconditionals control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype) + if self.sd.has_multiple_control_images: + control_image = [control_image] encode_kwargs['control_images'] = control_image self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs) if self.trigger_word is not None: diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index e9b3d260..01d37b81 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -348,6 +348,9 @@ class BaseSDTrainProcess(BaseTrainProcess): fps=sample_item.fps, ctrl_img=sample_item.ctrl_img, ctrl_idx=sample_item.ctrl_idx, + ctrl_img_1=sample_item.ctrl_img_1, + ctrl_img_2=sample_item.ctrl_img_2, + ctrl_img_3=sample_item.ctrl_img_3, **extra_args )) diff --git a/requirements.txt b/requirements.txt index 39fa3dd0..e5442be3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torchao==0.10.0 safetensors git+https://github.com/jaretburkett/easy_dwpose.git -git+https://github.com/huggingface/diffusers@7a2b78bf0f788d311cc96b61e660a8e13e3b1e63 +git+https://github.com/huggingface/diffusers@1448b035859dd57bbb565239dcdd79a025a85422 transformers==4.52.4 lycoris-lora==1.8.3 flatten_json diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index ca0b3bf2..24da74fe 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -56,6 +56,11 @@ class SampleItem: self.num_frames: int = kwargs.get('num_frames', sample_config.num_frames) self.ctrl_img: Optional[str] = kwargs.get('ctrl_img', None) self.ctrl_idx: int = kwargs.get('ctrl_idx', 0) + # for multi control image models + self.ctrl_img_1: Optional[str] = kwargs.get('ctrl_img_1', self.ctrl_img) + self.ctrl_img_2: Optional[str] = kwargs.get('ctrl_img_2', None) + self.ctrl_img_3: Optional[str] = kwargs.get('ctrl_img_3', None) + self.network_multiplier: float = kwargs.get('network_multiplier', sample_config.network_multiplier) # convert to a number if it is a string if isinstance(self.network_multiplier, str): @@ -966,6 +971,9 @@ class GenerateImageConfig: extra_values: List[float] = None, # extra values to save with prompt file logger: Optional[EmptyLogger] = None, ctrl_img: Optional[str] = None, # control image for controlnet + ctrl_img_1: Optional[str] = None, # first control image for multi control model + ctrl_img_2: Optional[str] = None, # second control image for multi control model + ctrl_img_3: Optional[str] = None, # third control image for multi control model num_frames: int = 1, fps: int = 15, ctrl_idx: int = 0 @@ -1002,6 +1010,12 @@ class GenerateImageConfig: self.ctrl_img = ctrl_img self.ctrl_idx = ctrl_idx + if ctrl_img_1 is None and ctrl_img is not None: + ctrl_img_1 = ctrl_img + + self.ctrl_img_1 = ctrl_img_1 + self.ctrl_img_2 = ctrl_img_2 + self.ctrl_img_3 = ctrl_img_3 # prompt string will override any settings above self._process_prompt_string() diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index a7ed3759..ca27dd47 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -144,6 +144,7 @@ class DataLoaderBatchDTO: self.tensor: Union[torch.Tensor, None] = None self.latents: Union[torch.Tensor, None] = None self.control_tensor: Union[torch.Tensor, None] = None + self.control_tensor_list: Union[List[List[torch.Tensor]], None] = None self.clip_image_tensor: Union[torch.Tensor, None] = None self.mask_tensor: Union[torch.Tensor, None] = None self.unaugmented_tensor: Union[torch.Tensor, None] = None @@ -160,7 +161,6 @@ class DataLoaderBatchDTO: self.latents: Union[torch.Tensor, None] = None if is_latents_cached: self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items]) - self.control_tensor: Union[torch.Tensor, None] = None self.prompt_embeds: Union[PromptEmbeds, None] = None # if self.file_items[0].control_tensor is not None: # if any have a control tensor, we concatenate them @@ -178,6 +178,16 @@ class DataLoaderBatchDTO: else: control_tensors.append(x.control_tensor) self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors]) + + # handle control tensor list + if any([x.control_tensor_list is not None for x in self.file_items]): + self.control_tensor_list = [] + for x in self.file_items: + if x.control_tensor_list is not None: + self.control_tensor_list.append(x.control_tensor_list) + else: + raise Exception(f"Could not find control tensors for all file items, missing for {x.path}") + self.inpaint_tensor: Union[torch.Tensor, None] = None if any([x.inpaint_tensor is not None for x in self.file_items]): diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 29b2aced..7a114bdd 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -850,6 +850,9 @@ class ControlFileItemDTOMixin: self.has_control_image = False self.control_path: Union[str, List[str], None] = None self.control_tensor: Union[torch.Tensor, None] = None + self.control_tensor_list: Union[List[torch.Tensor], None] = None + sd = kwargs.get('sd', None) + self.use_raw_control_images = sd is not None and sd.use_raw_control_images dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) self.full_size_control_images = False if dataset_config.control_path is not None: @@ -900,23 +903,14 @@ class ControlFileItemDTOMixin: except Exception as e: print_acc(f"Error: {e}") print_acc(f"Error loading image: {control_path}") - + if not self.full_size_control_images: # we just scale them to 512x512: w, h = img.size img = img.resize((512, 512), Image.BICUBIC) - else: + elif not self.use_raw_control_images: w, h = img.size - if w > h and self.scale_to_width < self.scale_to_height: - # throw error, they should match - raise ValueError( - f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") - elif h > w and self.scale_to_height < self.scale_to_width: - # throw error, they should match - raise ValueError( - f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") - if self.flip_x: # do a flip img = img.transpose(Image.FLIP_LEFT_RIGHT) @@ -950,11 +944,15 @@ class ControlFileItemDTOMixin: self.control_tensor = None elif len(control_tensors) == 1: self.control_tensor = control_tensors[0] + elif self.use_raw_control_images: + # just send the list of tensors as their shapes wont match + self.control_tensor_list = control_tensors else: self.control_tensor = torch.stack(control_tensors, dim=0) def cleanup_control(self: 'FileItemDTO'): self.control_tensor = None + self.control_tensor_list = None class ClipImageFileItemDTOMixin: @@ -1884,14 +1882,31 @@ class TextEmbeddingCachingMixin: if file_item.encode_control_in_text_embeddings: if file_item.control_path is None: raise Exception(f"Could not find a control image for {file_item.path} which is needed for this model") - # load the control image and feed it into the text encoder - ctrl_img = Image.open(file_item.control_path).convert("RGB") - # convert to 0 to 1 tensor - ctrl_img = ( - TF.to_tensor(ctrl_img) - .unsqueeze(0) - .to(self.sd.device_torch, dtype=self.sd.torch_dtype) - ) + ctrl_img_list = [] + control_path_list = file_item.control_path + if not isinstance(file_item.control_path, list): + control_path_list = [control_path_list] + for i in range(len(control_path_list)): + try: + img = Image.open(control_path_list[i]).convert("RGB") + img = exif_transpose(img) + # convert to 0 to 1 tensor + img = ( + TF.to_tensor(img) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(img) + except Exception as e: + print_acc(f"Error: {e}") + print_acc(f"Error loading control image: {control_path_list[i]}") + + if len(ctrl_img_list) == 0: + ctrl_img = None + elif not self.sd.has_multiple_control_images: + ctrl_img = ctrl_img_list[0] + else: + ctrl_img = ctrl_img_list prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption, control_images=ctrl_img) else: prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption) diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index af5ea06f..58d48b4f 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -181,6 +181,10 @@ class BaseModel: # set true for models that encode control image into text embeddings self.encode_control_in_text_embeddings = False + # control images will come in as a list for encoding some things if true + self.has_multiple_control_images = False + # do not resize control images + self.use_raw_control_images = False # properties for old arch for backwards compatibility @property diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 53bd689a..78960ed1 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -219,6 +219,10 @@ class StableDiffusion: # set true for models that encode control image into text embeddings self.encode_control_in_text_embeddings = False + # control images will come in as a list for encoding some things if true + self.has_multiple_control_images = False + # do not resize control images + self.use_raw_control_images = False # properties for old arch for backwards compatibility @property From 1069dee0e444c900d2eba78ab6f4ef78a43f3ffc Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 25 Sep 2025 11:10:02 -0600 Subject: [PATCH 2/2] Added ui sopport for multi control samples and datasets. Added qwen image edit 5209 to the ui --- .../train_lora_qwen_image_edit_2509_32gb.yaml | 105 +++++++++ toolkit/config_modules.py | 15 ++ ui/src/app/jobs/new/SimpleJob.tsx | 179 +++++++-------- ui/src/app/jobs/new/jobConfig.ts | 1 - ui/src/app/jobs/new/options.ts | 25 +++ ui/src/app/jobs/new/utils.ts | 105 +++++++++ ui/src/components/SampleControlImage.tsx | 206 ++++++++++++++++++ ui/src/docs.tsx | 16 ++ ui/src/types.ts | 8 +- version.py | 2 +- 10 files changed, 574 insertions(+), 88 deletions(-) create mode 100644 config/examples/train_lora_qwen_image_edit_2509_32gb.yaml create mode 100644 ui/src/app/jobs/new/utils.ts create mode 100644 ui/src/components/SampleControlImage.tsx diff --git a/config/examples/train_lora_qwen_image_edit_2509_32gb.yaml b/config/examples/train_lora_qwen_image_edit_2509_32gb.yaml new file mode 100644 index 00000000..845ac92b --- /dev/null +++ b/config/examples/train_lora_qwen_image_edit_2509_32gb.yaml @@ -0,0 +1,105 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_qwen_image_edit_2509_lora_v1" + process: + - type: 'diffusion_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + # can do up to 3 control image folders, file names must match target file names, but aspect/size can be different + control_path: + - "/path/to/control/images/folder1" + - "/path/to/control/images/folder2" + - "/path/to/control/images/folder3" + caption_ext: "txt" + # default_caption: "a person" # if caching text embeddings, if you don't have captions, this will get cached + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + resolution: [ 512, 768, 1024 ] # qwen image enjoys multiple resolutions + # a trigger word that can be cached with the text embeddings + # trigger_word: "optional trigger word" + train: + batch_size: 1 + # caching text embeddings is required for 32GB + cache_text_embeddings: true + # unload_text_encoder: true + + steps: 3000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + timestep_type: "weighted" + train_unet: true + train_text_encoder: false # probably won't work with qwen image + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample + # skip_first_sample: true + # uncomment to completely disable sampling + # disable_sampling: true + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "Qwen/Qwen-Image-Edit-2509" + arch: "qwen_image_edit_plus" + quantize: true + # to use the ARA use the | pipe to point to hf path, or a local path if you have one. + # 3bit is required for 32GB + qtype: "uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_2509_torchao_uint3.safetensors" + quantize_te: true + qtype_te: "qfloat8" + low_vram: true + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + # you can provide up to 3 control images here + samples: + - prompt: "Do whatever with Image1 and Image2" + ctrl_img_1: "/path/to/image1.png" + ctrl_img_2: "/path/to/image2.png" + # ctrl_img_3: "/path/to/image3.png" + - prompt: "Do whatever with Image1 and Image2" + ctrl_img_1: "/path/to/image1.png" + ctrl_img_2: "/path/to/image2.png" + # ctrl_img_3: "/path/to/image3.png" + - prompt: "Do whatever with Image1 and Image2" + ctrl_img_1: "/path/to/image1.png" + ctrl_img_2: "/path/to/image2.png" + # ctrl_img_3: "/path/to/image3.png" + - prompt: "Do whatever with Image1 and Image2" + ctrl_img_1: "/path/to/image1.png" + ctrl_img_2: "/path/to/image2.png" + # ctrl_img_3: "/path/to/image3.png" + - prompt: "Do whatever with Image1 and Image2" + ctrl_img_1: "/path/to/image1.png" + ctrl_img_2: "/path/to/image2.png" + # ctrl_img_3: "/path/to/image3.png" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 3 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 24da74fe..ad6cc1f3 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -831,6 +831,21 @@ class DatasetConfig: if self.control_path == '': self.control_path = None + # handle multi control inputs from the ui. It is just easier to handle it here for a cleaner ui experience + control_path_1 = kwargs.get('control_path_1', None) + control_path_2 = kwargs.get('control_path_2', None) + control_path_3 = kwargs.get('control_path_3', None) + + if any([control_path_1, control_path_2, control_path_3]): + control_paths = [] + if control_path_1: + control_paths.append(control_path_1) + if control_path_2: + control_paths.append(control_path_2) + if control_path_3: + control_paths.append(control_path_3) + self.control_path = control_paths + # color for transparent reigon of control images with transparency self.control_transparent_color: List[int] = kwargs.get('control_transparent_color', [0, 0, 0]) # inpaint images should be webp/png images with alpha channel. The alpha 0 (invisible) section will diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index 3953c182..4ae5bcc9 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -15,7 +15,9 @@ import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/comp import Card from '@/components/Card'; import { X } from 'lucide-react'; import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal'; +import SampleControlImage from '@/components/SampleControlImage'; import { FlipHorizontal2, FlipVertical2 } from 'lucide-react'; +import { handleModelArchChange } from './utils'; type Props = { jobConfig: JobConfig; @@ -185,58 +187,7 @@ export default function SimpleJob({ label="Model Architecture" value={jobConfig.config.process[0].model.arch} onChange={value => { - const currentArch = modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch); - if (!currentArch || currentArch.name === value) { - return; - } - // update the defaults when a model is selected - const newArch = modelArchs.find(model => model.name === value); - - // update vram setting - if (!newArch?.additionalSections?.includes('model.low_vram')) { - setJobConfig(false, 'config.process[0].model.low_vram'); - } - - // revert defaults from previous model - for (const key in currentArch.defaults) { - setJobConfig(currentArch.defaults[key][1], key); - } - - if (newArch?.defaults) { - for (const key in newArch.defaults) { - setJobConfig(newArch.defaults[key][0], key); - } - } - // set new model - setJobConfig(value, 'config.process[0].model.arch'); - - // update datasets - const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false; - const hasNumFrames = newArch?.additionalSections?.includes('datasets.num_frames') || false; - const controls = newArch?.controls ?? []; - const datasets = jobConfig.config.process[0].datasets.map(dataset => { - const newDataset = objectCopy(dataset); - newDataset.controls = controls; - if (!hasControlPath) { - newDataset.control_path = null; // reset control path if not applicable - } - if (!hasNumFrames) { - newDataset.num_frames = 1; // reset num_frames if not applicable - } - return newDataset; - }); - setJobConfig(datasets, 'config.process[0].datasets'); - - // update samples - const hasSampleCtrlImg = newArch?.additionalSections?.includes('sample.ctrl_img') || false; - const samples = jobConfig.config.process[0].sample.samples.map(sample => { - const newSample = objectCopy(sample); - if (!hasSampleCtrlImg) { - delete newSample.ctrl_img; // remove ctrl_img if not applicable - } - return newSample; - }); - setJobConfig(samples, 'config.process[0].sample.samples'); + handleModelArchChange(jobConfig.config.process[0].model.arch, value, jobConfig, setJobConfig); }} options={groupedModelOptions} /> @@ -557,17 +508,19 @@ export default function SimpleJob({ )} - { - setJobConfig(value, 'config.process[0].train.unload_text_encoder'); - if (value) { - setJobConfig(false, 'config.process[0].train.cache_text_embeddings'); - } - }} - /> + {!disableSections.includes('train.unload_text_encoder') && ( + { + setJobConfig(value, 'config.process[0].train.unload_text_encoder'); + if (value) { + setJobConfig(false, 'config.process[0].train.cache_text_embeddings'); + } + }} + /> + )}
setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)} options={datasetOptions} @@ -659,6 +612,49 @@ export default function SimpleJob({ options={[{ value: '', label: <>  }, ...datasetOptions]} /> )} + {modelArch?.additionalSections?.includes('datasets.multi_control_paths') && ( + <> + + setJobConfig( + value == '' ? null : value, + `config.process[0].datasets[${i}].control_path_1`, + ) + } + options={[{ value: '', label: <>  }, ...datasetOptions]} + /> + + setJobConfig( + value == '' ? null : value, + `config.process[0].datasets[${i}].control_path_2`, + ) + } + options={[{ value: '', label: <>  }, ...datasetOptions]} + /> + + setJobConfig( + value == '' ? null : value, + `config.process[0].datasets[${i}].control_path_3`, + ) + } + options={[{ value: '', label: <>  }, ...datasetOptions]} + /> + + )}
- + {modelArch?.additionalSections?.includes('datasets.multi_control_paths') && ( + +
+ {['ctrl_img_1', 'ctrl_img_2', 'ctrl_img_3'].map((ctrlKey, ctrl_idx) => ( + { + if (!imagePath) { + let newSamples = objectCopy(jobConfig.config.process[0].sample.samples); + delete newSamples[i][ctrlKey as keyof typeof sample]; + setJobConfig(newSamples, 'config.process[0].sample.samples'); + } else { + setJobConfig(imagePath, `config.process[0].sample.samples[${i}].${ctrlKey}`); + } + }} + /> + ))} +
+
+ )} {modelArch?.additionalSections?.includes('sample.ctrl_img') && ( -
{ - openAddImageModal(imagePath => { - console.log('Selected image path:', imagePath); - if (!imagePath) return; + { + if (!imagePath) { + let newSamples = objectCopy(jobConfig.config.process[0].sample.samples); + delete newSamples[i].ctrl_img; + setJobConfig(newSamples, 'config.process[0].sample.samples'); + } else { setJobConfig(imagePath, `config.process[0].sample.samples[${i}].ctrl_img`); - }); + } }} - > - {!sample.ctrl_img && ( -
Add Control Image
- )} -
+ /> )}
diff --git a/ui/src/app/jobs/new/jobConfig.ts b/ui/src/app/jobs/new/jobConfig.ts index c18f7327..54f34ca9 100644 --- a/ui/src/app/jobs/new/jobConfig.ts +++ b/ui/src/app/jobs/new/jobConfig.ts @@ -2,7 +2,6 @@ import { JobConfig, DatasetConfig, SliderConfig } from '@/types'; export const defaultDatasetConfig: DatasetConfig = { folder_path: '/path/to/images/folder', - control_path: null, mask_path: null, mask_min_value: 0.1, default_caption: '', diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index 7735ae35..eec57717 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -9,12 +9,15 @@ type DisableableSections = | 'network.conv' | 'trigger_word' | 'train.diff_output_preservation' + | 'train.unload_text_encoder' | 'slider'; type AdditionalSections = | 'datasets.control_path' + | 'datasets.multi_control_paths' | 'datasets.do_i2v' | 'sample.ctrl_img' + | 'sample.multi_ctrl_imgs' | 'datasets.num_frames' | 'model.multistage' | 'model.low_vram'; @@ -335,6 +338,28 @@ export const modelArchs: ModelArch[] = [ '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_torchao_uint3.safetensors', }, }, + { + name: 'qwen_image_edit_plus', + label: 'Qwen-Image-Edit-2509', + group: 'instruction', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Qwen/Qwen-Image-Edit-2509', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.unload_text_encoder': [false, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + }, + disableSections: ['network.conv', 'train.unload_text_encoder'], + additionalSections: ['datasets.multi_control_paths', 'sample.multi_ctrl_imgs', 'model.low_vram'], + accuracyRecoveryAdapters: { + '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_2509_torchao_uint3.safetensors', + }, + }, { name: 'hidream', label: 'HiDream', diff --git a/ui/src/app/jobs/new/utils.ts b/ui/src/app/jobs/new/utils.ts new file mode 100644 index 00000000..5d3803e8 --- /dev/null +++ b/ui/src/app/jobs/new/utils.ts @@ -0,0 +1,105 @@ +import { GroupedSelectOption, JobConfig, SelectOption } from '@/types'; +import { modelArchs, ModelArch } from './options'; +import { objectCopy } from '@/utils/basic'; + +export const handleModelArchChange = ( + currentArchName: string, + newArchName: string, + jobConfig: JobConfig, + setJobConfig: (value: any, key: string) => void, +) => { + const currentArch = modelArchs.find(a => a.name === currentArchName); + if (!currentArch || currentArch.name === newArchName) { + return; + } + + // update the defaults when a model is selected + const newArch = modelArchs.find(model => model.name === newArchName); + + // update vram setting + if (!newArch?.additionalSections?.includes('model.low_vram')) { + setJobConfig(false, 'config.process[0].model.low_vram'); + } + + // revert defaults from previous model + for (const key in currentArch.defaults) { + setJobConfig(currentArch.defaults[key][1], key); + } + + if (newArch?.defaults) { + for (const key in newArch.defaults) { + setJobConfig(newArch.defaults[key][0], key); + } + } + // set new model + setJobConfig(newArchName, 'config.process[0].model.arch'); + + // update datasets + const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false; + const hasMultiControlPaths = newArch?.additionalSections?.includes('datasets.multi_control_paths') || false; + const hasNumFrames = newArch?.additionalSections?.includes('datasets.num_frames') || false; + const controls = newArch?.controls ?? []; + const datasets = jobConfig.config.process[0].datasets.map(dataset => { + const newDataset = objectCopy(dataset); + newDataset.controls = controls; + if (hasMultiControlPaths) { + // make sure the config has the multi control paths + newDataset.control_path_1 = newDataset.control_path_1 || null; + newDataset.control_path_2 = newDataset.control_path_2 || null; + newDataset.control_path_3 = newDataset.control_path_3 || null; + // if we previously had a single control path and now + // we selected a multi control model + if (newDataset.control_path && newDataset.control_path !== '') { + // only set if not overwriting + if (!newDataset.control_path_1) { + newDataset.control_path_1 = newDataset.control_path; + } + } + delete newDataset.control_path; // remove single control path + } else if (hasControlPath) { + newDataset.control_path = newDataset.control_path || null; + if (newDataset.control_path_1 && newDataset.control_path_1 !== '') { + newDataset.control_path = newDataset.control_path_1; + } + if (newDataset.control_path_1) { + delete newDataset.control_path_1; + } + if (newDataset.control_path_2) { + delete newDataset.control_path_2; + } + if (newDataset.control_path_3) { + delete newDataset.control_path_3; + } + } else { + // does not have control images + if (newDataset.control_path) { + delete newDataset.control_path; + } + if (newDataset.control_path_1) { + delete newDataset.control_path_1; + } + if (newDataset.control_path_2) { + delete newDataset.control_path_2; + } + if (newDataset.control_path_3) { + delete newDataset.control_path_3; + } + } + if (!hasNumFrames) { + newDataset.num_frames = 1; // reset num_frames if not applicable + } + return newDataset; + }); + setJobConfig(datasets, 'config.process[0].datasets'); + + // update samples + const hasSampleCtrlImg = newArch?.additionalSections?.includes('sample.ctrl_img') || false; + const samples = jobConfig.config.process[0].sample.samples.map(sample => { + const newSample = objectCopy(sample); + if (!hasSampleCtrlImg) { + delete newSample.ctrl_img; // remove ctrl_img if not applicable + } + return newSample; + }); + setJobConfig(samples, 'config.process[0].sample.samples'); +}; diff --git a/ui/src/components/SampleControlImage.tsx b/ui/src/components/SampleControlImage.tsx new file mode 100644 index 00000000..b10b2cef --- /dev/null +++ b/ui/src/components/SampleControlImage.tsx @@ -0,0 +1,206 @@ +'use client'; + +import React, { useCallback, useMemo, useRef, useState } from 'react'; +import classNames from 'classnames'; +import { useDropzone } from 'react-dropzone'; +import { FaUpload, FaImage, FaTimes } from 'react-icons/fa'; +import { apiClient } from '@/utils/api'; +import type { AxiosProgressEvent } from 'axios'; + +interface Props { + src: string | null | undefined; + className?: string; + instruction?: string; + onNewImageSelected: (imagePath: string | null) => void; +} + +export default function SampleControlImage({ + src, + className, + instruction = 'Add Control Image', + onNewImageSelected, +}: Props) { + const [isUploading, setIsUploading] = useState(false); + const [uploadProgress, setUploadProgress] = useState(0); + const [localPreview, setLocalPreview] = useState(null); + const fileInputRef = useRef(null); + + const backgroundUrl = useMemo(() => { + if (localPreview) return localPreview; + if (src) return `/api/img/${encodeURIComponent(src)}`; + return null; + }, [src, localPreview]); + + const handleUpload = useCallback( + async (file: File) => { + if (!file) return; + setIsUploading(true); + setUploadProgress(0); + + const objectUrl = URL.createObjectURL(file); + setLocalPreview(objectUrl); + + const formData = new FormData(); + formData.append('files', file); + + try { + const resp = await apiClient.post(`/api/img/upload`, formData, { + headers: { 'Content-Type': 'multipart/form-data' }, + onUploadProgress: (evt: AxiosProgressEvent) => { + const total = evt.total ?? 100; + const loaded = evt.loaded ?? 0; + setUploadProgress(Math.round((loaded * 100) / total)); + }, + timeout: 0, + }); + + const uploaded = resp?.data?.files?.[0] ?? null; + onNewImageSelected(uploaded); + } catch (err) { + console.error('Upload failed:', err); + setLocalPreview(null); + } finally { + setIsUploading(false); + setUploadProgress(0); + URL.revokeObjectURL(objectUrl); + if (fileInputRef.current) fileInputRef.current.value = ''; + } + }, + [onNewImageSelected], + ); + + const onDrop = useCallback( + (acceptedFiles: File[]) => { + if (acceptedFiles.length === 0) return; + handleUpload(acceptedFiles[0]); + }, + [handleUpload], + ); + + const clearImage = useCallback( + (e?: React.MouseEvent) => { + console.log('clearImage'); + if (e) { + e.stopPropagation(); + e.preventDefault(); + } + setLocalPreview(null); + onNewImageSelected(null); + if (fileInputRef.current) fileInputRef.current.value = ''; + }, + [onNewImageSelected], + ); + + // Drag & drop only; click handled via our own hidden input + const { getRootProps, isDragActive } = useDropzone({ + onDrop, + accept: { 'image/*': ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'] }, + multiple: false, + noClick: true, + noKeyboard: true, + }); + + const rootProps = getRootProps(); + + return ( +
!isUploading && fileInputRef.current?.click()} + > + {/* Hidden input for click-to-open */} + { + const file = e.currentTarget.files?.[0]; + if (file) handleUpload(file); + }} + /> + + {/* Empty state — centered */} + {!backgroundUrl && ( +
+ +
{instruction}
+
Click or drop
+
+ )} + + {/* Existing image overlays */} + {backgroundUrl && !isUploading && ( + <> +
+
+ + Replace +
+
+ + {/* Clear (X) button */} + + + )} + + {/* Uploading overlay */} + {isUploading && ( +
+
+
+
+
+
Uploading… {uploadProgress}%
+
+
+ )} +
+ ); +} diff --git a/ui/src/docs.tsx b/ui/src/docs.tsx index cadf0e01..b93ae3a2 100644 --- a/ui/src/docs.tsx +++ b/ui/src/docs.tsx @@ -53,10 +53,26 @@ const docs: { [key: string]: ConfigDoc } = { }, 'datasets.control_path': { title: 'Control Dataset', + description: ( + <> + The control dataset needs to have files that match the filenames of your training dataset. They should be + matching file pairs. These images are fed as control/input images during training. The control images will be + resized to match the training images. + + ), + }, + 'datasets.multi_control_paths': { + title: 'Multi Control Dataset', description: ( <> The control dataset needs to have files that match the filenames of your training dataset. They should be matching file pairs. These images are fed as control/input images during training. +
+
+ For multi control datasets, the controls will all be applied in the order they are listed. If the model does not + require the images to be the same aspect ratios, such as with Qwen/Qwen-Image-Edit-2509, then the control images + do not need to match the aspect size or aspect ratio of the target image and they will be automatically resized to + the ideal resolutions for the model / target images. ), }, diff --git a/ui/src/types.ts b/ui/src/types.ts index e5ffb25b..71048ca5 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -83,12 +83,15 @@ export interface DatasetConfig { cache_latents_to_disk?: boolean; resolution: number[]; controls: string[]; - control_path: string | null; + control_path?: string | null; num_frames: number; shrink_video_to_frames: boolean; do_i2v: boolean; flip_x: boolean; flip_y: boolean; + control_path_1?: string | null; + control_path_2?: string | null; + control_path_3?: string | null; } export interface EMAConfig { @@ -155,6 +158,9 @@ export interface SampleItem { ctrl_img?: string | null; ctrl_idx?: number; network_multiplier?: number; + ctrl_img_1?: string | null; + ctrl_img_2?: string | null; + ctrl_img_3?: string | null; } export interface SampleConfig { diff --git a/version.py b/version.py index dd1e6dbf..af9f7a48 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.5.10" \ No newline at end of file +VERSION = "0.6.0" \ No newline at end of file