From 3086a58e5b02ecd00c0fc65eb04d72e2a91ecfd3 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 1 Oct 2025 14:12:17 -0600 Subject: [PATCH] git status --- .../qwen_image/qwen_image_edit_plus.py | 86 +++-- .../qwen_image/qwen_image_pipelines.py | 354 ++++++++++++++++++ jobs/process/BaseSDTrainProcess.py | 1 + toolkit/config_modules.py | 9 +- toolkit/memory_management/__init__.py | 1 + toolkit/memory_management/manager.py | 12 + toolkit/models/base_model.py | 3 + toolkit/stable_diffusion_model.py | 3 + 8 files changed, 438 insertions(+), 31 deletions(-) create mode 100644 extensions_built_in/diffusion_models/qwen_image/qwen_image_pipelines.py create mode 100644 toolkit/memory_management/__init__.py create mode 100644 toolkit/memory_management/manager.py diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py index 95a05c5a..edb894ad 100644 --- a/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py @@ -30,8 +30,11 @@ if TYPE_CHECKING: from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO try: - from diffusers import QwenImageEditPlusPipeline - from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus import CONDITION_IMAGE_SIZE, VAE_IMAGE_SIZE + from .qwen_image_pipelines import QwenImageEditPlusCustomPipeline + from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus import ( + CONDITION_IMAGE_SIZE, + VAE_IMAGE_SIZE, + ) except ImportError: raise ImportError( "Diffusers is out of date. Update diffusers to the latest version by doing 'pip uninstall diffusers' and then 'pip install -r requirements.txt'" @@ -41,7 +44,7 @@ except ImportError: class QwenImageEditPlusModel(QwenImageModel): arch = "qwen_image_edit_plus" _qwen_image_keep_visual = True - _qwen_pipeline = QwenImageEditPlusPipeline + _qwen_pipeline = QwenImageEditPlusCustomPipeline def __init__( self, @@ -72,7 +75,7 @@ class QwenImageEditPlusModel(QwenImageModel): def get_generation_pipeline(self): scheduler = QwenImageModel.get_train_scheduler() - pipeline: QwenImageEditPlusPipeline = QwenImageEditPlusPipeline( + pipeline: QwenImageEditPlusCustomPipeline = QwenImageEditPlusCustomPipeline( scheduler=scheduler, text_encoder=unwrap_model(self.text_encoder[0]), tokenizer=self.tokenizer[0], @@ -87,7 +90,7 @@ class QwenImageEditPlusModel(QwenImageModel): def generate_single_image( self, - pipeline: QwenImageEditPlusPipeline, + pipeline: QwenImageEditPlusCustomPipeline, gen_config: GenerateImageConfig, conditional_embeds: PromptEmbeds, unconditional_embeds: PromptEmbeds, @@ -108,7 +111,7 @@ class QwenImageEditPlusModel(QwenImageModel): control_img = Image.open(gen_config.ctrl_img_1) control_img = control_img.convert("RGB") control_img_list.append(control_img) - + if gen_config.ctrl_img_2 is not None: control_img = Image.open(gen_config.ctrl_img_2) control_img = control_img.convert("RGB") @@ -147,6 +150,7 @@ class QwenImageEditPlusModel(QwenImageModel): latents=gen_config.latents, generator=generator, callback_on_step_end=callback_on_step_end, + do_cfg_norm=gen_config.do_cfg_norm, **extra, ).images[0] return img @@ -205,25 +209,27 @@ class QwenImageEditPlusModel(QwenImageModel): latent_model_input = latent_model_input.reshape( batch_size, (height // 2) * (width // 2), num_channels_latents * 4 ) - + raw_packed_latents = latent_model_input - + img_h2, img_w2 = height // 2, width // 2 - + # build distinct instances per batch item, per mamad8 img_shapes = [(1, img_h2, img_w2) for _ in range(batch_size)] - + # pack controls if batch is None: raise ValueError("Batch is required for QwenImageEditPlusModel") - + # split the latents into batch items so we can concat the controls packed_latents_list = torch.chunk(latent_model_input, batch_size, dim=0) packed_latents_with_controls_list = [] - + if batch.control_tensor_list is not None: if len(batch.control_tensor_list) != batch_size: - raise ValueError("Control tensor list length does not match batch size") + raise ValueError( + "Control tensor list length does not match batch size" + ) b = 0 for control_tensor_list in batch.control_tensor_list: # control tensor list is a list of tensors for this batch item @@ -231,7 +237,9 @@ class QwenImageEditPlusModel(QwenImageModel): # pack control for control_img in control_tensor_list: # control images are 0 - 1 scale, shape (1, ch, height, width) - control_img = control_img.to(self.device_torch, dtype=self.torch_dtype) + control_img = control_img.to( + self.device_torch, dtype=self.torch_dtype + ) # if it is only 3 dim, add batch dim if len(control_img.shape) == 3: control_img = control_img.unsqueeze(0) @@ -245,38 +253,54 @@ class QwenImageEditPlusModel(QwenImageModel): control_img = F.interpolate( control_img, size=(c_height, c_width), mode="bilinear" ) - + # scale to -1 to 1 control_img = control_img * 2 - 1 - + control_latent = self.encode_images( control_img, device=self.device_torch, dtype=self.torch_dtype, ) - - clb, cl_num_channels_latents, cl_height, cl_width = control_latent.shape - + + clb, cl_num_channels_latents, cl_height, cl_width = ( + control_latent.shape + ) + control = control_latent.view( - 1, cl_num_channels_latents, cl_height // 2, 2, cl_width // 2, 2 + 1, + cl_num_channels_latents, + cl_height // 2, + 2, + cl_width // 2, + 2, ) control = control.permute(0, 2, 4, 1, 3, 5) control = control.reshape( - 1, (cl_height // 2) * (cl_width // 2), num_channels_latents * 4 + 1, + (cl_height // 2) * (cl_width // 2), + num_channels_latents * 4, ) - + img_shapes[b].append((1, cl_height // 2, cl_width // 2)) controls.append(control) - + # stack controls on dim 1 - control = torch.cat(controls, dim=1).to(packed_latents_list[b].device, dtype=packed_latents_list[b].dtype) + control = torch.cat(controls, dim=1).to( + packed_latents_list[b].device, + dtype=packed_latents_list[b].dtype, + ) # concat with latents - packed_latents_with_control = torch.cat([packed_latents_list[b], control], dim=1) - - packed_latents_with_controls_list.append(packed_latents_with_control) - + packed_latents_with_control = torch.cat( + [packed_latents_list[b], control], dim=1 + ) + + packed_latents_with_controls_list.append( + packed_latents_with_control + ) + b += 1 - + latent_model_input = torch.cat(packed_latents_with_controls_list, dim=0) prompt_embeds_mask = text_embeddings.attention_mask.to( @@ -289,7 +313,9 @@ class QwenImageEditPlusModel(QwenImageModel): ) noise_pred = self.transformer( - hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype).detach(), + hidden_states=latent_model_input.to( + self.device_torch, self.torch_dtype + ).detach(), timestep=(timestep / 1000).detach(), guidance=None, encoder_hidden_states=enc_hs.detach(), diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image_pipelines.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image_pipelines.py new file mode 100644 index 00000000..dd6a9b1d --- /dev/null +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image_pipelines.py @@ -0,0 +1,354 @@ +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch + +try: + from diffusers import QwenImageEditPlusPipeline + from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus import ( + CONDITION_IMAGE_SIZE, + VAE_IMAGE_SIZE, + XLA_AVAILABLE, + logger, + calculate_dimensions, + calculate_shift, + retrieve_timesteps, + ) +except ImportError: + raise ImportError( + "Diffusers is out of date. Update diffusers to the latest version by doing 'pip uninstall diffusers' and then 'pip install -r requirements.txt'" + ) + +from diffusers.image_processor import PipelineImageInput +from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput + + +class QwenImageEditPlusCustomPipeline(QwenImageEditPlusPipeline): + @torch.no_grad() + def __call__( + self, + image: Optional[PipelineImageInput] = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: Optional[float] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + do_cfg_norm: bool = False, + ): + image_size = image[-1].size if isinstance(image, list) else image.size + calculated_width, calculated_height = calculate_dimensions( + 1024 * 1024, image_size[0] / image_size[1] + ) + height = height or calculated_height + width = width or calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # 3. Preprocess image + if image is not None and not ( + isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels + ): + if not isinstance(image, list): + image = [image] + condition_image_sizes = [] + condition_images = [] + vae_image_sizes = [] + vae_images = [] + for img in image: + image_width, image_height = img.size + condition_width, condition_height = calculate_dimensions( + CONDITION_IMAGE_SIZE, image_width / image_height + ) + vae_width, vae_height = calculate_dimensions( + VAE_IMAGE_SIZE, image_width / image_height + ) + condition_image_sizes.append((condition_width, condition_height)) + vae_image_sizes.append((vae_width, vae_height)) + condition_images.append( + self.image_processor.resize(img, condition_height, condition_width) + ) + vae_images.append( + self.image_processor.preprocess( + img, vae_height, vae_width + ).unsqueeze(2) + ) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None + and negative_prompt_embeds_mask is not None + ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + image=condition_images, + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + image=condition_images, + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents = self.prepare_latents( + vae_images, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [ + [ + ( + 1, + height // self.vae_scale_factor // 2, + width // self.vae_scale_factor // 2, + ), + *[ + ( + 1, + vae_height // self.vae_scale_factor // 2, + vae_width // self.vae_scale_factor // 2, + ) + for vae_width, vae_height in vae_image_sizes + ], + ] + ] * batch_size + + # 5. Prepare timesteps + sigmas = ( + np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + if sigmas is None + else sigmas + ) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.config.guidance_embeds: + guidance = torch.full( + [1], guidance_scale, device=device, dtype=torch.float32 + ) + guidance = guidance.expand(latents.shape[0]) + elif not self.transformer.config.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.config.guidance_embeds and guidance_scale is None: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = ( + prompt_embeds_mask.sum(dim=1).tolist() + if prompt_embeds_mask is not None + else None + ) + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() + if negative_prompt_embeds_mask is not None + else None + ) + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + comb_pred = neg_noise_pred + true_cfg_scale * ( + noise_pred - neg_noise_pred + ) + + if do_cfg_norm: + # the official code does this, but I find it hurts more often than it helps, leaving it optional but off by default + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + else: + noise_pred = comb_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False + )[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents( + latents, height, width, self.vae_scale_factor + ) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index e63f899c..a75eb3b2 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -348,6 +348,7 @@ class BaseSDTrainProcess(BaseTrainProcess): ctrl_img_1=sample_item.ctrl_img_1, ctrl_img_2=sample_item.ctrl_img_2, ctrl_img_3=sample_item.ctrl_img_3, + do_cfg_norm=sample_config.do_cfg_norm, **extra_args )) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 75abb5ff..71871617 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -70,6 +70,8 @@ class SampleItem: print(f"Invalid network_multiplier {self.network_multiplier}, defaulting to 1.0") self.network_multiplier = 1.0 + # only for models that support it, (qwen image edit 2509 for now) + self.do_cfg_norm: bool = kwargs.get('do_cfg_norm', False) class SampleConfig: def __init__(self, **kwargs): @@ -104,6 +106,8 @@ class SampleConfig: ] raw_samples = kwargs.get('samples', default_samples_kwargs) self.samples = [SampleItem(self, **item) for item in raw_samples] + # only for models that support it, (qwen image edit 2509 for now) + self.do_cfg_norm: bool = kwargs.get('do_cfg_norm', False) @property def prompts(self): @@ -993,7 +997,8 @@ class GenerateImageConfig: ctrl_img_3: Optional[str] = None, # third control image for multi control model num_frames: int = 1, fps: int = 15, - ctrl_idx: int = 0 + ctrl_idx: int = 0, + do_cfg_norm: bool = False, ): self.width: int = width self.height: int = height @@ -1063,6 +1068,8 @@ class GenerateImageConfig: self.width = max(64, self.width - self.width % 8) # round to divisible by 8 self.logger = logger + + self.do_cfg_norm: bool = do_cfg_norm def set_gen_time(self, gen_time: int = None): if gen_time is not None: diff --git a/toolkit/memory_management/__init__.py b/toolkit/memory_management/__init__.py new file mode 100644 index 00000000..2eeef37d --- /dev/null +++ b/toolkit/memory_management/__init__.py @@ -0,0 +1 @@ +from .manager import MemoryManager \ No newline at end of file diff --git a/toolkit/memory_management/manager.py b/toolkit/memory_management/manager.py new file mode 100644 index 00000000..4480ea94 --- /dev/null +++ b/toolkit/memory_management/manager.py @@ -0,0 +1,12 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from toolkit.models.base_model import BaseModel + + +class MemoryManager: + def __init__( + self, + model: "BaseModel", + ): + self.model: "BaseModel" = model diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 58d48b4f..dd070249 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -41,6 +41,7 @@ from torchvision.transforms import functional as TF from toolkit.accelerator import get_accelerator, unwrap_model from typing import TYPE_CHECKING from toolkit.print import print_acc +from toolkit.memory_management import MemoryManager if TYPE_CHECKING: from toolkit.lora_special import LoRASpecialNetwork @@ -185,6 +186,8 @@ class BaseModel: self.has_multiple_control_images = False # do not resize control images self.use_raw_control_images = False + + self.memory_manager = MemoryManager(self) # properties for old arch for backwards compatibility @property diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 78960ed1..8b72a1c5 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -70,6 +70,7 @@ from typing import TYPE_CHECKING from toolkit.print import print_acc from diffusers import FluxFillPipeline from transformers import AutoModel, AutoTokenizer, Gemma2Model, Qwen2Model, LlamaModel +from toolkit.memory_management import MemoryManager if TYPE_CHECKING: from toolkit.lora_special import LoRASpecialNetwork @@ -224,6 +225,8 @@ class StableDiffusion: # do not resize control images self.use_raw_control_images = False + self.memory_manager = MemoryManager(self) + # properties for old arch for backwards compatibility @property def is_xl(self):