mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Fix issue with wan22 14b where timesteps were generated not in the current boundary.
This commit is contained in:
@@ -21,7 +21,7 @@ from diffusers import WanTransformer3DModel
|
|||||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||||
from torchvision.transforms import functional as TF
|
from torchvision.transforms import functional as TF
|
||||||
|
|
||||||
from toolkit.models.wan21.wan21 import AggressiveWanUnloadPipeline
|
from toolkit.models.wan21.wan21 import AggressiveWanUnloadPipeline, Wan21
|
||||||
from .wan22_5b_model import (
|
from .wan22_5b_model import (
|
||||||
scheduler_config,
|
scheduler_config,
|
||||||
time_text_monkeypatch,
|
time_text_monkeypatch,
|
||||||
@@ -116,11 +116,12 @@ class DualWanTransformer3DModel(torch.nn.Module):
|
|||||||
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs
|
||||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||||
# determine if doing high noise or low noise by meaning the timestep.
|
# determine if doing high noise or low noise by meaning the timestep.
|
||||||
# timesteps are in the range of 0 to 1000, so we can use a threshold
|
# timesteps are in the range of 0 to 1000, so we can use a threshold
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if timestep.float().mean().item() >= self.boundary:
|
if timestep.float().mean().item() > self.boundary:
|
||||||
t_name = "transformer_1"
|
t_name = "transformer_1"
|
||||||
else:
|
else:
|
||||||
t_name = "transformer_2"
|
t_name = "transformer_2"
|
||||||
@@ -159,10 +160,10 @@ class DualWanTransformer3DModel(torch.nn.Module):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class Wan2214bModel(Wan225bModel):
|
class Wan2214bModel(Wan21):
|
||||||
arch = "wan22_14b"
|
arch = "wan22_14b"
|
||||||
_wan_generation_scheduler_config = scheduler_configUniPC
|
_wan_generation_scheduler_config = scheduler_configUniPC
|
||||||
_wan_expand_timesteps = True
|
_wan_expand_timesteps = False
|
||||||
_wan_vae_path = "ai-toolkit/wan2.1-vae"
|
_wan_vae_path = "ai-toolkit/wan2.1-vae"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -223,7 +224,7 @@ class Wan2214bModel(Wan225bModel):
|
|||||||
def load_model(self):
|
def load_model(self):
|
||||||
# load model from patent parent. Wan21 not immediate parent
|
# load model from patent parent. Wan21 not immediate parent
|
||||||
# super().load_model()
|
# super().load_model()
|
||||||
super(Wan225bModel, self).load_model()
|
super().load_model()
|
||||||
|
|
||||||
# we have to split up the model on the pipeline
|
# we have to split up the model on the pipeline
|
||||||
self.pipeline.transformer = self.model.transformer_1
|
self.pipeline.transformer = self.model.transformer_1
|
||||||
|
|||||||
@@ -1864,7 +1864,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
for batch in batch_list:
|
for batch in batch_list:
|
||||||
if self.sd.is_multistage:
|
if self.sd.is_multistage:
|
||||||
# handle multistage switching
|
# handle multistage switching
|
||||||
if self.steps_this_boundary >= self.train_config.switch_boundary_every:
|
if self.steps_this_boundary >= self.train_config.switch_boundary_every or self.current_boundary_index not in self.sd.trainable_multistage_boundaries:
|
||||||
# iterate to make sure we only train trainable_multistage_boundaries
|
# iterate to make sure we only train trainable_multistage_boundaries
|
||||||
while True:
|
while True:
|
||||||
self.steps_this_boundary = 0
|
self.steps_this_boundary = 0
|
||||||
|
|||||||
@@ -1177,13 +1177,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
if self.sd.is_multistage:
|
if self.sd.is_multistage:
|
||||||
with self.timer('adjust_multistage_timesteps'):
|
with self.timer('adjust_multistage_timesteps'):
|
||||||
# get our current sample range
|
# get our current sample range
|
||||||
boundaries = [1000] + self.sd.multistage_boundaries
|
boundaries = [1] + self.sd.multistage_boundaries
|
||||||
boundary_max, boundary_min = boundaries[self.current_boundary_index], boundaries[self.current_boundary_index + 1]
|
boundary_max, boundary_min = boundaries[self.current_boundary_index], boundaries[self.current_boundary_index + 1]
|
||||||
lo = torch.searchsorted(self.sd.noise_scheduler.timesteps, -torch.tensor(boundary_max, device=self.sd.noise_scheduler.timesteps.device), right=False)
|
asc_timesteps = torch.flip(self.sd.noise_scheduler.timesteps, dims=[0])
|
||||||
hi = torch.searchsorted(self.sd.noise_scheduler.timesteps, -torch.tensor(boundary_min, device=self.sd.noise_scheduler.timesteps.device), right=True)
|
lo = len(asc_timesteps) - torch.searchsorted(asc_timesteps, torch.tensor(boundary_max * 1000, device=asc_timesteps.device), right=False)
|
||||||
first_idx = lo.item() if hi > lo else 0
|
hi = len(asc_timesteps) - torch.searchsorted(asc_timesteps, torch.tensor(boundary_min * 1000, device=asc_timesteps.device), right=True)
|
||||||
|
first_idx = (lo - 1).item() if hi > lo else 0
|
||||||
last_idx = (hi - 1).item() if hi > lo else 999
|
last_idx = (hi - 1).item() if hi > lo else 999
|
||||||
|
|
||||||
min_noise_steps = first_idx
|
min_noise_steps = first_idx
|
||||||
max_noise_steps = last_idx
|
max_noise_steps = last_idx
|
||||||
|
|
||||||
@@ -1246,7 +1246,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
max_idx = max_noise_steps - 1
|
max_idx = max_noise_steps - 1
|
||||||
if self.train_config.noise_scheduler == 'flowmatch':
|
if self.train_config.noise_scheduler == 'flowmatch':
|
||||||
# flowmatch uses indices, so we need to use indices
|
# flowmatch uses indices, so we need to use indices
|
||||||
min_idx = 0
|
min_idx = min_noise_steps
|
||||||
max_idx = max_noise_steps
|
max_idx = max_noise_steps
|
||||||
timestep_indices = torch.randint(
|
timestep_indices = torch.randint(
|
||||||
min_idx,
|
min_idx,
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
VERSION = "0.5.0"
|
VERSION = "0.5.1"
|
||||||
Reference in New Issue
Block a user