Fix issue with wan22 14b where timesteps were generated not in the current boundary.

This commit is contained in:
Jaret Burkett
2025-08-16 21:16:48 -06:00
parent 6fffadfc0e
commit b3e666daf4
4 changed files with 14 additions and 13 deletions

View File

@@ -21,7 +21,7 @@ from diffusers import WanTransformer3DModel
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from torchvision.transforms import functional as TF from torchvision.transforms import functional as TF
from toolkit.models.wan21.wan21 import AggressiveWanUnloadPipeline from toolkit.models.wan21.wan21 import AggressiveWanUnloadPipeline, Wan21
from .wan22_5b_model import ( from .wan22_5b_model import (
scheduler_config, scheduler_config,
time_text_monkeypatch, time_text_monkeypatch,
@@ -116,11 +116,12 @@ class DualWanTransformer3DModel(torch.nn.Module):
encoder_hidden_states_image: Optional[torch.Tensor] = None, encoder_hidden_states_image: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
**kwargs
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
# determine if doing high noise or low noise by meaning the timestep. # determine if doing high noise or low noise by meaning the timestep.
# timesteps are in the range of 0 to 1000, so we can use a threshold # timesteps are in the range of 0 to 1000, so we can use a threshold
with torch.no_grad(): with torch.no_grad():
if timestep.float().mean().item() >= self.boundary: if timestep.float().mean().item() > self.boundary:
t_name = "transformer_1" t_name = "transformer_1"
else: else:
t_name = "transformer_2" t_name = "transformer_2"
@@ -159,10 +160,10 @@ class DualWanTransformer3DModel(torch.nn.Module):
return self return self
class Wan2214bModel(Wan225bModel): class Wan2214bModel(Wan21):
arch = "wan22_14b" arch = "wan22_14b"
_wan_generation_scheduler_config = scheduler_configUniPC _wan_generation_scheduler_config = scheduler_configUniPC
_wan_expand_timesteps = True _wan_expand_timesteps = False
_wan_vae_path = "ai-toolkit/wan2.1-vae" _wan_vae_path = "ai-toolkit/wan2.1-vae"
def __init__( def __init__(
@@ -223,7 +224,7 @@ class Wan2214bModel(Wan225bModel):
def load_model(self): def load_model(self):
# load model from patent parent. Wan21 not immediate parent # load model from patent parent. Wan21 not immediate parent
# super().load_model() # super().load_model()
super(Wan225bModel, self).load_model() super().load_model()
# we have to split up the model on the pipeline # we have to split up the model on the pipeline
self.pipeline.transformer = self.model.transformer_1 self.pipeline.transformer = self.model.transformer_1

View File

@@ -1864,7 +1864,7 @@ class SDTrainer(BaseSDTrainProcess):
for batch in batch_list: for batch in batch_list:
if self.sd.is_multistage: if self.sd.is_multistage:
# handle multistage switching # handle multistage switching
if self.steps_this_boundary >= self.train_config.switch_boundary_every: if self.steps_this_boundary >= self.train_config.switch_boundary_every or self.current_boundary_index not in self.sd.trainable_multistage_boundaries:
# iterate to make sure we only train trainable_multistage_boundaries # iterate to make sure we only train trainable_multistage_boundaries
while True: while True:
self.steps_this_boundary = 0 self.steps_this_boundary = 0

View File

@@ -1177,13 +1177,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.sd.is_multistage: if self.sd.is_multistage:
with self.timer('adjust_multistage_timesteps'): with self.timer('adjust_multistage_timesteps'):
# get our current sample range # get our current sample range
boundaries = [1000] + self.sd.multistage_boundaries boundaries = [1] + self.sd.multistage_boundaries
boundary_max, boundary_min = boundaries[self.current_boundary_index], boundaries[self.current_boundary_index + 1] boundary_max, boundary_min = boundaries[self.current_boundary_index], boundaries[self.current_boundary_index + 1]
lo = torch.searchsorted(self.sd.noise_scheduler.timesteps, -torch.tensor(boundary_max, device=self.sd.noise_scheduler.timesteps.device), right=False) asc_timesteps = torch.flip(self.sd.noise_scheduler.timesteps, dims=[0])
hi = torch.searchsorted(self.sd.noise_scheduler.timesteps, -torch.tensor(boundary_min, device=self.sd.noise_scheduler.timesteps.device), right=True) lo = len(asc_timesteps) - torch.searchsorted(asc_timesteps, torch.tensor(boundary_max * 1000, device=asc_timesteps.device), right=False)
first_idx = lo.item() if hi > lo else 0 hi = len(asc_timesteps) - torch.searchsorted(asc_timesteps, torch.tensor(boundary_min * 1000, device=asc_timesteps.device), right=True)
first_idx = (lo - 1).item() if hi > lo else 0
last_idx = (hi - 1).item() if hi > lo else 999 last_idx = (hi - 1).item() if hi > lo else 999
min_noise_steps = first_idx min_noise_steps = first_idx
max_noise_steps = last_idx max_noise_steps = last_idx
@@ -1246,7 +1246,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
max_idx = max_noise_steps - 1 max_idx = max_noise_steps - 1
if self.train_config.noise_scheduler == 'flowmatch': if self.train_config.noise_scheduler == 'flowmatch':
# flowmatch uses indices, so we need to use indices # flowmatch uses indices, so we need to use indices
min_idx = 0 min_idx = min_noise_steps
max_idx = max_noise_steps max_idx = max_noise_steps
timestep_indices = torch.randint( timestep_indices = torch.randint(
min_idx, min_idx,

View File

@@ -1 +1 @@
VERSION = "0.5.0" VERSION = "0.5.1"