From b3e666daf4fce468a2776f37667513043c501081 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 16 Aug 2025 21:16:48 -0600 Subject: [PATCH] Fix issue with wan22 14b where timesteps were generated not in the current boundary. --- .../diffusion_models/wan22/wan22_14b_model.py | 11 ++++++----- extensions_built_in/sd_trainer/SDTrainer.py | 2 +- jobs/process/BaseSDTrainProcess.py | 12 ++++++------ version.py | 2 +- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py index de4c7c59..abc1803e 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py @@ -21,7 +21,7 @@ from diffusers import WanTransformer3DModel from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from torchvision.transforms import functional as TF -from toolkit.models.wan21.wan21 import AggressiveWanUnloadPipeline +from toolkit.models.wan21.wan21 import AggressiveWanUnloadPipeline, Wan21 from .wan22_5b_model import ( scheduler_config, time_text_monkeypatch, @@ -116,11 +116,12 @@ class DualWanTransformer3DModel(torch.nn.Module): encoder_hidden_states_image: Optional[torch.Tensor] = None, return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, + **kwargs ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: # determine if doing high noise or low noise by meaning the timestep. # timesteps are in the range of 0 to 1000, so we can use a threshold with torch.no_grad(): - if timestep.float().mean().item() >= self.boundary: + if timestep.float().mean().item() > self.boundary: t_name = "transformer_1" else: t_name = "transformer_2" @@ -159,10 +160,10 @@ class DualWanTransformer3DModel(torch.nn.Module): return self -class Wan2214bModel(Wan225bModel): +class Wan2214bModel(Wan21): arch = "wan22_14b" _wan_generation_scheduler_config = scheduler_configUniPC - _wan_expand_timesteps = True + _wan_expand_timesteps = False _wan_vae_path = "ai-toolkit/wan2.1-vae" def __init__( @@ -223,7 +224,7 @@ class Wan2214bModel(Wan225bModel): def load_model(self): # load model from patent parent. Wan21 not immediate parent # super().load_model() - super(Wan225bModel, self).load_model() + super().load_model() # we have to split up the model on the pipeline self.pipeline.transformer = self.model.transformer_1 diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 8ca654af..d9d7aff5 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1864,7 +1864,7 @@ class SDTrainer(BaseSDTrainProcess): for batch in batch_list: if self.sd.is_multistage: # handle multistage switching - if self.steps_this_boundary >= self.train_config.switch_boundary_every: + if self.steps_this_boundary >= self.train_config.switch_boundary_every or self.current_boundary_index not in self.sd.trainable_multistage_boundaries: # iterate to make sure we only train trainable_multistage_boundaries while True: self.steps_this_boundary = 0 diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 80e2e226..84c941bb 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1177,13 +1177,13 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.sd.is_multistage: with self.timer('adjust_multistage_timesteps'): # get our current sample range - boundaries = [1000] + self.sd.multistage_boundaries + boundaries = [1] + self.sd.multistage_boundaries boundary_max, boundary_min = boundaries[self.current_boundary_index], boundaries[self.current_boundary_index + 1] - lo = torch.searchsorted(self.sd.noise_scheduler.timesteps, -torch.tensor(boundary_max, device=self.sd.noise_scheduler.timesteps.device), right=False) - hi = torch.searchsorted(self.sd.noise_scheduler.timesteps, -torch.tensor(boundary_min, device=self.sd.noise_scheduler.timesteps.device), right=True) - first_idx = lo.item() if hi > lo else 0 + asc_timesteps = torch.flip(self.sd.noise_scheduler.timesteps, dims=[0]) + lo = len(asc_timesteps) - torch.searchsorted(asc_timesteps, torch.tensor(boundary_max * 1000, device=asc_timesteps.device), right=False) + hi = len(asc_timesteps) - torch.searchsorted(asc_timesteps, torch.tensor(boundary_min * 1000, device=asc_timesteps.device), right=True) + first_idx = (lo - 1).item() if hi > lo else 0 last_idx = (hi - 1).item() if hi > lo else 999 - min_noise_steps = first_idx max_noise_steps = last_idx @@ -1246,7 +1246,7 @@ class BaseSDTrainProcess(BaseTrainProcess): max_idx = max_noise_steps - 1 if self.train_config.noise_scheduler == 'flowmatch': # flowmatch uses indices, so we need to use indices - min_idx = 0 + min_idx = min_noise_steps max_idx = max_noise_steps timestep_indices = torch.randint( min_idx, diff --git a/version.py b/version.py index 2cec286b..9b13ade4 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.5.0" \ No newline at end of file +VERSION = "0.5.1" \ No newline at end of file