mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Fix issue with wan22 14b where timesteps were generated not in the current boundary.
This commit is contained in:
@@ -21,7 +21,7 @@ from diffusers import WanTransformer3DModel
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from torchvision.transforms import functional as TF
|
||||
|
||||
from toolkit.models.wan21.wan21 import AggressiveWanUnloadPipeline
|
||||
from toolkit.models.wan21.wan21 import AggressiveWanUnloadPipeline, Wan21
|
||||
from .wan22_5b_model import (
|
||||
scheduler_config,
|
||||
time_text_monkeypatch,
|
||||
@@ -116,11 +116,12 @@ class DualWanTransformer3DModel(torch.nn.Module):
|
||||
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
# determine if doing high noise or low noise by meaning the timestep.
|
||||
# timesteps are in the range of 0 to 1000, so we can use a threshold
|
||||
with torch.no_grad():
|
||||
if timestep.float().mean().item() >= self.boundary:
|
||||
if timestep.float().mean().item() > self.boundary:
|
||||
t_name = "transformer_1"
|
||||
else:
|
||||
t_name = "transformer_2"
|
||||
@@ -159,10 +160,10 @@ class DualWanTransformer3DModel(torch.nn.Module):
|
||||
return self
|
||||
|
||||
|
||||
class Wan2214bModel(Wan225bModel):
|
||||
class Wan2214bModel(Wan21):
|
||||
arch = "wan22_14b"
|
||||
_wan_generation_scheduler_config = scheduler_configUniPC
|
||||
_wan_expand_timesteps = True
|
||||
_wan_expand_timesteps = False
|
||||
_wan_vae_path = "ai-toolkit/wan2.1-vae"
|
||||
|
||||
def __init__(
|
||||
@@ -223,7 +224,7 @@ class Wan2214bModel(Wan225bModel):
|
||||
def load_model(self):
|
||||
# load model from patent parent. Wan21 not immediate parent
|
||||
# super().load_model()
|
||||
super(Wan225bModel, self).load_model()
|
||||
super().load_model()
|
||||
|
||||
# we have to split up the model on the pipeline
|
||||
self.pipeline.transformer = self.model.transformer_1
|
||||
|
||||
Reference in New Issue
Block a user