From 3413fa537f78b884d16b38b99eba55a1ae63f456 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 14 Aug 2025 13:03:27 -0600 Subject: [PATCH 1/5] Wan22 14b training is working, still need tons of testing and some bug fixes --- .../diffusion_models/__init__.py | 3 +- .../diffusion_models/wan22/__init__.py | 3 +- .../diffusion_models/wan22/wan22_14b_model.py | 470 ++++++++++++++++++ .../diffusion_models/wan22/wan22_pipeline.py | 51 +- jobs/process/BaseSDTrainProcess.py | 17 +- toolkit/config_modules.py | 3 + toolkit/models/wan21/wan21.py | 11 +- toolkit/network_mixins.py | 20 +- 8 files changed, 554 insertions(+), 24 deletions(-) create mode 100644 extensions_built_in/diffusion_models/wan22/wan22_14b_model.py diff --git a/extensions_built_in/diffusion_models/__init__.py b/extensions_built_in/diffusion_models/__init__.py index 04ff0ff3..a2d5df69 100644 --- a/extensions_built_in/diffusion_models/__init__.py +++ b/extensions_built_in/diffusion_models/__init__.py @@ -3,7 +3,7 @@ from .hidream import HidreamModel, HidreamE1Model from .f_light import FLiteModel from .omnigen2 import OmniGen2Model from .flux_kontext import FluxKontextModel -from .wan22 import Wan225bModel +from .wan22 import Wan225bModel, Wan2214bModel from .qwen_image import QwenImageModel AI_TOOLKIT_MODELS = [ @@ -15,5 +15,6 @@ AI_TOOLKIT_MODELS = [ OmniGen2Model, FluxKontextModel, Wan225bModel, + Wan2214bModel, QwenImageModel, ] diff --git a/extensions_built_in/diffusion_models/wan22/__init__.py b/extensions_built_in/diffusion_models/wan22/__init__.py index 993acbfb..5c88152b 100644 --- a/extensions_built_in/diffusion_models/wan22/__init__.py +++ b/extensions_built_in/diffusion_models/wan22/__init__.py @@ -1 +1,2 @@ -from .wan22_5b_model import Wan225bModel \ No newline at end of file +from .wan22_5b_model import Wan225bModel +from .wan22_14b_model import Wan2214bModel \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py new file mode 100644 index 00000000..3266d320 --- /dev/null +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py @@ -0,0 +1,470 @@ +from functools import partial +import os +from typing import Any, Dict, Optional, Union +from typing_extensions import Self +import torch +import yaml +from toolkit.accelerator import unwrap_model +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from PIL import Image +from diffusers import UniPCMultistepScheduler +import torch +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.util.quantize import quantize_model +from .wan22_pipeline import Wan22Pipeline +from diffusers import WanTransformer3DModel + +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from torchvision.transforms import functional as TF + +from toolkit.models.wan21.wan21 import AggressiveWanUnloadPipeline +from .wan22_5b_model import ( + scheduler_config, + time_text_monkeypatch, + Wan225bModel, +) +from safetensors.torch import load_file, save_file + + +boundary_ratio_t2v = 0.875 +boundary_ratio_i2v = 0.9 + +scheduler_configUniPC = { + "_class_name": "UniPCMultistepScheduler", + "_diffusers_version": "0.35.0.dev0", + "beta_end": 0.02, + "beta_schedule": "linear", + "beta_start": 0.0001, + "disable_corrector": [], + "dynamic_thresholding_ratio": 0.995, + "final_sigmas_type": "zero", + "flow_shift": 3.0, + "lower_order_final": True, + "num_train_timesteps": 1000, + "predict_x0": True, + "prediction_type": "flow_prediction", + "rescale_betas_zero_snr": False, + "sample_max_value": 1.0, + "solver_order": 2, + "solver_p": None, + "solver_type": "bh2", + "steps_offset": 0, + "thresholding": False, + "time_shift_type": "exponential", + "timestep_spacing": "linspace", + "trained_betas": None, + "use_beta_sigmas": False, + "use_dynamic_shifting": False, + "use_exponential_sigmas": False, + "use_flow_sigmas": True, + "use_karras_sigmas": False, +} + + +class DualWanTransformer3DModel(torch.nn.Module): + def __init__( + self, + transformer_1: WanTransformer3DModel, + transformer_2: WanTransformer3DModel, + torch_dtype: Optional[Union[str, torch.dtype]] = None, + device: Optional[Union[str, torch.device]] = None, + boundary_ratio: float = boundary_ratio_t2v, + low_vram: bool = False, + ) -> None: + super().__init__() + self.transformer_1: WanTransformer3DModel = transformer_1 + self.transformer_2: WanTransformer3DModel = transformer_2 + self.torch_dtype: torch.dtype = torch_dtype + self.device_torch: torch.device = device + self.boundary_ratio: float = boundary_ratio + self.boundary: float = self.boundary_ratio * 1000 + self.low_vram: bool = low_vram + self._active_transformer_name = "transformer_1" # default to transformer_1 + + @property + def device(self) -> torch.device: + return self.device_torch + + @property + def dtype(self) -> torch.dtype: + return self.torch_dtype + + @property + def config(self): + return self.transformer_1.config + + @property + def transformer(self) -> WanTransformer3DModel: + return getattr(self, self._active_transformer_name) + + def enable_gradient_checkpointing(self): + """ + Enable gradient checkpointing for both transformers. + """ + self.transformer_1.enable_gradient_checkpointing() + self.transformer_2.enable_gradient_checkpointing() + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + # determine if doing high noise or low noise by meaning the timestep. + # timesteps are in the range of 0 to 1000, so we can use a threshold + with torch.no_grad(): + if timestep.float().mean().item() >= self.boundary: + t_name = "transformer_1" + else: + t_name = "transformer_2" + + # check if we are changing the active transformer, if so, we need to swap the one in + # vram if low_vram is enabled + # todo swap the loras as well + if t_name != self._active_transformer_name: + if self.low_vram: + getattr(self, self._active_transformer_name).to("cpu") + getattr(self, t_name).to(self.device_torch) + torch.cuda.empty_cache() + self._active_transformer_name = t_name + if self.transformer.device != hidden_states.device: + raise ValueError( + f"Transformer device {self.transformer.device} does not match hidden states device {hidden_states.device}" + ) + + return self.transformer( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_image=encoder_hidden_states_image, + return_dict=return_dict, + attention_kwargs=attention_kwargs, + ) + + def to(self, *args, **kwargs) -> Self: + # do not do to, this will be handled separately + return self + + +class Wan2214bModel(Wan225bModel): + arch = "wan22_14b" + _wan_generation_scheduler_config = scheduler_configUniPC + _wan_expand_timesteps = True + _wan_vae_path = "ai-toolkit/wan2.1-vae" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device=device, + model_config=model_config, + dtype=dtype, + custom_pipeline=custom_pipeline, + noise_scheduler=noise_scheduler, + **kwargs, + ) + # target it so we can target both transformers + self.target_lora_modules = ["DualWanTransformer3DModel"] + self._wan_cache = None + + @property + def max_step_saves_to_keep_multiplier(self): + # the cleanup mechanism checks this to see how many saves to keep + # if we are training a LoRA, we need to set this to 2 so we keep both the high noise and low noise LoRAs at saves to keep + if self.network is not None: + return 2 + return 1 + + def load_model(self): + # load model from patent parent. Wan21 not immediate parent + # super().load_model() + super(Wan225bModel, self).load_model() + + # we have to split up the model on the pipeline + self.pipeline.transformer = self.model.transformer_1 + self.pipeline.transformer_2 = self.model.transformer_2 + + # patch the condition embedder + self.model.transformer_1.condition_embedder.forward = partial( + time_text_monkeypatch, self.model.transformer_1.condition_embedder + ) + self.model.transformer_2.condition_embedder.forward = partial( + time_text_monkeypatch, self.model.transformer_2.condition_embedder + ) + + def get_bucket_divisibility(self): + # 16x compression and 2x2 patch size + return 32 + + def load_wan_transformer(self, transformer_path, subfolder=None): + if self.model_config.split_model_over_gpus: + raise ValueError( + "Splitting model over gpus is not supported for Wan2.2 models" + ) + + if ( + self.model_config.assistant_lora_path is not None + or self.model_config.inference_lora_path is not None + ): + raise ValueError( + "Assistant LoRA is not supported for Wan2.2 models currently" + ) + + if self.model_config.lora_path is not None: + raise ValueError( + "Loading LoRA is not supported for Wan2.2 models currently" + ) + + # transformer path can be a directory that ends with /transformer or a hf path. + + transformer_path_1 = transformer_path + subfolder_1 = subfolder + + transformer_path_2 = transformer_path + subfolder_2 = subfolder + + if subfolder_2 is None: + # we have a local path, replace it with transformer_2 folder + transformer_path_2 = os.path.join( + os.path.dirname(transformer_path_1), "transformer_2" + ) + else: + # we have a hf path, replace it with transformer_2 subfolder + subfolder_2 = "transformer_2" + + self.print_and_status_update("Loading transformer 1") + dtype = self.torch_dtype + transformer_1 = WanTransformer3DModel.from_pretrained( + transformer_path_1, + subfolder=subfolder_1, + torch_dtype=dtype, + ).to(dtype=dtype) + + flush() + + if not self.model_config.low_vram: + # quantize on the device + transformer_1.to(self.quantize_device, dtype=dtype) + flush() + + if self.model_config.quantize: + # todo handle two ARAs + self.print_and_status_update("Quantizing Transformer 1") + quantize_model(self, transformer_1) + flush() + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer 1 to CPU") + transformer_1.to("cpu") + + self.print_and_status_update("Loading transformer 2") + dtype = self.torch_dtype + transformer_2 = WanTransformer3DModel.from_pretrained( + transformer_path_2, + subfolder=subfolder_2, + torch_dtype=dtype, + ).to(dtype=dtype) + + flush() + + if not self.model_config.low_vram: + # quantize on the device + transformer_2.to(self.quantize_device, dtype=dtype) + flush() + + if self.model_config.quantize: + # todo handle two ARAs + self.print_and_status_update("Quantizing Transformer 2") + quantize_model(self, transformer_2) + flush() + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer 2 to CPU") + transformer_2.to("cpu") + + # make the combined model + self.print_and_status_update("Creating DualWanTransformer3DModel") + transformer = DualWanTransformer3DModel( + transformer_1=transformer_1, + transformer_2=transformer_2, + torch_dtype=self.torch_dtype, + device=self.device_torch, + boundary_ratio=boundary_ratio_t2v, + low_vram=self.model_config.low_vram, + ) + + return transformer + + def get_generation_pipeline(self): + scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config) + pipeline = Wan22Pipeline( + vae=self.vae, + transformer=self.model.transformer_1, + transformer_2=self.model.transformer_2, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=scheduler, + expand_timesteps=self._wan_expand_timesteps, + device=self.device_torch, + aggressive_offload=self.model_config.low_vram, + # todo detect if it is i2v or t2v + boundary_ratio=boundary_ratio_t2v, + ) + + # pipeline = pipeline.to(self.device_torch) + + return pipeline + + # static method to get the scheduler + @staticmethod + def get_train_scheduler(): + scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + return scheduler + + def get_base_model_version(self): + return "wan_2.2_14b" + + def generate_single_image( + self, + pipeline: AggressiveWanUnloadPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + return super().generate_single_image( + pipeline=pipeline, + gen_config=gen_config, + conditional_embeds=conditional_embeds, + unconditional_embeds=unconditional_embeds, + generator=generator, + extra=extra, + ) + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + batch: DataLoaderBatchDTO, + **kwargs, + ): + # todo do we need to override this? Adjust timesteps? + return super().get_noise_prediction( + latent_model_input=latent_model_input, + timestep=timestep, + text_embeddings=text_embeddings, + batch=batch, + **kwargs, + ) + + def get_model_has_grad(self): + return False + + def get_te_has_grad(self): + return False + + def save_model(self, output_path, meta, save_dtype): + transformer_combo: DualWanTransformer3DModel = unwrap_model(self.model) + transformer_combo.transformer_1.save_pretrained( + save_directory=os.path.join(output_path, "transformer"), + safe_serialization=True, + ) + transformer_combo.transformer_2.save_pretrained( + save_directory=os.path.join(output_path, "transformer_2"), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, "aitk_meta.yaml") + with open(meta_path, "w") as f: + yaml.dump(meta, f) + + def save_lora( + self, + state_dict: Dict[str, torch.Tensor], + output_path: str, + metadata: Optional[Dict[str, Any]] = None, + ): + if not self.network.network_config.split_multistage_loras: + # just save as a combo lora + save_file(state_dict, output_path, metadata=metadata) + return + + # we need to build out both dictionaries for high and low noise LoRAs + high_noise_lora = {} + low_noise_lora = {} + + for key in state_dict: + if ".transformer_1." in key: + # this is a high noise LoRA + new_key = key.replace(".transformer_1.", ".") + high_noise_lora[new_key] = state_dict[key] + elif ".transformer_2." in key: + # this is a low noise LoRA + new_key = key.replace(".transformer_2.", ".") + low_noise_lora[new_key] = state_dict[key] + + # loras have either LORA_MODEL_NAME_000005000.safetensors or LORA_MODEL_NAME.safetensors + if len(high_noise_lora.keys()) > 0: + # save the high noise LoRA + high_noise_lora_path = output_path.replace( + ".safetensors", "_high_noise.safetensors" + ) + save_file(high_noise_lora, high_noise_lora_path, metadata=metadata) + + if len(low_noise_lora.keys()) > 0: + # save the low noise LoRA + low_noise_lora_path = output_path.replace( + ".safetensors", "_low_noise.safetensors" + ) + save_file(low_noise_lora, low_noise_lora_path, metadata=metadata) + + def load_lora(self, file: str): + # if it doesnt have high_noise or low_noise, it is a combo LoRA + if "_high_noise.safetensors" not in file and "_low_noise.safetensors" not in file: + # this is a combined LoRA, we need to split it up + sd = load_file(file) + return sd + + # we may have been passed the high_noise or the low_noise LoRA path, but we need to load both + high_noise_lora_path = file.replace( + "_low_noise.safetensors", "_high_noise.safetensors" + ) + low_noise_lora_path = file.replace( + "_high_noise.safetensors", "_low_noise.safetensors" + ) + + combined_dict = {} + + if os.path.exists(high_noise_lora_path): + # load the high noise LoRA + high_noise_lora = load_file(high_noise_lora_path) + for key in high_noise_lora: + new_key = key.replace( + "diffusion_model.", "diffusion_model.transformer_1." + ) + combined_dict[new_key] = high_noise_lora[key] + if os.path.exists(low_noise_lora_path): + # load the low noise LoRA + low_noise_lora = load_file(low_noise_lora_path) + for key in low_noise_lora: + new_key = key.replace( + "diffusion_model.", "diffusion_model.transformer_2." + ) + combined_dict[new_key] = low_noise_lora[key] + + return combined_dict diff --git a/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py b/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py index 6c7b93b7..dafa2012 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py @@ -12,7 +12,6 @@ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from typing import Any, Callable, Dict, List, Optional, Union - class Wan22Pipeline(WanPipeline): def __init__( self, @@ -52,6 +51,7 @@ class Wan22Pipeline(WanPipeline): num_frames: int = 81, num_inference_steps: int = 50, guidance_scale: float = 5.0, + guidance_scale_2: Optional[float] = None, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -77,11 +77,15 @@ class Wan22Pipeline(WanPipeline): vae_device = self.vae.device transformer_device = self.transformer.device text_encoder_device = self.text_encoder.device - device = self.transformer.device + device = self._exec_device if self._aggressive_offload: print("Unloading vae") self.vae.to("cpu") + print("Unloading transformer") + self.transformer.to("cpu") + if self.transformer_2 is not None: + self.transformer_2.to("cpu") self.text_encoder.to(device) flush() @@ -95,9 +99,14 @@ class Wan22Pipeline(WanPipeline): prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, + guidance_scale_2 ) + + if self.config.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False @@ -160,6 +169,13 @@ class Wan22Pipeline(WanPipeline): num_warmup_steps = len(timesteps) - \ num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) + + if self.config.boundary_ratio is not None: + boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + current_model = self.transformer with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -167,6 +183,25 @@ class Wan22Pipeline(WanPipeline): continue self._current_timestep = t + + if boundary_timestep is None or t >= boundary_timestep: + if self._aggressive_offload and current_model != self.transformer: + if self.transformer_2 is not None: + self.transformer_2.to("cpu") + self.transformer.to(device) + # wan2.1 or high-noise stage in wan2.2 + current_model = self.transformer + current_guidance_scale = guidance_scale + else: + if self._aggressive_offload and current_model != self.transformer_2: + if self.transformer is not None: + self.transformer.to("cpu") + if self.transformer_2 is not None: + self.transformer_2.to(device) + # low-noise stage in wan2.2 + current_model = self.transformer_2 + current_guidance_scale = guidance_scale_2 + latent_model_input = latents.to(device, transformer_dtype) if self.config.expand_timesteps: # seq_len: num_latent_frames * latent_height//2 * latent_width//2 @@ -176,7 +211,7 @@ class Wan22Pipeline(WanPipeline): else: timestep = t.expand(latents.shape[0]) - noise_pred = self.transformer( + noise_pred = current_model( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, @@ -185,14 +220,14 @@ class Wan22Pipeline(WanPipeline): )[0] if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + noise_uncond = current_model( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] - noise_pred = noise_uncond + guidance_scale * \ + noise_pred = noise_uncond + current_guidance_scale * \ (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 @@ -259,10 +294,8 @@ class Wan22Pipeline(WanPipeline): # move transformer back to device if self._aggressive_offload: - print("Moving transformer back to device") - self.transformer.to(self._execution_device) - if self.transformer_2 is not None: - self.transformer_2.to(self._execution_device) + # print("Moving transformer back to device") + # self.transformer.to(self._execution_device) flush() if not return_dict: diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 851237cb..daf3d33c 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -437,19 +437,24 @@ class BaseSDTrainProcess(BaseTrainProcess): # Combine and sort the lists combined_items = safetensors_files + directories + pt_files combined_items.sort(key=os.path.getctime) + + num_saves_to_keep = self.save_config.max_step_saves_to_keep + + if hasattr(self.sd, 'max_step_saves_to_keep_multiplier'): + num_saves_to_keep *= self.sd.max_step_saves_to_keep_multiplier # Use slicing with a check to avoid 'NoneType' error safetensors_to_remove = safetensors_files[ - :-self.save_config.max_step_saves_to_keep] if safetensors_files else [] - pt_files_to_remove = pt_files[:-self.save_config.max_step_saves_to_keep] if pt_files else [] - directories_to_remove = directories[:-self.save_config.max_step_saves_to_keep] if directories else [] - embeddings_to_remove = embed_files[:-self.save_config.max_step_saves_to_keep] if embed_files else [] - critic_to_remove = critic_items[:-self.save_config.max_step_saves_to_keep] if critic_items else [] + :-num_saves_to_keep] if safetensors_files else [] + pt_files_to_remove = pt_files[:-num_saves_to_keep] if pt_files else [] + directories_to_remove = directories[:-num_saves_to_keep] if directories else [] + embeddings_to_remove = embed_files[:-num_saves_to_keep] if embed_files else [] + critic_to_remove = critic_items[:-num_saves_to_keep] if critic_items else [] items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove + embeddings_to_remove + critic_to_remove # remove all but the latest max_step_saves_to_keep - # items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep] + # items_to_remove = combined_items[:-num_saves_to_keep] # remove duplicates items_to_remove = list(dict.fromkeys(items_to_remove)) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index a825830b..81415a10 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -185,6 +185,9 @@ class NetworkConfig: self.conv_alpha = 9999999999 # -1 automatically finds the largest factor self.lokr_factor = kwargs.get('lokr_factor', -1) + + # for multi stage models + self.split_multistage_loras = kwargs.get('split_multistage_loras', True) AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v'] diff --git a/toolkit/models/wan21/wan21.py b/toolkit/models/wan21/wan21.py index bdc2f601..ecdc8f3f 100644 --- a/toolkit/models/wan21/wan21.py +++ b/toolkit/models/wan21/wan21.py @@ -310,6 +310,7 @@ class Wan21(BaseModel): arch = 'wan21' _wan_generation_scheduler_config = scheduler_configUniPC _wan_expand_timesteps = False + _wan_vae_path = None _comfy_te_file = ['text_encoders/umt5_xxl_fp16.safetensors', 'text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors'] def __init__( @@ -431,8 +432,14 @@ class Wan21(BaseModel): scheduler = Wan21.get_train_scheduler() self.print_and_status_update("Loading VAE") # todo, example does float 32? check if quality suffers - vae = AutoencoderKLWan.from_pretrained( - vae_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype) + + if self._wan_vae_path is not None: + # load the vae from individual repo + vae = AutoencoderKLWan.from_pretrained( + self._wan_vae_path, torch_dtype=dtype).to(dtype=dtype) + else: + vae = AutoencoderKLWan.from_pretrained( + vae_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype) flush() self.print_and_status_update("Making pipe") diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 796e4042..59c15d3d 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -565,6 +565,13 @@ class ToolkitNetworkMixin: if metadata is None: metadata = OrderedDict() metadata = add_model_hash_to_meta(save_dict, metadata) + # let the model handle the saving + + if self.base_model_ref is not None and hasattr(self.base_model_ref(), 'save_lora'): + # call the base model save lora method + self.base_model_ref().save_lora(save_dict, file, metadata) + return + if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import save_file save_file(save_dict, file, metadata) @@ -577,12 +584,15 @@ class ToolkitNetworkMixin: keymap = {} if keymap is None else keymap if isinstance(file, str): - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file - - weights_sd = load_file(file) + if self.base_model_ref is not None and hasattr(self.base_model_ref(), 'load_lora'): + # call the base model load lora method + weights_sd = self.base_model_ref().load_lora(file) else: - weights_sd = torch.load(file, map_location="cpu") + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") else: # probably a state dict weights_sd = file From 1c96b95617f6e2dd927c05fb8931842baeb987a7 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 14 Aug 2025 14:24:41 -0600 Subject: [PATCH 2/5] Fix issue where sometimes the transformer does not get loaded properly. --- .../diffusion_models/wan22/wan22_14b_model.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py index 3266d320..79886ea4 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py @@ -134,10 +134,14 @@ class DualWanTransformer3DModel(torch.nn.Module): getattr(self, t_name).to(self.device_torch) torch.cuda.empty_cache() self._active_transformer_name = t_name + if self.transformer.device != hidden_states.device: - raise ValueError( - f"Transformer device {self.transformer.device} does not match hidden states device {hidden_states.device}" - ) + if self.low_vram: + # move other transformer to cpu + other_tname = 'transformer_1' if t_name == 'transformer_2' else 'transformer_2' + getattr(self, other_tname).to("cpu") + + self.transformer.to(hidden_states.device) return self.transformer( hidden_states=hidden_states, From ca7bfa414b68b2bf22dc87c74b712a7cd8616e2d Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 16 Aug 2025 05:27:38 -0600 Subject: [PATCH 3/5] Increase max number of samples to 40 --- ui/src/components/SampleImages.tsx | 42 +++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/ui/src/components/SampleImages.tsx b/ui/src/components/SampleImages.tsx index 6f134ad4..8f4b80d9 100644 --- a/ui/src/components/SampleImages.tsx +++ b/ui/src/components/SampleImages.tsx @@ -27,7 +27,7 @@ export default function SampleImages({ job }: SampleImagesProps) { // This way Tailwind can properly generate the class // I hate this, but it's the only way to make it work const gridColsClass = useMemo(() => { - const cols = Math.min(numSamples, 20); + const cols = Math.min(numSamples, 40); switch (cols) { case 1: @@ -70,6 +70,46 @@ export default function SampleImages({ job }: SampleImagesProps) { return 'grid-cols-19'; case 20: return 'grid-cols-20'; + case 21: + return 'grid-cols-21'; + case 22: + return 'grid-cols-22'; + case 23: + return 'grid-cols-23'; + case 24: + return 'grid-cols-24'; + case 25: + return 'grid-cols-25'; + case 26: + return 'grid-cols-26'; + case 27: + return 'grid-cols-27'; + case 28: + return 'grid-cols-28'; + case 29: + return 'grid-cols-29'; + case 30: + return 'grid-cols-30'; + case 31: + return 'grid-cols-31'; + case 32: + return 'grid-cols-32'; + case 33: + return 'grid-cols-33'; + case 34: + return 'grid-cols-34'; + case 35: + return 'grid-cols-35'; + case 36: + return 'grid-cols-36'; + case 37: + return 'grid-cols-37'; + case 38: + return 'grid-cols-38'; + case 39: + return 'grid-cols-39'; + case 40: + return 'grid-cols-40'; default: return 'grid-cols-1'; } From 8ea2cf00f65c990c152d961c158a239954dc7025 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 16 Aug 2025 05:51:37 -0600 Subject: [PATCH 4/5] Added training to the ui. Still testing, but everything seems to be working. --- .../diffusion_models/wan22/wan22_14b_model.py | 98 ++++++++++++++++--- extensions_built_in/sd_trainer/SDTrainer.py | 15 ++- jobs/process/BaseSDTrainProcess.py | 31 +++++- toolkit/config_modules.py | 5 +- toolkit/models/base_model.py | 9 ++ toolkit/stable_diffusion_model.py | 9 ++ ui/src/app/jobs/new/SimpleJob.tsx | 58 +++++++++-- ui/src/app/jobs/new/jobConfig.ts | 1 + ui/src/app/jobs/new/options.ts | 46 ++++++++- ui/src/components/formInputs.tsx | 4 +- ui/src/docs.tsx | 30 ++++++ ui/src/types.ts | 1 + 12 files changed, 268 insertions(+), 39 deletions(-) diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py index 79886ea4..a1b75d49 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py @@ -1,6 +1,6 @@ from functools import partial import os -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, List from typing_extensions import Self import torch import yaml @@ -134,13 +134,15 @@ class DualWanTransformer3DModel(torch.nn.Module): getattr(self, t_name).to(self.device_torch) torch.cuda.empty_cache() self._active_transformer_name = t_name - + if self.transformer.device != hidden_states.device: if self.low_vram: # move other transformer to cpu - other_tname = 'transformer_1' if t_name == 'transformer_2' else 'transformer_2' + other_tname = ( + "transformer_1" if t_name == "transformer_2" else "transformer_2" + ) getattr(self, other_tname).to("cpu") - + self.transformer.to(hidden_states.device) return self.transformer( @@ -184,11 +186,33 @@ class Wan2214bModel(Wan225bModel): self.target_lora_modules = ["DualWanTransformer3DModel"] self._wan_cache = None + self.is_multistage = True + # multistage boundaries split the models up when sampling timesteps + # for wan 2.2 14b. the timesteps are 1000-875 for transformer 1 and 875-0 for transformer 2 + self.multistage_boundaries: List[float] = [0.875, 0.0] + + self.train_high_noise = model_config.model_kwargs.get("train_high_noise", True) + self.train_low_noise = model_config.model_kwargs.get("train_low_noise", True) + + self.trainable_multistage_boundaries: List[int] = [] + if self.train_high_noise: + self.trainable_multistage_boundaries.append(0) + if self.train_low_noise: + self.trainable_multistage_boundaries.append(1) + + if len(self.trainable_multistage_boundaries) == 0: + raise ValueError( + "At least one of train_high_noise or train_low_noise must be True in model.model_kwargs" + ) + @property def max_step_saves_to_keep_multiplier(self): # the cleanup mechanism checks this to see how many saves to keep # if we are training a LoRA, we need to set this to 2 so we keep both the high noise and low noise LoRAs at saves to keep - if self.network is not None: + if ( + self.network is not None + and self.network.network_config.split_multistage_loras + ): return 2 return 1 @@ -264,7 +288,7 @@ class Wan2214bModel(Wan225bModel): transformer_1.to(self.quantize_device, dtype=dtype) flush() - if self.model_config.quantize: + if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None: # todo handle two ARAs self.print_and_status_update("Quantizing Transformer 1") quantize_model(self, transformer_1) @@ -289,7 +313,7 @@ class Wan2214bModel(Wan225bModel): transformer_2.to(self.quantize_device, dtype=dtype) flush() - if self.model_config.quantize: + if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None: # todo handle two ARAs self.print_and_status_update("Quantizing Transformer 2") quantize_model(self, transformer_2) @@ -309,7 +333,13 @@ class Wan2214bModel(Wan225bModel): boundary_ratio=boundary_ratio_t2v, low_vram=self.model_config.low_vram, ) - + + if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is not None: + # apply the accuracy recovery adapter to both transformers + self.print_and_status_update("Applying Accuracy Recovery Adapter to Transformers") + quantize_model(self, transformer) + flush() + return transformer def get_generation_pipeline(self): @@ -407,17 +437,20 @@ class Wan2214bModel(Wan225bModel): # just save as a combo lora save_file(state_dict, output_path, metadata=metadata) return - + # we need to build out both dictionaries for high and low noise LoRAs high_noise_lora = {} low_noise_lora = {} + + only_train_high_noise = self.train_high_noise and not self.train_low_noise + only_train_low_noise = self.train_low_noise and not self.train_high_noise for key in state_dict: - if ".transformer_1." in key: + if ".transformer_1." in key or only_train_high_noise: # this is a high noise LoRA new_key = key.replace(".transformer_1.", ".") high_noise_lora[new_key] = state_dict[key] - elif ".transformer_2." in key: + elif ".transformer_2." in key or only_train_low_noise: # this is a low noise LoRA new_key = key.replace(".transformer_2.", ".") low_noise_lora[new_key] = state_dict[key] @@ -439,11 +472,14 @@ class Wan2214bModel(Wan225bModel): def load_lora(self, file: str): # if it doesnt have high_noise or low_noise, it is a combo LoRA - if "_high_noise.safetensors" not in file and "_low_noise.safetensors" not in file: - # this is a combined LoRA, we need to split it up + if ( + "_high_noise.safetensors" not in file + and "_low_noise.safetensors" not in file + ): + # this is a combined LoRA, we dont need to split it up sd = load_file(file) return sd - + # we may have been passed the high_noise or the low_noise LoRA path, but we need to load both high_noise_lora_path = file.replace( "_low_noise.safetensors", "_high_noise.safetensors" @@ -454,7 +490,7 @@ class Wan2214bModel(Wan225bModel): combined_dict = {} - if os.path.exists(high_noise_lora_path): + if os.path.exists(high_noise_lora_path) and self.train_high_noise: # load the high noise LoRA high_noise_lora = load_file(high_noise_lora_path) for key in high_noise_lora: @@ -462,7 +498,7 @@ class Wan2214bModel(Wan225bModel): "diffusion_model.", "diffusion_model.transformer_1." ) combined_dict[new_key] = high_noise_lora[key] - if os.path.exists(low_noise_lora_path): + if os.path.exists(low_noise_lora_path) and self.train_low_noise: # load the low noise LoRA low_noise_lora = load_file(low_noise_lora_path) for key in low_noise_lora: @@ -470,5 +506,35 @@ class Wan2214bModel(Wan225bModel): "diffusion_model.", "diffusion_model.transformer_2." ) combined_dict[new_key] = low_noise_lora[key] + + # if we are not training both stages, we wont have transformer designations in the keys + if not self.train_high_noise and not self.train_low_noise: + new_dict = {} + for key in combined_dict: + if ".transformer_1." in key: + new_key = key.replace(".transformer_1.", ".") + elif ".transformer_2." in key: + new_key = key.replace(".transformer_2.", ".") + else: + new_key = key + new_dict[new_key] = combined_dict[key] + combined_dict = new_dict return combined_dict + + def get_model_to_train(self): + # todo, loras wont load right unless they have the transformer_1 or transformer_2 in the key. + # called when setting up the LoRA. We only need to get the model for the stages we want to train. + if self.train_high_noise and self.train_low_noise: + # we are training both stages, return the unified model + return self.model + elif self.train_high_noise: + # we are only training the high noise stage, return transformer_1 + return self.model.transformer_1 + elif self.train_low_noise: + # we are only training the low noise stage, return transformer_2 + return self.model.transformer_2 + else: + raise ValueError( + "At least one of train_high_noise or train_low_noise must be True in model.model_kwargs" + ) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 22fd465a..8ca654af 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1862,7 +1862,20 @@ class SDTrainer(BaseSDTrainProcess): total_loss = None self.optimizer.zero_grad() for batch in batch_list: + if self.sd.is_multistage: + # handle multistage switching + if self.steps_this_boundary >= self.train_config.switch_boundary_every: + # iterate to make sure we only train trainable_multistage_boundaries + while True: + self.steps_this_boundary = 0 + self.current_boundary_index += 1 + if self.current_boundary_index >= len(self.sd.multistage_boundaries): + self.current_boundary_index = 0 + if self.current_boundary_index in self.sd.trainable_multistage_boundaries: + # if this boundary is trainable, we can stop looking + break loss = self.train_single_accumulation(batch) + self.steps_this_boundary += 1 if total_loss is None: total_loss = loss else: @@ -1907,7 +1920,7 @@ class SDTrainer(BaseSDTrainProcess): self.adapter.restore_embeddings() loss_dict = OrderedDict( - {'loss': loss.item()} + {'loss': (total_loss / len(batch_list)).item()} ) self.end_of_training_loop() diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index daf3d33c..80e2e226 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -260,6 +260,9 @@ class BaseSDTrainProcess(BaseTrainProcess): torch.profiler.ProfilerActivity.CUDA, ], ) + + self.current_boundary_index = 0 + self.steps_this_boundary = 0 def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]): # override in subclass @@ -1171,6 +1174,24 @@ class BaseSDTrainProcess(BaseTrainProcess): self.sd.noise_scheduler.set_timesteps( num_train_timesteps, device=self.device_torch ) + if self.sd.is_multistage: + with self.timer('adjust_multistage_timesteps'): + # get our current sample range + boundaries = [1000] + self.sd.multistage_boundaries + boundary_max, boundary_min = boundaries[self.current_boundary_index], boundaries[self.current_boundary_index + 1] + lo = torch.searchsorted(self.sd.noise_scheduler.timesteps, -torch.tensor(boundary_max, device=self.sd.noise_scheduler.timesteps.device), right=False) + hi = torch.searchsorted(self.sd.noise_scheduler.timesteps, -torch.tensor(boundary_min, device=self.sd.noise_scheduler.timesteps.device), right=True) + first_idx = lo.item() if hi > lo else 0 + last_idx = (hi - 1).item() if hi > lo else 999 + + min_noise_steps = first_idx + max_noise_steps = last_idx + + # clip min max indicies + min_noise_steps = max(min_noise_steps, 0) + max_noise_steps = min(max_noise_steps, num_train_timesteps - 1) + + with self.timer('prepare_timesteps_indices'): content_or_style = self.train_config.content_or_style @@ -1209,11 +1230,11 @@ class BaseSDTrainProcess(BaseTrainProcess): 0, self.train_config.num_train_timesteps - 1, min_noise_steps, - max_noise_steps - 1 + max_noise_steps ) timestep_indices = timestep_indices.long().clamp( - min_noise_steps + 1, - max_noise_steps - 1 + min_noise_steps, + max_noise_steps ) elif content_or_style == 'balanced': @@ -1226,7 +1247,7 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.train_config.noise_scheduler == 'flowmatch': # flowmatch uses indices, so we need to use indices min_idx = 0 - max_idx = max_noise_steps - 1 + max_idx = max_noise_steps timestep_indices = torch.randint( min_idx, max_idx, @@ -1676,7 +1697,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.network = NetworkClass( text_encoder=text_encoder, - unet=unet, + unet=self.sd.get_model_to_train(), lora_dim=self.network_config.linear, multiplier=1.0, alpha=self.network_config.linear_alpha, diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 81415a10..403abc54 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -335,7 +335,7 @@ class TrainConfig: self.lr_scheduler = kwargs.get('lr_scheduler', 'constant') self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {}) self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0) - self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000) + self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 999) self.batch_size: int = kwargs.get('batch_size', 1) self.orig_batch_size: int = self.batch_size self.dtype: str = kwargs.get('dtype', 'fp32') @@ -515,6 +515,9 @@ class TrainConfig: self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '') if isinstance(self.guidance_loss_target, tuple): self.guidance_loss_target = list(self.guidance_loss_target) + + # for multi stage models, how often to switch the boundary + self.switch_boundary_every: int = kwargs.get('switch_boundary_every', 1) ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex1', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21'] diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 412e04d4..d446c25e 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -172,6 +172,11 @@ class BaseModel: self.sample_prompts_cache = None self.accuracy_recovery_adapter: Union[None, 'LoRASpecialNetwork'] = None + self.is_multistage = False + # a list of multistage boundaries starting with train step 1000 to first idx + self.multistage_boundaries: List[float] = [0.0] + # a list of trainable multistage boundaries + self.trainable_multistage_boundaries: List[int] = [0] # properties for old arch for backwards compatibility @property @@ -1502,3 +1507,7 @@ class BaseModel: def get_base_model_version(self) -> str: # override in child classes to get the base model version return "unknown" + + def get_model_to_train(self): + # called to get model to attach LoRAs to. Can be overridden in child classes + return self.unet diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 183cbb8d..86908884 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -211,6 +211,12 @@ class StableDiffusion: self.sample_prompts_cache = None + self.is_multistage = False + # a list of multistage boundaries starting with train step 1000 to first idx + self.multistage_boundaries: List[float] = [0.0] + # a list of trainable multistage boundaries + self.trainable_multistage_boundaries: List[int] = [0] + # properties for old arch for backwards compatibility @property def is_xl(self): @@ -3123,3 +3129,6 @@ class StableDiffusion: if self.is_v2: return 'sd_2.1' return 'sd_1.5' + + def get_model_to_train(self): + return self.unet diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index b2e31294..c9c4d0f8 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -40,10 +40,25 @@ export default function SimpleJob({ const isVideoModel = !!(modelArch?.group === 'video'); - let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-5 gap-6'; + const numTopCards = useMemo(() => { + let count = 4; // job settings, model config, target config, save config + if (modelArch?.additionalSections?.includes('model.multistage')) { + count += 1; // add multistage card + } + if (!modelArch?.disableSections?.includes('model.quantize')) { + count += 1; // add quantization card + } + return count; + + }, [modelArch]); - if (modelArch?.disableSections?.includes('model.quantize')) { - topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6'; + let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6'; + + if (numTopCards == 5) { + topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-5 gap-6'; + } + if (numTopCards == 6) { + topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-6 gap-6'; } const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => { @@ -91,7 +106,7 @@ export default function SimpleJob({ <>
- + {/* Model Configuration Section */} - + )} - + {modelArch?.additionalSections?.includes('model.multistage') && ( + + + setJobConfig(value, 'config.process[0].model.model_kwargs.train_high_noise')} + /> + setJobConfig(value, 'config.process[0].model.model_kwargs.train_low_noise')} + /> + + setJobConfig(value, 'config.process[0].train.switch_boundary_every')} + placeholder="eg. 1" + docKey={'train.switch_boundary_every'} + min={1} + required + /> + + )} + )} - +
- +
- +
= props => { return (
{label && ( -