diff --git a/extensions_built_in/diffusion_models/wan22/wan22_model.py b/extensions_built_in/diffusion_models/wan22/wan22_model.py index 565a3fea..0b0c63ea 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_model.py @@ -1,3 +1,4 @@ +from functools import partial import torch from toolkit.prompt_utils import PromptEmbeds from PIL import Image @@ -55,6 +56,29 @@ scheduler_config = { "use_dynamic_shifting": False, } +# TODO: this is a temporary monkeypatch to fix the time text embedding to allow for batch sizes greater than 1. Remove this when the diffusers library is fixed. +def time_text_monkeypatch( + self, + timestep: torch.Tensor, + encoder_hidden_states, + encoder_hidden_states_image = None, + timestep_seq_len = None, +): + timestep = self.timesteps_proj(timestep) + if timestep_seq_len is not None: + timestep = timestep.unflatten(0, (encoder_hidden_states.shape[0], timestep_seq_len)) + + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) + + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image class Wan22Model(Wan21): arch = "wan22_5b" @@ -80,6 +104,12 @@ class Wan22Model(Wan21): ) self._wan_cache = None + + def load_model(self): + super().load_model() + + # patch the condition embedder + self.model.condition_embedder.forward = partial(time_text_monkeypatch, self.model.condition_embedder) def get_bucket_divisibility(self): # 16x compression and 2x2 patch size