Fix bug that prevented training wan 2.2 with batch size greater than 1

This commit is contained in:
Jaret Burkett
2025-07-29 09:06:25 -06:00
parent f453e28ea3
commit 1d1199b15b

View File

@@ -1,3 +1,4 @@
from functools import partial
import torch
from toolkit.prompt_utils import PromptEmbeds
from PIL import Image
@@ -55,6 +56,29 @@ scheduler_config = {
"use_dynamic_shifting": False,
}
# TODO: this is a temporary monkeypatch to fix the time text embedding to allow for batch sizes greater than 1. Remove this when the diffusers library is fixed.
def time_text_monkeypatch(
self,
timestep: torch.Tensor,
encoder_hidden_states,
encoder_hidden_states_image = None,
timestep_seq_len = None,
):
timestep = self.timesteps_proj(timestep)
if timestep_seq_len is not None:
timestep = timestep.unflatten(0, (encoder_hidden_states.shape[0], timestep_seq_len))
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
timestep_proj = self.time_proj(self.act_fn(temb))
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
if encoder_hidden_states_image is not None:
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
class Wan22Model(Wan21):
arch = "wan22_5b"
@@ -80,6 +104,12 @@ class Wan22Model(Wan21):
)
self._wan_cache = None
def load_model(self):
super().load_model()
# patch the condition embedder
self.model.condition_embedder.forward = partial(time_text_monkeypatch, self.model.condition_embedder)
def get_bucket_divisibility(self):
# 16x compression and 2x2 patch size